SphericalKat Blog Resume

Idiomatic TensorFlow on Android — Get started with the TensorFlow Support Library


Working with data on Android is inconvenient!

If you’ve used TensorFlow Lite on Android before, chances are that you’ve had to deal with the tedious task of pre-processing data, working with Float arrays in a statically typed language or resize, transform, normalize, and do any of the other standard tasks required before the data is fit for consumption by the model.

Well, no more! The TFLite support library nightly is now available, and in this post, we’ll go over its usage and build a wrapper around a tflite model.

Note: A companion repository for this post is available here. Follow along, or jump straight into the source!

Scope

This post is limited in scope to loading and creating a wrapper class around a tflite model; however, you can see a fully functional project in the repository linked above. The code is liberally commented and very straightforward.

If you still have any queries, please don’t hesitate to reach out to me and drop a comment. I’ll be glad to help you out.

Setting up the project

We’re going to be deploying a TFLite version of the popular YOLOv3 object detection model on an Android device. Without further ado, let’s jump into it.

Create a new project using Android Studio, name it anything you like, and wait for the initial gradle sync to complete. Next, we’ll install the dependencies.

Adding dependencies

Add the following dependencies to your app-level build.gradle.

// Permissions handling
implementation 'com.github.quickpermissions:quickpermissions-kotlin:0.4.0'

// Tensorflow lite
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'

// CameraView
implementation 'com.otaliastudios:cameraview:2.6.2'
  1. Quick Permissions: This is a great library to make granting permissions quick and easy.
  2. Tensorflow Lite: This is the core TFLite library.
  3. Tensorflow Lite Support: This is the support library we’ll be using to make our data-related tasks simpler.
  4. CameraView: This is a library that provides a simple API for accessing the camera.

Configuring the gradle project

Our project still needs a little more configuration before we’re ready for the code. In the app-level build.gradle file, add the following options under the android block.

The reason we need to add this is because we’ll be shipping the model inside our assets, which are compressed by default, which is problematic because compressed models cannot be loaded by the interpreter.

aaptOptions {
    noCompress "tflite"
}

Note: After this initial configuration, run the gradle sync again to fetch all dependencies.

Jumping into the code

First things first; we need a model to load. The one I used can be found here. Place the model inside app/src/main/assets. This will enable us to load it at runtime.

The labels for detected objects can be found here. Place them in the same directory as the model.

Warning: If you plan to use your own custom models, a word of caution; the input and output shapes may not match the ones used in this project.

Creating a wrapper class

We’re going to wrap our model and its associated methods inside a class called YOLO. The initial code is as follows.

class YOLO(private val context: Context) {
   private val interpreter: Interpreter
      
   companion object {
        private const val MODEL_FILE = "detect.tflite"
        private const val LABEL_FILE = "labelmap.txt"
    }

    init {
        val options = Interpreter.Options()
        interpreter = Interpreter(FileUtil.loadMappedFile(context, MODEL_FILE), options)
    }
}

Let’s break this class down into its core functionality and behaviour.

  1. First, upon being created, the class loads the model from the app assets through the FileUtil class provided by the support library.
  2. Next, we have a class member. The interpreter is self-explanatory, it’s an instance of a TFLite interpreter.
  3. Finally, we have some static variables. These are just the file names of the model and the labels inside our assets.

Moving on, let’s add a convenience method to load our labels from the assets.

class YOLO(private val context: Context) {
   // other stuff
    
   // lazily load object labels
    private val labelList by lazy { loadLabelList(context.assets) }
    
    private fun loadLabelList(
        assetManager: AssetManager
    ): List {
        val labelList = mutableListOf()
        val reader =
            BufferedReader(InputStreamReader(assetManager.open(LABEL_FILE)))
        var line = reader.readLine()
        while (line != null) {
            labelList.add(line)
            line = reader.readLine()
        }
        reader.close()
        return labelList
    }
}

Here we’ve declared a method that loads the label file and lazily initialized a member var to the returned value.

Let’s get down to the brass tacks. We’re now going to define a method that takes in a bitmap, passes it into the model and returns the detected object classes.

class YOLO(private val context: Context) {
      private val interpreter: Interpreter

    // lazily load object labels
    private val labelList by lazy { loadLabelList(context.assets) }

    // create image processor to resize image to input dimensions
    private val imageProcessor by lazy {
        ImageProcessor.Builder()
            .add(ResizeOp(300, 300, ResizeOp.ResizeMethod.BILINEAR))
            .build()
    }

    // create tensorflow representation of an image
    private val tensorImage by lazy { TensorImage(DataType.UINT8) }

    fun detectObjects(bitmap: Bitmap): List {
        tensorImage.load(bitmap)

        // resize image using processor
        val processedImage = imageProcessor.process(tensorImage)

        // load image data into input buffer
        val inputbuffer = TensorBuffer.createFixedSize(intArrayOf(1, 300, 300, 3), DataType.UINT8)
        inputbuffer.loadBuffer(processedImage.buffer, intArrayOf(1, 300, 300, 3))

        // create output buffers
        val boundBuffer = TensorBuffer.createFixedSize(intArrayOf(1, 10, 4), DataType.FLOAT32)
        val classBuffer = TensorBuffer.createFixedSize(intArrayOf(1, 10), DataType.FLOAT32)
        val classProbBuffer = TensorBuffer.createFixedSize(intArrayOf(1, 10), DataType.FLOAT32)
        val numBoxBuffer = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)

        // run interpreter
        interpreter.runForMultipleInputsOutputs(
            arrayOf(inputbuffer.buffer), mapOf(
                0 to boundBuffer.buffer,
                1 to classBuffer.buffer,
                2 to classProbBuffer.buffer,
                3 to numBoxBuffer.buffer
            )
        )

        // map and return classnames to detected categories
        return classBuffer.floatArray.map { labelList[it.toInt() + 1] }
    }
}

Whoa, that’s a wall of code! Let’s go through it and break it down.

We’ve declared some new lazily initialized variables; an ImageProcessor and a TensorImage. These are classes exposed by the support library, to make loading images and processing them much simpler.

As shown here, we can load a bitmap directly into the TensorImage and then pass it on to the ImageProcessor for further processing.

The ImageProcessor has several operations available, but the one we’ve used here is to resize our input images to 300 * 300. This is because our model’s input size requires a 300 * 300 image.

After processing the image, we create several TensorBuffers. These are representations of Tensors that we can manipulate and access easily. The shapes of these TensorBuffers is determined by the model. Take a look at the model summary to figure out the appropriate shapes.

We load the TensorImage into the input TensorBuffer, and then pass the input and output buffers into the interpreter.

Note: The YOLOv3 model has multiple outputs. This is the reason why we had to use multiple output buffers.

After running inference, the interpreter sets the internal FloatArrays of the output buffers. Right now, we’re only interested in the one that contains the predicted classes. Using the handy kotlin map function, we map labels to the numerical classes output by the model and return them.

This class can now be used by our application to run inference on a bitmap. How convenient!

Conclusion

And that’s it! Compared to a project without using the support library; we’d have written much more code to resize the image, convert bitmaps to float arrays, allocate float arrays manually to store the output in, etc.

To find out more, visit the documentation here.