Intro to Machine Learning on Android (Part 2): Building an app to recognize handwritten digits with TensorFlow Lite

By Eric Hsiao

With the launch of mobile machine learning frameworks like TensorFlow Lite, it’s never been easier for mobile developers to develop new and exciting features into their apps. Powerful apps leverage machine learning under the hood in order to accomplish complex tasks like identifying crop disease or automatically creating captions for pictures, all in real time and without Internet connectivity.

In part one of this tutorial, we went over how to convert a custom model to TensorFlow Lite and discussed some tips and tricks to evaluate and trim the unnecessary layers in the TensorFlow graph. In the end, we prepared a model trained on MNIST data for inference.

The good news is that we’re finished with the hardest part: training and converting the model. To review, here are some details about our finished model (mnist.tflite):

  • input size of 1x28x28x1 (batch size x image width x image height x number of channels).
  • output size of 1x10 (classification of handwritten numbers from 0–9)

In this post, we’ll go over how we take that model and and create a simple Android app.

Let’s fire up Android Studio and add our converted model to the app.

  1. First, create an assets folder (src/main/assets) and add your model.

2. Next, add the TensorFlow Lite dependency in your app’s build.gradle file.

dependencies {
implementation 'org.tensorflow:tensorflow-lite:+'
}

This downloads the latest stable version, but typically you’ll want to give your library a set version number for stable builds. Since TensorFlow Lite is in active development, you might want to use the nightly builds when you’re testing things out.

3. Stay in the build.gradle file and set these options so that the model does not get compressed when the app is compiled.

android {
aaptOptions {
noCompress "tflite"
noCompress "lite"
}
}

4. Re-sync your gradle file and make sure there are no build errors.

In order to create the interface for the app, we build a canvas that the user can draw on. Once they’ve finished writing the number, they can click on the detect button in order to pass that image into the model to get a prediction.

To save some time, I won’t dive into how to create the UI and instead save that as an activity for the reader. There are some great tutorials / open source code out that to help you get started.

Paint Tutorial (left), MNIST TFMobile Tutorial (middle), Finished Version (right)

Let’s take a step back and understand the inputs and outputs of the digits model we’ve trained. From an app developer’s point of view, it’s important to understand the following:

  • The input and output dimensions.
  • Pre-processing for the input—Depending on the model, you may need to apply pre-processing to image data before it’s passed into the model. A common example is normalizing the pixel values so that they’re within a certain range.
  • Post-processing on the output — After the input is run through the model, there could be additional steps in order to interpret the result. For example, in the digits model, we get back an array of 10 probabilities, so we’ll need to run an arg max (get the index with the largest value) on the model output to get the predicted number.

Here’s a high level view of what we’ll need to do in order to get this model to work:

Let’s break this down.

  1. First, create a new class for the model called DigitsDetector.java that will load the model with TensorFlow Lite’s Interpreter class. This will read our mnist.tflite file in the assets folder.

2. Next, initialize the byte buffers / arrays with the correct input / output dimensions in the class scope. Try to allocate these variables once to prevent excess memory allocations.

  • You may find it useful to create static variables to hold onto the dimension sizes.
  • In this case, we’re allocating the input buffer directly but you can also use a FloatBuffer with a size of 1 x 28 x 28 x1 if you’d like.

3. Next, let’s create a public method called detectDigit that takes in a Bitmap. In the detectDigit method, we’ll pre-process the image to prepare it for the model, run inference, and then interpret the result. We’ll define these methods in later steps.

4. For pre-processing the bitmap pixels for inference, we know that the model has a shape of 1 x 28 x 28 x 1 (batch size x width x height x number of channels).

If the bitmap we pass into the classify method is not 28 x 28, we’ll need to downscale it. In this case, we’ll center crop the image to a square and then resize it to fit the model input.

Here’s the pseudocode in order to modify the original image:

BitmapUtils is a script used from (https://github.com/jinkg/YalinEmail/blob/master/android-unifiedemail/src/main/java/com/android/mail/utils/BitmapUtil.java).

Use the resizedBitmap, get the pixels, and grayscale the image so that 0 is set for white pixels and 255 is set for black ones.

5. Run the inputBuffer through the TensorFlow Lite model and load the result in mnistOutput. When we converted the model to tflite in the first post, we defined the output layer as softmax_tensor. The result from that layer is stored in our variable mnistOutput.

protected void runInference() {
tflite.run(inputBuffer, mnistOutput);
}

6. Finally, let’s process the result. Go through the results and get the index with a value of 1 (meaning that the model has predicted this number). The index is the number that the model has predicted.

Putting this all together and running the app, here’s what we get when we click on the detect button.

We’ve added a preview of the 28 x 28 image that’s ultimately passed to the model.

Here’s the complete source code for reference on GitHub.

In this tutorial, we presented a simple example with relatively straightforward pre- and post-processing steps. These steps can be applied to any model that you wish to add to your app:

  1. Include the TensorFlow Lite dependency
  2. Create a separate class that wraps the model and it’s pre- and post- processing steps. In this post we called it DigitDetector.java, which had a method called classify that took in a Bitmap and returned the number detected in that image.
  3. Define input and output buffers with the appropriate dimensions.
  4. Implement pre-processing to prepare a raw input (in this case our original image) and load it into the input buffer (which is passed into the model).
  5. Run inference on the input and pass it into an output buffer.
  6. Interpret the result. We read the output buffer and then decided which number was drawn on the original image.

There have already been some fantastic user experiences built using on-device machine learning and TensorFlow Lite. I’m excited to see what developers do with TensorFlow Lite as it continues to mature. If you’ve come across any cool projects or have any questions, please feel free to share those in the comments below!

Discuss this post on Hacker News and Reddit