[Android] Tensorflow Lite로 딥러닝 모델 추론하기

Choi Sang Rok·2022년 4월 8일
1

딥러닝 모델을 안드로이드에서 사용하기 위해 Tensorflow Lite를 알게 되었습니다. 학습된 딥러닝 모델을 추론하기만 하면 되는 것이라서 전처리나 후처리가 중요하다고 생각합니다.
https://www.tensorflow.org/lite/?hl=ko
해당 글은 위 문서를 참조하였습니다.


✔ tflite 모델로 Interpreter 객체 생성

✔ Input data를 모델의 format에 맞게 resizing

✔ 추론(run)

✔ output 해석



1. tflite 모델로 interpreter 객체를 생성

딥러닝에서 안드로이드 개발자의 역할은, 이미 AI 엔지니어 단에서 학습된 딥러닝 모델을, Tensorflow Lite Library를 사용하여 .tflite 파일을 다루는 것 입니다.


Assets 디렉토리에 있는 tflite 형식의 모델을 사용하기 위해서는 먼저 다음과 같은 과정을 거쳐야 합니다.
  1. 모델을 메모리에 매핑
  2. 추론 향상 도구 등 옵션 설정
  3. 결과로 나타날 클래스 파일 목록 적재

위의 이 초기화 과정은,

모델의 초기화 및 추론을 담당하는 tensorflow 패키지의 클래스 Classifier의 constructor 에서 구현합니다.

protected constructor(activity: Activity, device: Device, numThreads: Int) {
...

parameter : Assets 폴더에 접근하기 위한 Activity, 추론에 사용될 런타임 장치 유형 Device, 스레드의 개수 Int

enum class Device {
CPU,NNAPI,GPU
}

Device는 열거 클래스로, 추론에 사용될 런타임 디바이스 종류를 포함합니다.

val model = loadModelFile(activity).also{
tfliteModel =it
}

Assets 디렉토리에 있는 tflite 파일을 메모리에 매핑하고 MappedByteBuffer 객체를 반환합니다.

when (device) {
    Device.NNAPI-> tfliteOptions.setUseNNAPI(true)
    Device.GPU-> gpuDelegate = GpuDelegate().also{
tfliteOptions.addDelegate(it)
}
else -> { /* Device.CPU */ }
}

NNAPI, GPU, CPU 사용 여부에 따라 옵션을 설정합니다. 현재 코드에서는 CPU를 사용합니다.

tfliteOptions.setNumThreads(numThreads)
tflite = Interpreter(model, tfliteOptions)
labels = loadLabelList(activity)
imgData = ByteBuffer.allocateDirect(DIM_BATCH_SIZE * getImageSizeX() * getImageSizeY()
        * getDimPixelSize() * getNumBytesPerChannel())
    .apply{
order(ByteOrder.nativeOrder())
}

스레드 개수를 설정하고, 모델 추론을 위한 클래스 Interpreter을 생성합니다.

그리고 Assets 폴더의 label 텍스트 파일로부터 클래스(surprise, happiness, ...) 리스트를 반환합니다.

마지막으로는 이미지 데이터 크기를 배치 사이즈 이미지 크기 픽셀 사이즈 * 채널 당 바이트 수로 지정합니다.



2. Input Data를 모델의 format에 맞게 resizing

카메라로부터 얻어오는 이미지는 YUV 형식이라서 이를 RGB 형식으로 변환하고, X, Y 크기, 비율 변환 등의 작업을 거칩니다. 그리고 비트맵 형식의 이미지를 ByteBuffer로 변환합니다.

util 패키지의 클래스 VisionAnalyzer에서 해당 과정이 진행됩니다.

class VisionAnalyzer(private val activity: Activity) : ImageAnalysis.Analyzer {

...

private val mClassifier: Classifier = Classifier.create(activity, Classifier.Device.CPU, 1)

VisionAnalyzer 클래스에서 Classifier 객체가 생성됩니다. Classifier의 create는 어떻게 생겼을까요?

abstract class Classifier{
		companion object{

		  fun create(activity: Activity, device: Device, numThreads: Int): C      lassifier
		    = MercyVisionClassifier(activity, device, numThreads)
		}
}

Classifier 클래스를 보면, create를 호출하였을 때 자식 클래스 MercyVisionClassifier를 호출하는 것을 볼 수 있습니다.

class MercyVisionClassifier(activity: Activity, device: Device, numThreads: Int)
    : Classifier(activity, device, numThreads) {

    ...

    override fun getDimPixelSize(): Int = 2

    override fun getImageSizeX(): Int = 24

    ...

Classifier은 공통 메서드, 속성을 포함하는 베이스 클래스였고, MercyVisionClassifier 객체를 생성해야 비로소 모델 파일 경로, 이미지 사이즈 등의 속성을 가지는 분류기 클래스가 생성됩니다. 다시 그러면 VisionAnalyzer로 돌아가 봅시다.

class VisionAnalyzer(private val activity: Activity) : ImageAnalysis.Analyzer {
     
    ...
    
    override fun analyze(image: ImageProxy, rotationDegrees: Int) {
		
		ImageUtils.convertYUV420SPToARGB8888(yuvBuffer, image.width, image.height, rgbBytes)
		}
}

analyze로 들어오는 ImageProxy는 YUV 포멧입니다. 우선 이를 RGB 포멧비트맵으로 변환해야 합니다. 이는 ImageUtils.convertYUV420SPToARGB8888 메서드에 ImageProxy의 Y, U ,V의 값을 유의미한 배열로 변환하여 넘기면 rgbBytes에 반환됩니다.

val croppedBitmap = Bitmap.createBitmap(
    mClassifier.getImageSizeX(),
    mClassifier.getImageSizeY(),
    Bitmap.Config.ARGB_8888)

val canvas = Canvas(croppedBitmap).apply {
                drawBitmap(rgbFrameBitmap, frameToCropTransform, null)
}
val results = mClassifier.recognizeImage(croppedBitmap)

그 후 croppedBitmap에 RGB 형식, 모델의 크기 만큼의 비트맵을 생성하고, 카메라 회전 각도로 인하여 변환시킬 Matrix를 가지고 변환된 rgb 비트맵 형식의 이미지를 그려줍니다.

이제 Classifier의 recognizeImage를 호출할 수 있습니다.

fun recognizeImage(bitmap: Bitmap): List<Recognition> {
	   
		...
    
    convertBitmapToByteBuffer(bitmap)
    
    ...
}

recognizeImage 메서드를 호출하면, 마지막으로 Bitmap을 생성자에서 만들었던 imgData에 ByteBuffer에 매핑합니다. 비로소 모델 추론이 가능하게 됩니다.



3. 추론(run)

한줄, tflite.run 메서드를 호출하면 끝납니다. 모든건 이 한줄을 위해..

fun recognize(bitmap: Bitmap): List<Recognition> {
    
		...

    runInference()
   
    ...
}
override fun runInference() {
    tflite?.run(imgData, labelProbArray)
}

tflite.run을 호출하면, ByteBuffer 형식의 이미지 데이터를 넘겨서 모델에서 추론 후 결과를 labelProbArray에 반환합니다.

private val labelProbArray: Array<FloatArray> = *arrayOf*(FloatArray(getNumLabels()))

labelProbArray는 실수배열을 요소로 가지는 배열입니다. 이미지를 추론한 결과는

[0][클래스 번호(인덱스 번호)] = 추론 값(정확도) 형식으로 저장됩니다.

anger
disgust/contempt
afraid
happiness
sadness
surprise
neutral

ex) labelProbArray[0][2] = 0.414 → disgust의 정확도가 0.414



4. output 해석

fun recognizeImage(bitmap: Bitmap): List<Recognition> {
    
    val pq = PriorityQueue<Recognition>(3,Comparator{lhs: Recognition, rhs: Recognition->
return@Comparator rhs.confidence.compareTo(lhs.confidence)
})
    for (i in labels.indices) {
        pq.add(Recognition(i.toString(),
            if (labels.size > i) labels[i] else "unknown",
            getNormalizedProbability(i),
            null))
    }
    val recognitions = ArrayList<Recognition>()
    val recognitionSize =min(pq.size, MAX_RESULTS)
    for (i in 0untilrecognitionSize) {
        recognitions.add(pq.poll())
    }
    return recognitions
}

runInference() 다음 코드입니다.

size 3의 우선순위 큐를 생성합니다. 큐의 원소는 분류 결과 클래스를 담고, 신뢰도는 내림차순으로 정렬됩니다

레이블의 개수(클래스의 개수) 만큼 반복하면서 우선순위 큐에 삽입합니다.

그리고 결과 클래스가 담길 list를 만들고 우선순위 큐의 사이즈와 MAX_RESULTS를 비교하는데, 둘다 3이므로 사이즈는 3이 될 것입니다.

우선순위 큐에서 poll()한 값이 list에 담길 것이고,

결과적으로 신뢰도가 가장 높은 세 개의 클래스의 정보가 담긴 리스트가 반환됩니다.

override fun updateGraph(results: List<Classifier.Companion.Recognition>) {
    for (recognition in results) {
        Log.d(TAG, "Recognition(id=${recognition.id}, title=${recognition.title}, confidence=${recognition.confidence}")
        val index = emotions.indexOf(recognition.title)
        Log.d(TAG, "index: $index")
        mGraphData[index].y= recognition.confidence
    }
    chart.notifyDataSetChanged()
    chart.invalidate()
}

앱을 실행시켜서 emotion graph를 보면 총 3개의 막대만 움직이는 것을 확인할 수 있습니다. 이는 이미지에 대해 딥러닝 모델이 추론한 결과 중 신뢰도가 가장 높은 3개의 클래스가 리스트로 반환되기 때문입니다.

profile
android_developer

5개의 댓글

comment-user-thumbnail
2022년 4월 8일

글을 쓸 때 왜 이 글을 쓰게되었는지 이유를 적는것도 좋을 것 같습니다~^^~

답글 달기
comment-user-thumbnail
2022년 4월 8일

글의 기승전결이 없는 느낌이 나네요.. 이전 글은 정말 잘 쓰셨는데 이번 글은 승전전보(승전보맞음) 같은 느낌이 들었습니다.. 좀 더 분발 부탁드립니다.

답글 달기
comment-user-thumbnail
2022년 4월 13일

어쩔티비

답글 달기
comment-user-thumbnail
2022년 4월 30일

자알 봤습니다아

답글 달기
comment-user-thumbnail
2022년 6월 15일

혹시 논문에 기재된 내용인가요? 제가 쓴 논문이랑 너무 비슷한데;;
해명해주시기 바랍니다.

답글 달기