Android Handwriting digit prediction app using Machine Learning Model in Kotlin

Technology used:

Dataset

Dataset demo

First, create the TFLite model then upload it in the assets folder of the Android Studio.

Android

<?xml version=”1.0" encoding=”utf-8"?>
<LinearLayout
xmlns:android=”http://schemas.android.com/apk/res/android"
xmlns:tools=”http://schemas.android.com/tools"
android:layout_width=”match_parent”
android:layout_height=”match_parent”
android:orientation=”vertical”
android:background=”@android:color/black”
tools:context=”com.codewithgolap.tflite.mnist.MainActivity”>
<TableLayout
android:layout_width=”match_parent”
android:layout_height=”wrap_content”
android:padding=”16dp”
android:background=”@android:color/white”>
<TextView
style=”@style/ResultText”
android:fontFamily=”@font/poppins_bold”
android:text=”@string/prediction”
android:textColor=”@android:color/black”
android:textSize=”19sp”
android:letterSpacing=”0.05"/>
<TextView
android:id=”@+id/tv_prediction”
style=”@style/ResultText”
android:text=”@string/empty”
android:textColor=”@android:color/black”
android:textSize=”24sp”/>
<TableRow
android:layout_marginTop=”16dp”>
<TextView
style=”@style/ResultText”
android:text=”@string/probability”
android:background=”#FFC107"
android:textColor=”@color/colorPrimary”
android:fontFamily=”@font/poppins_medium”
android:letterSpacing=”0.02"
android:layout_marginEnd=”2dp”/>
<TextView
style=”@style/ResultText”
android:text=”@string/timecost”
android:background=”#FFC107"
android:textColor=”@color/colorPrimary”
android:fontFamily=”@font/poppins_medium”
android:letterSpacing=”0.02"
android:layout_marginStart=”2dp”/>
</TableRow>

<TableRow>
<TextView
android:id=”@+id/tv_probability”
style=”@style/ResultText”
android:text=”@string/empty”
android:background=”#FFEB3B”
android:textColor=”@color/colorPrimary”
android:fontFamily=”@font/poppins_medium”
android:letterSpacing=”0.02"
android:layout_marginEnd=”2dp”/>
<TextView
android:id=”@+id/tv_timecost”
style=”@style/ResultText”
android:text=”@string/empty”
android:background=”#FFEB3B”
android:textColor=”@color/colorPrimary”
android:fontFamily=”@font/poppins_medium”
android:letterSpacing=”0.02"
android:layout_marginStart=”2dp”/>
</TableRow>
</TableLayout>
<com.nex3z.fingerpaintview.FingerPaintView
android:id=”@+id/fingerPaintView”
android:layout_width=”280dp”
android:layout_height=”280dp”
android:layout_marginTop=”48dp”
android:layout_gravity=”center”
android:background=”@android:color/white”
android:foreground=”@drawable/shape_rect_border”/>
<LinearLayout
android:layout_width=”match_parent”
android:layout_height=”wrap_content”
android:orientation=”horizontal”
android:paddingStart=”16dp”
android:paddingEnd=”16dp”
android:layout_marginTop=”48dp”>
<Button
android:layout_width=”match_parent”
android:layout_height=”wrap_content”
android:id=”@+id/btn_detect”
android:text=”@string/detect”
android:layout_weight=”1"
android:textSize=”16sp”
android:fontFamily=”@font/poppins_medium”/>
<Button
android:layout_width=”match_parent”
android:layout_height=”wrap_content”
android:id=”@+id/btn_clear”
android:text=”@string/clear”
android:layout_weight=”1"
android:textSize=”16sp”
android:fontFamily=”@font/poppins_medium”/>
</LinearLayout>
</LinearLayout>

2. Classifier.java

package com.codewithgolap.tflite.mnist
import android.content.Context
import android.graphics.Bitmap
import android.os.SystemClock
import android.util.Log
import android.util.Size
import org.tensorflow.lite.Delegate
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.Tensor
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.nnapi.NnApiDelegate
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.io.Closeable
import java.nio.ByteBuffer
import java.nio.ByteOrder
class Classifier(
context: Context,
device: Device = Device.CPU,
numThreads: Int = 4
) {
private val delegate: Delegate? = when(device) {
Device.CPU -> null
Device.NNAPI -> NnApiDelegate()
Device.GPU -> GpuDelegate()
}
private val interpreter: Interpreter = Interpreter(
FileUtil.loadMappedFile(context, MODEL_FILE_NAME),
Interpreter.Options().apply {
setNumThreads(numThreads)
delegate?.let { addDelegate(it) }
}
)
private val inputTensor: Tensor = interpreter.getInputTensor(0)
private val outputTensor: Tensor = interpreter.getOutputTensor(0)
val inputShape: Size = with(inputTensor.shape()) { Size(this[2], this[1]) }
private val imagePixels = IntArray(inputShape.height * inputShape.width)
private val imageBuffer: ByteBuffer =
ByteBuffer.allocateDirect(4 * inputShape.height * inputShape.width).apply {
order(ByteOrder.nativeOrder())
}
private val outputBuffer: TensorBuffer =
TensorBuffer.createFixedSize(outputTensor.shape(), outputTensor.dataType())
init {
Log.v(
LOG_TAG, “[Input] shape = ${inputTensor.shape()?.contentToString()}, “ +
“dataType = ${inputTensor.dataType()}”)
Log.v(
LOG_TAG, “[Output] shape = ${outputTensor.shape()?.contentToString()}, “ +
“dataType = ${outputTensor.dataType()}”)
}
fun classify(image: Bitmap): Recognition {
convertBitmapToByteBuffer(image)
val start = SystemClock.uptimeMillis()
interpreter.run(imageBuffer, outputBuffer.buffer.rewind())
val end = SystemClock.uptimeMillis()
val timeCost = end — start
val probs = outputBuffer.floatArray
val top = probs.argMax()
Log.v(LOG_TAG, “classify(): timeCost = $timeCost, top = $top, probs = ${probs.contentToString()}”)
return Recognition(top, probs[top], timeCost)
}
fun close() {
interpreter.close()
if (delegate is Closeable) {
delegate.close()
}
}
private fun convertBitmapToByteBuffer(bitmap: Bitmap) {
imageBuffer.rewind()
bitmap.getPixels(imagePixels, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
for (i in 0 until inputShape.width * inputShape.height) {
val pixel: Int = imagePixels[i]
imageBuffer.putFloat(convertPixel(pixel))
}
}
private fun convertPixel(color: Int): Float {
return (255 — ((color shr 16 and 0xFF) * 0.299f
+ (color shr 8 and 0xFF) * 0.587f
+ (color and 0xFF) * 0.114f)) / 255.0f
}
companion object {
private val LOG_TAG: String = Classifier::class.java.simpleName
private const val MODEL_FILE_NAME: String = “mnist.tflite”
}
}
fun FloatArray.argMax(): Int {
return this.withIndex().maxByOrNull { it.value }?.index
?: throw IllegalArgumentException(“Cannot find arg max in empty list”)
}

package com.codewithgolap.tflite.mnist
enum class Device {
CPU,
NNAPI,
GPU
}

package com.codewithgolap.tflite.mnist
data class Recognition(
val label: Int,
val confidence: Float,
val timeCost: Long
)

package com.codewithgolap.tflite.mnist
import android.graphics.Bitmap
import android.os.Bundle
import android.util.Log
import android.widget.Toast
import androidx.appcompat.app.AppCompatActivity
import kotlinx.android.synthetic.main.activity_main.*
import java.io.IOException
class MainActivity : AppCompatActivity() {
// call out classifier class
private lateinit var classifier: Classifier
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
init()
}
private fun init() {
initClassifier()
initView()
}
// if your internet connection and classifier code is okay there will be no failed msg
private fun initClassifier() {
try {
classifier = Classifier(this)
Log.v(LOG_TAG, “Classifier initialized”)
} catch (e: IOException) {
Toast.makeText(this, R.string.failed_to_create_classifier, Toast.LENGTH_LONG).show()
Log.e(LOG_TAG, “init(): Failed to create Classifier”, e)
}
}
// buttons click events
private fun initView() {
btn_detect.setOnClickListener { onDetectClick() }
btn_clear.setOnClickListener { clearResult() }
}
private fun onDetectClick() {
if (!this::classifier.isInitialized) {
Log.e(LOG_TAG, “onDetectClick(): Classifier is not initialized”)
return
} else if (fingerPaintView.isEmpty) {
Toast.makeText(this, R.string.please_write_a_digit, Toast.LENGTH_SHORT).show()
return
}
// when we draw sometihing on the finerpaint view it will call the renderResult function
val image: Bitmap = fingerPaintView.exportToBitmap(
classifier.inputShape.width, classifier.inputShape.height
)
val result = classifier.classify(image)
renderResult(result)
}
// in this function we will get the label that is the digit, confirence that is the probability and the time cost
private fun renderResult(result: Recognition) {
tv_prediction.text = java.lang.String.valueOf(result.label)
tv_probability.text = java.lang.String.valueOf(result.confidence)
tv_timecost.text = java.lang.String.format(
getString(R.string.timecost_value),
result.timeCost
)
}
// when click the clear button all data will be gone
private fun clearResult() {
fingerPaintView.clear()
tv_prediction.setText(R.string.empty)
tv_probability.setText(R.string.empty)
tv_timecost.setText(R.string.empty)
}
override fun onDestroy() {
super.onDestroy()
// classifier.close()
}
companion object {
private val LOG_TAG: String = MainActivity::class.java.simpleName
}
}

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store