Android Handwriting digit prediction app using Machine Learning Model in Kotlin

In this blog, we are going to see how to create a handwriting digit prediction app using a machine learning model in kotlin.

Technology used:

  1. Kotlin
  2. Machine Learning
  3. Tensorflow

Dataset

The MNIST dataset is used to create the TFLite model.

Dataset demo

Python code link: CLICK HERE to open the code in the Google Colab.

First, create the TFLite model then upload it in the assets folder of the Android Studio.

Android

1. activity_main.xml

<?xml version=”1.0" encoding=”utf-8"?>
<LinearLayout
xmlns:android=”http://schemas.android.com/apk/res/android"
xmlns:tools=”http://schemas.android.com/tools"
android:layout_width=”match_parent”
android:layout_height=”match_parent”
android:orientation=”vertical”
android:background=”@android:color/black”
tools:context=”com.codewithgolap.tflite.mnist.MainActivity”>
<TableLayout
android:layout_width=”match_parent”
android:layout_height=”wrap_content”
android:padding=”16dp”
android:background=”@android:color/white”>
<TextView
style=”@style/ResultText”
android:fontFamily=”@font/poppins_bold”
android:text=”@string/prediction”
android:textColor=”@android:color/black”
android:textSize=”19sp”
android:letterSpacing=”0.05"/>
<TextView
android:id=”@+id/tv_prediction”
style=”@style/ResultText”
android:text=”@string/empty”
android:textColor=”@android:color/black”
android:textSize=”24sp”/>
<TableRow
android:layout_marginTop=”16dp”>
<TextView
style=”@style/ResultText”
android:text=”@string/probability”
android:background=”#FFC107"
android:textColor=”@color/colorPrimary”
android:fontFamily=”@font/poppins_medium”
android:letterSpacing=”0.02"
android:layout_marginEnd=”2dp”/>
<TextView
style=”@style/ResultText”
android:text=”@string/timecost”
android:background=”#FFC107"
android:textColor=”@color/colorPrimary”
android:fontFamily=”@font/poppins_medium”
android:letterSpacing=”0.02"
android:layout_marginStart=”2dp”/>
</TableRow>

<TableRow>
<TextView
android:id=”@+id/tv_probability”
style=”@style/ResultText”
android:text=”@string/empty”
android:background=”#FFEB3B”
android:textColor=”@color/colorPrimary”
android:fontFamily=”@font/poppins_medium”
android:letterSpacing=”0.02"
android:layout_marginEnd=”2dp”/>
<TextView
android:id=”@+id/tv_timecost”
style=”@style/ResultText”
android:text=”@string/empty”
android:background=”#FFEB3B”
android:textColor=”@color/colorPrimary”
android:fontFamily=”@font/poppins_medium”
android:letterSpacing=”0.02"
android:layout_marginStart=”2dp”/>
</TableRow>
</TableLayout>
<com.nex3z.fingerpaintview.FingerPaintView
android:id=”@+id/fingerPaintView”
android:layout_width=”280dp”
android:layout_height=”280dp”
android:layout_marginTop=”48dp”
android:layout_gravity=”center”
android:background=”@android:color/white”
android:foreground=”@drawable/shape_rect_border”/>
<LinearLayout
android:layout_width=”match_parent”
android:layout_height=”wrap_content”
android:orientation=”horizontal”
android:paddingStart=”16dp”
android:paddingEnd=”16dp”
android:layout_marginTop=”48dp”>
<Button
android:layout_width=”match_parent”
android:layout_height=”wrap_content”
android:id=”@+id/btn_detect”
android:text=”@string/detect”
android:layout_weight=”1"
android:textSize=”16sp”
android:fontFamily=”@font/poppins_medium”/>
<Button
android:layout_width=”match_parent”
android:layout_height=”wrap_content”
android:id=”@+id/btn_clear”
android:text=”@string/clear”
android:layout_weight=”1"
android:textSize=”16sp”
android:fontFamily=”@font/poppins_medium”/>
</LinearLayout>
</LinearLayout>

2. Classifier.java

package com.codewithgolap.tflite.mnist
import android.content.Context
import android.graphics.Bitmap
import android.os.SystemClock
import android.util.Log
import android.util.Size
import org.tensorflow.lite.Delegate
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.Tensor
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.nnapi.NnApiDelegate
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.io.Closeable
import java.nio.ByteBuffer
import java.nio.ByteOrder
class Classifier(
context: Context,
device: Device = Device.CPU,
numThreads: Int = 4
) {
private val delegate: Delegate? = when(device) {
Device.CPU -> null
Device.NNAPI -> NnApiDelegate()
Device.GPU -> GpuDelegate()
}
private val interpreter: Interpreter = Interpreter(
FileUtil.loadMappedFile(context, MODEL_FILE_NAME),
Interpreter.Options().apply {
setNumThreads(numThreads)
delegate?.let { addDelegate(it) }
}
)
private val inputTensor: Tensor = interpreter.getInputTensor(0)
private val outputTensor: Tensor = interpreter.getOutputTensor(0)
val inputShape: Size = with(inputTensor.shape()) { Size(this[2], this[1]) }
private val imagePixels = IntArray(inputShape.height * inputShape.width)
private val imageBuffer: ByteBuffer =
ByteBuffer.allocateDirect(4 * inputShape.height * inputShape.width).apply {
order(ByteOrder.nativeOrder())
}
private val outputBuffer: TensorBuffer =
TensorBuffer.createFixedSize(outputTensor.shape(), outputTensor.dataType())
init {
Log.v(
LOG_TAG, “[Input] shape = ${inputTensor.shape()?.contentToString()}, “ +
“dataType = ${inputTensor.dataType()}”)
Log.v(
LOG_TAG, “[Output] shape = ${outputTensor.shape()?.contentToString()}, “ +
“dataType = ${outputTensor.dataType()}”)
}
fun classify(image: Bitmap): Recognition {
convertBitmapToByteBuffer(image)
val start = SystemClock.uptimeMillis()
interpreter.run(imageBuffer, outputBuffer.buffer.rewind())
val end = SystemClock.uptimeMillis()
val timeCost = end — start
val probs = outputBuffer.floatArray
val top = probs.argMax()
Log.v(LOG_TAG, “classify(): timeCost = $timeCost, top = $top, probs = ${probs.contentToString()}”)
return Recognition(top, probs[top], timeCost)
}
fun close() {
interpreter.close()
if (delegate is Closeable) {
delegate.close()
}
}
private fun convertBitmapToByteBuffer(bitmap: Bitmap) {
imageBuffer.rewind()
bitmap.getPixels(imagePixels, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
for (i in 0 until inputShape.width * inputShape.height) {
val pixel: Int = imagePixels[i]
imageBuffer.putFloat(convertPixel(pixel))
}
}
private fun convertPixel(color: Int): Float {
return (255 — ((color shr 16 and 0xFF) * 0.299f
+ (color shr 8 and 0xFF) * 0.587f
+ (color and 0xFF) * 0.114f)) / 255.0f
}
companion object {
private val LOG_TAG: String = Classifier::class.java.simpleName
private const val MODEL_FILE_NAME: String = “mnist.tflite”
}
}
fun FloatArray.argMax(): Int {
return this.withIndex().maxByOrNull { it.value }?.index
?: throw IllegalArgumentException(“Cannot find arg max in empty list”)
}

3. Device.java

package com.codewithgolap.tflite.mnist
enum class Device {
CPU,
NNAPI,
GPU
}

4. Recognition.java

package com.codewithgolap.tflite.mnist
data class Recognition(
val label: Int,
val confidence: Float,
val timeCost: Long
)

5. MainActivity.java

package com.codewithgolap.tflite.mnist
import android.graphics.Bitmap
import android.os.Bundle
import android.util.Log
import android.widget.Toast
import androidx.appcompat.app.AppCompatActivity
import kotlinx.android.synthetic.main.activity_main.*
import java.io.IOException
class MainActivity : AppCompatActivity() {
// call out classifier class
private lateinit var classifier: Classifier
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
init()
}
private fun init() {
initClassifier()
initView()
}
// if your internet connection and classifier code is okay there will be no failed msg
private fun initClassifier() {
try {
classifier = Classifier(this)
Log.v(LOG_TAG, “Classifier initialized”)
} catch (e: IOException) {
Toast.makeText(this, R.string.failed_to_create_classifier, Toast.LENGTH_LONG).show()
Log.e(LOG_TAG, “init(): Failed to create Classifier”, e)
}
}
// buttons click events
private fun initView() {
btn_detect.setOnClickListener { onDetectClick() }
btn_clear.setOnClickListener { clearResult() }
}
private fun onDetectClick() {
if (!this::classifier.isInitialized) {
Log.e(LOG_TAG, “onDetectClick(): Classifier is not initialized”)
return
} else if (fingerPaintView.isEmpty) {
Toast.makeText(this, R.string.please_write_a_digit, Toast.LENGTH_SHORT).show()
return
}
// when we draw sometihing on the finerpaint view it will call the renderResult function
val image: Bitmap = fingerPaintView.exportToBitmap(
classifier.inputShape.width, classifier.inputShape.height
)
val result = classifier.classify(image)
renderResult(result)
}
// in this function we will get the label that is the digit, confirence that is the probability and the time cost
private fun renderResult(result: Recognition) {
tv_prediction.text = java.lang.String.valueOf(result.label)
tv_probability.text = java.lang.String.valueOf(result.confidence)
tv_timecost.text = java.lang.String.format(
getString(R.string.timecost_value),
result.timeCost
)
}
// when click the clear button all data will be gone
private fun clearResult() {
fingerPaintView.clear()
tv_prediction.setText(R.string.empty)
tv_probability.setText(R.string.empty)
tv_timecost.setText(R.string.empty)
}
override fun onDestroy() {
super.onDestroy()
// classifier.close()
}
companion object {
private val LOG_TAG: String = MainActivity::class.java.simpleName
}
}

Watch the full video to know more step by step

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Golap Gunjan Barman

Hi everyone, myself Golap an Android app developer with UI/UX designer.