YOLOv8 classification 안드로이드

알로에·2023년 4월 28일
2

YOLOv8-Classification

목록 보기
1/1

YOLOv8 classification

YOLOv8 classification을 안드로이드에 적용해보자.

classification은 말 그대로 분류이다.

https://docs.ultralytics.com/tasks/ 사이트에 있는 사진이다.

classification을 왜 쓰는가?

말 그대로 사진에서 가장 높은 확률의 객체를 찾는 모델이다. Object Detection과 다른 점이 있다면 하나의 객체를 결과로 가져오며, 바운딩 박스가 따로 존재하지 않는다.
그럼 당연히 object detection을 하면 되는거 아닌가 라고 생각할 수 있다.
classification의 장점은 추론 속도가 매우 빠르다. 만약 바운딩 박스를 찾을 필요가 없고 다수의 객체를 찾을 필요가 없다면 classification 모델로 충분하다.
내 핸드폰은 갤럭시 S21 Ultra 인데, Object Detection은 한 장의 사진을 추론하는데 0.2초 정도가 걸린다. classification은 한 장의 사진을 추론하는데 35ms 정도가 걸린다. 대충 5~6배의 빠른 추론 속도를 가진다. 따라서 상황에 맞게 모델을 쓰면 될 듯 하다.

✔ 1. 모델 변환하기

https://github.com/ultralytics/ultralytics 사이트에서 사전 학습된 classification 모델을 다운 받는다.

nano 모델을 다운 받는다.

다운 받은 경로에 파이썬 파일을 생성하고, 아래와 같이 작성한다.

from ultralytics import YOLO

# Load a model
model = YOLO('yolov8n-cls.pt')  # load an official model

# Export the model
model.export(format='onnx')

ultralytics가 설치 되지 않았다면 cmd창에서 ultralytics를 먼저 설치하면 된다.

pip install ultralytics

파이썬 파일을 만들지 않아도, cmd 창에서도 바로 onnx를 변환할 수 있다.

yolo export model=yolov8n-cls.pt format=onnx 

편한 방식대로 .pt 모델을 .onnx 모델로 변환하면 된다.

변환이 끝나면 아래 사이트를 통해서 백본을 확인할 수 있다.
https://netron.app/

모델의 input을 확인한다.


1개의 사진에 대해 가로, 세로 224 사이즈이며 RGB 3차원인 것을 확인할 수 있다.

다음은 모델의 output 이다.

학습된 데이터는 총 1000개이다. 한 장의 사진에 대해 1000개의 확률 값이 출력임을 확인할 수 있다.

Object Detection과 달리 바운딩 박스가 없으므로 NMS도 할 필요가 없다. 따라서 1000개의 확률 값에서 제일 높은 값을 가진 클래스를 결과로 표출하면 된다.

✔ 2. 앱 생성

새로운 안드로이드 앱을 생성하고, 모델과 라벨링된 txt 파일을 assets 폴더에 넣는다. 좌측 상단에 Android → Projcet 로 변경한 뒤에 app → src → main 폴더 우클릭 후 new → Directory에 assets 폴더를 생성하면 된다.

위에서 변환했던 onnx 모델과 라벨링 txt파일을 저장하면 된다.
라벨링 데이터는 아래 사진과 같이 그냥 각 0 ~ 999 까지의 클래스가 한 줄씩 저장되어있는 텍스트 파일이다.

이후 다시 좌측 상단에 Projcet를 Android로 변경한다.

✔ 3. 라이브러리 추가

앱 수준의 graddle에 카메라를 사용할 라이브러리와 onnx 모델을 추론할 onnxruntime 라이브러리를 추가한다.

 // https://mvnrepository.com/artifact/com.microsoft.onnxruntime/onnxruntime-android
implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.12.1'

// CameraX core library using the camera2 implementation
def camerax_version = "1.3.0-alpha06"
// The following line is optional, as the core library is included indirectly by camera-camera2
implementation "androidx.camera:camera-core:${camerax_version}"
implementation "androidx.camera:camera-camera2:${camerax_version}"
// If you want to additionally use the CameraX Lifecycle library
implementation("androidx.camera:camera-lifecycle:${camerax_version}")
// If you want to additionally use the CameraX View class
implementation("androidx.camera:camera-view:${camerax_version}")

✔ 4. 카메라 생성

  1. Manifest 파일에서 카메라 권한을 설정한다.
<uses-permission android:name="android.permission.CAMERA" />
  1. 화면에 보여줄 xml 파일을 수정한다. 우리는 main activity에서 바로 실행한다. 따라서 activity_main.xml 파일을 수정한다. 기본 앱에서 text를 지우고 아래와 같이 추가한다.
 <androidx.camera.view.PreviewView
        android:id="@+id/previewView"
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        app:layout_constraintBottom_toBottomOf="parent"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toTopOf="parent" >

    </androidx.camera.view.PreviewView>

    <TextView
        android:id="@+id/textView"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginTop="600dp"
        android:textColor="#000000"
        android:textSize="20sp"
        app:layout_constraintBottom_toBottomOf="parent"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toTopOf="parent" />

아래 사진 처럼 추가하면 된다.

PreviewView는 카메라로 부터 받아온 사진을 화면에 보여줄 view이다. TextView는 추론해서 나온 결과 (ex: 컴퓨터 마우스, 키보드, ...) 를 보여줄 view이다.

  1. 메인 액티비티 onCreate 내부에 아래 코드를 추가한다.
previewView = findViewById(R.id.previewView)
textView = findViewById(R.id.textView)
window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)

xml에서 만든 previewView와 textview를 가져오고, 화면이 꺼지지 않게 설정한다.

  1. 권한 요청
    onCreate에 아래 메서드를 추가하고 메서드를 정의한다.
//onCreate에 추가 
setPermissions()
private fun setPermissions() {
        val permissions = ArrayList<String>()
        permissions.add(android.Manifest.permission.CAMERA)

        permissions.forEach {
            if (ActivityCompat.checkSelfPermission(this, it) != PackageManager.PERMISSION_GRANTED) {
                ActivityCompat.requestPermissions(this, permissions.toTypedArray(), 1)
            }
       }
}

그리고 onRequestPermissionsResult를 오버라이딩해서 요청을 거절하면 앱을 종료하게 설정한다.

   override fun onRequestPermissionsResult(
        requestCode: Int,
        permissions: Array<out String>,
        grantResults: IntArray
    ) {
        if (requestCode == 1) {
            grantResults.forEach {
                if (it != PackageManager.PERMISSION_GRANTED) {
                    Toast.makeText(this, "권한을 허용하지 않으면 사용할 수 없습니다!", Toast.LENGTH_SHORT).show()
                    finish()
                }
            }
        }
        super.onRequestPermissionsResult(requestCode, permissions, grantResults)
    }
  1. 카메라 설정
    oncreate에서 아래 메서드를 추가하고 정의하는 부분이다.
//onCreate 내부 
   setCamera()
  private fun setCamera() {
        //카메라 제공 객체
        val processCameraProvider = ProcessCameraProvider.getInstance(this).get()

        //전체 화면
        previewView.scaleType = PreviewView.ScaleType.FILL_CENTER

        //후면 카메라
        val cameraSelector =
            CameraSelector.Builder().requireLensFacing(CameraSelector.LENS_FACING_BACK).build()

        val resolutionSelector = ResolutionSelector.Builder()
            .setAspectRatioStrategy(AspectRatioStrategy.RATIO_16_9_FALLBACK_AUTO_STRATEGY).build()

        // 16:9 화면으로 받아옴
        val preview = Preview.Builder().setResolutionSelector(resolutionSelector).build()

        // preview 에서 받아와서 previewView 에 보여준다.
        preview.setSurfaceProvider(previewView.surfaceProvider)

        //분석 중이면 그 다음 화면이 대기중인 것이 아니라 계속 받아오는 화면으로 새로고침 함. 분석이 끝나면 그 최신 사진을 다시 분석
        val analysis = ImageAnalysis.Builder().setResolutionSelector(resolutionSelector)
            .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST).build()

        //여기서 it == imageProxy 객체이다.
        analysis.setAnalyzer(Executors.newSingleThreadExecutor()) {
            imageProcess(it)
            it.close()
        }

        // 카메라의 수명 주기를 메인 액티비티에 귀속
        processCameraProvider.bindToLifecycle(this, cameraSelector, preview, analysis)
    }

 private fun imageProcess(imageProxy: ImageProxy) {
 // 추후에 여기서 이미지를 가지고 추론할 부분 
 }

주석으로 간략히 쓰여있지만, 카메라 객체를 생성하고 카메라 설정을 하는 단계이다. 16:9 화면으로 받아와서 해당 사진을 나중에 분석하는 메서드인 imageProcess에서 추론 하는 코드를 적으면 된다.

✔ 5. 모델 불러오기

assets안에 있는 파일은 바로 사용할 수 없다. 모델을 불러오고 각종 처리를 담당할 클래스를 새로 생성한다. DataProcess 라는 이름의 새 클래스를 생성한다.

클래스 내부에 모델명과 라벨링 파일, 그 외 입력에 필요한 사이즈를 미리 설정한다.

companion object {
        const val BATCH_SIZE = 1
        const val INPUT_SIZE = 224
        const val PIXEL_SIZE = 3
        const val FILE_NAME = "yolov8n-cls.onnx"
        const val LABEL_NAME = "yolov8n-cls.txt"
    }

클래스 내부에 각각 모델과 라벨링 파일을 불러오는 메서드를 정의한다.

private lateinit var classes: Array<String>

fun loadLabel(context: Context) {
        // txt 파일 불러오기
        BufferedReader(InputStreamReader(context.assets.open(LABEL_NAME))).use { reader ->
            var line: String?
            val classList = ArrayList<String>()
            while (reader.readLine().also { line = it } != null) {
                classList.add(line!!)
            }
            classes = classList.toTypedArray()
        }
    }

    fun loadModel(context: Context) {
        //onnx 파일 불러오기
        val assetManager = context.assets
        val outputFile = File(context.filesDir.toString() + "/" + FILE_NAME)

        assetManager.open(FILE_NAME).use { inputStream ->
            FileOutputStream(outputFile).use { outputStream ->
                val buffer = ByteArray(1024)
                var read: Int
                while (inputStream.read(buffer).also { read = it } != -1) {
                    outputStream.write(buffer, 0, read)
                }
            }
        }
    }

classes라는 배열 안에는 1000개의 이름이 담겨있다.

메인 액티비티에서 onnx 추론할 객체를 선언하고, 모델을 불러오는 코드를 추가한다.

private lateinit var dataProcess: DataProcess
private lateinit var ortEnvironment: OrtEnvironment
private lateinit var session: OrtSession

override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
        previewView = findViewById(R.id.previewView)
        textView = findViewById(R.id.textView)
        window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)

        setPermissions()
        dataProcess = DataProcess()
        load()
        setCamera()
}

private fun load() {
        dataProcess.loadModel(this) // onnx 모델 불러오기
        dataProcess.loadLabel(this) // coco txt 파일 불러오기

        ortEnvironment = OrtEnvironment.getEnvironment()
        session = ortEnvironment.createSession(
            this.filesDir.absolutePath.toString() + "/" + DataProcess.FILE_NAME,
            OrtSession.SessionOptions()
        )
}

load 함수에서 이전에 정의했던 모델을 불러오는 메서드를 실행하고, 추론할 session 객체를 생성한다.

✔ 6. 추론

  1. 사진의 전처리
    이전에 카메라에서 받아온 화면을 imageProxy 객체로 받아오게 된다.
    추론에 바로 이 imageProxy를 입력으로 넣을 수 없다. 224x224x3 (가로 세로 RGB)의 FloayBuffer로 모델의 input이 되어야 하기 때문에 적절한 전처리를 해줘야 한다.

따라서 DataProcess 클래스 내부에 아래 코드를 추가한다.

 fun imageToBitmap(imageProxy: ImageProxy): Bitmap {
        val bitmap = imageProxy.toBitmap()
        val matrix = Matrix().apply { postRotate(90f) }
        val scaledBitmap = Bitmap.createScaledBitmap(bitmap, INPUT_SIZE, INPUT_SIZE, true)
        return Bitmap.createBitmap(
            scaledBitmap,
            0,
            0,
            scaledBitmap.width,
            scaledBitmap.height,
            matrix,
            true
        )
}


    fun bitmapToFloatBuffer(bitmap: Bitmap): FloatBuffer {
        val imageSTD = 255f
        val buffer = FloatBuffer.allocate(BATCH_SIZE * PIXEL_SIZE * INPUT_SIZE * INPUT_SIZE)
        buffer.rewind()

        val area = INPUT_SIZE * INPUT_SIZE
        val bitmapData = IntArray(area)
        bitmap.getPixels(
            bitmapData,
            0,
            bitmap.width,
            0,
            0,
            bitmap.width,
            bitmap.height
        ) //배열에 RGB 담기

        //하나씩 받아서 버퍼에 할당
        for (i in 0 until INPUT_SIZE - 1) {
            for (j in 0 until INPUT_SIZE - 1) {
                val idx = INPUT_SIZE * i + j
                val pixelValue = bitmapData[idx]
                // 위에서 부터 차례대로 R 값 추출, G 값 추출, B값 추출 -> 255로 나누어서 0~1 사이로 정규화
                buffer.put(idx, ((pixelValue shr 16 and 0xff) / imageSTD))
                buffer.put(idx + area, ((pixelValue shr 8 and 0xff) / imageSTD))
                buffer.put(idx + area * 2, ((pixelValue and 0xff) / imageSTD))
                //원리 bitmap == ARGB 형태의 32bit, R값의 시작은 16bit (16 ~ 23bit 가 R영역), 따라서 16bit 를 쉬프트
                //그럼 A값이 사라진 RGB 값인 24bit 가 남는다. 이후 255와 AND 연산을 통해 맨 뒤 8bit 인 R값만 가져오고, 255로 나누어 정규화를 한다.
                //다시 8bit 를 쉬프트 하여 R값을 제거한 G,B 값만 남은 곳에 다시 AND 연산, 255 정규화, 다시 반복해서 RGB 값을 buffer 에 담는다.
            }
        }
        buffer.rewind()
        return buffer
}

첫 번째 메서드는 ImageProxy 객체를 받아와서 비트맵으로 변환하는 코드이다. 카메라에 받아온 화면을 디버깅해보면 사진이 -90도 회전되어있는 상태임을 알 수 있다. 따라서 다시 90도 회전을 하고, 사이즈를 모델의 입력 사이즈에 맞게 224 224 사이즈로 축소 시킨다.
두 번째 메서드는 비트맵 사진을 받아와서 FloatBuffer에 담는 과정이다. 사진의 R,G,B 픽셀 값을 담는 과정이다.

위에서 메인 엑티비티에서 정의했던 이미지 분석을 하는 메서드에 전처리 코드를 추가하면 된다.

 private fun imageProcess(imageProxy: ImageProxy) {
        val bitmap = dataProcess.imageToBitmap(imageProxy)
        val floatBuffer = dataProcess.bitmapToFloatBuffer(bitmap)
        val inputName = session.inputNames.iterator().next() // session 이름
        //모델의 요구 입력값 [1 3 224 224] [배치 사이즈, 픽셀(RGB), 너비, 높이], 모델마다 크기는 다를 수 있음.
        val shape = longArrayOf(
            DataProcess.BATCH_SIZE.toLong(),
            DataProcess.PIXEL_SIZE.toLong(),
            DataProcess.INPUT_SIZE.toLong(),
            DataProcess.INPUT_SIZE.toLong()
        )
        val inputTensor = OnnxTensor.createTensor(ortEnvironment, floatBuffer, shape)
        val resultTensor = session.run(Collections.singletonMap(inputName, inputTensor))
        val outputs = resultTensor.get(0).value as Array<*> // [1 10000]
    }

위에서 말한 것과 같이 imageProxy에서 bitmap, floatbuffer에 사진의 내용을 담고, 추론을 진행하게 된다.

✔ 7. 후처리

추론된 결과 값은 [1 1000] 으로 1000개의 클래스에 대한 확률 값이 담겨있다.
이제 이 1000개의 값 중에서 가장 높은 값을 가진 클래스를 찾아 문자열로 변경하면 된다.
DataProcess 클래스 내부에 아래 메서드를 추가한다.

  //제일 높은 값 하나만 반환
fun getHighConf(outputs: Array<*>): Int? {
        val confThresholds = 0.6f
        val output = outputs[0] as FloatArray
        return output.withIndex().filter { it.value >= confThresholds }
            .maxByOrNull { it.value }?.index
}

 fun getClassName(i: Int?): String? {
        return if (i != null) {
            classes[i]
        } else null
}

첫 번째 메서드는 1000개의 클래스 중에서 가장 확률 값이 높은 클래스의 인덱스 값을 반환한다.
두 번째 메서드는 인덱스 값에 맞는 라벨링 문자를 반환하는 코드이다.
이제 이 내용을 imageProcess 메서드에 추가하면 된다.

private fun imageProcess(imageProxy: ImageProxy) {
        val bitmap = dataProcess.imageToBitmap(imageProxy)
        val floatBuffer = dataProcess.bitmapToFloatBuffer(bitmap)
        val inputName = session.inputNames.iterator().next() // session 이름
        //모델의 요구 입력값 [1 3 224 224] [배치 사이즈, 픽셀(RGB), 너비, 높이], 모델마다 크기는 다를 수 있음.
        val shape = longArrayOf(
            DataProcess.BATCH_SIZE.toLong(),
            DataProcess.PIXEL_SIZE.toLong(),
            DataProcess.INPUT_SIZE.toLong(),
            DataProcess.INPUT_SIZE.toLong()
        )
        val inputTensor = OnnxTensor.createTensor(ortEnvironment, floatBuffer, shape)
        val resultTensor = session.run(Collections.singletonMap(inputName, inputTensor))
        val outputs = resultTensor.get(0).value as Array<*> // [1 1000]
        //새로 추가된 부분
        val index = dataProcess.getHighConf(outputs)
        val name = dataProcess.getClassName(index)

        runOnUiThread {
            name?.let { textView.text = it }
        }
    }

받아온 글자를 화면에 보여주면 완성이다. 아래는 그 예시이다.

아래는 전체 코드이다.

//메인 액티비티

import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import android.content.pm.PackageManager
import android.os.Bundle
import android.view.WindowManager
import android.widget.TextView
import android.widget.Toast
import androidx.appcompat.app.AppCompatActivity
import androidx.camera.core.CameraSelector
import androidx.camera.core.ImageAnalysis
import androidx.camera.core.ImageProxy
import androidx.camera.core.Preview
import androidx.camera.core.resolutionselector.AspectRatioStrategy
import androidx.camera.core.resolutionselector.ResolutionSelector
import androidx.camera.lifecycle.ProcessCameraProvider
import androidx.camera.view.PreviewView
import androidx.core.app.ActivityCompat
import java.util.*
import java.util.concurrent.Executors

class MainActivity : AppCompatActivity() {

    private lateinit var previewView: PreviewView
    private lateinit var textView: TextView
    private lateinit var dataProcess: DataProcess
    private lateinit var ortEnvironment: OrtEnvironment
    private lateinit var session: OrtSession

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
        previewView = findViewById(R.id.previewView)
        textView = findViewById(R.id.textView)
        window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)

        setPermissions()
        dataProcess = DataProcess()
        load()
        setCamera()
    }

    private fun setCamera() {
        //카메라 제공 객체
        val processCameraProvider = ProcessCameraProvider.getInstance(this).get()

        //전체 화면
        previewView.scaleType = PreviewView.ScaleType.FILL_CENTER

        // 후면 카메라
        val cameraSelector =
            CameraSelector.Builder().requireLensFacing(CameraSelector.LENS_FACING_BACK).build()

        val resolutionSelector = ResolutionSelector.Builder()
            .setAspectRatioStrategy(AspectRatioStrategy.RATIO_16_9_FALLBACK_AUTO_STRATEGY).build()

        // 16:9 화면으로 받아옴
        val preview = Preview.Builder().setResolutionSelector(resolutionSelector).build()

        // preview 에서 받아와서 previewView 에 보여준다.
        preview.setSurfaceProvider(previewView.surfaceProvider)

        //분석 중이면 그 다음 화면이 대기중인 것이 아니라 계속 받아오는 화면으로 새로고침 함. 분석이 끝나면 그 최신 사진을 다시 분석
        val analysis = ImageAnalysis.Builder().setResolutionSelector(resolutionSelector)
            .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST).build()

        //여기서 it == imageProxy 객체이다.
        analysis.setAnalyzer(Executors.newSingleThreadExecutor()) {
            imageProcess(it)
            it.close()
        }

        // 카메라의 수명 주기를 메인 액티비티에 귀속
        processCameraProvider.bindToLifecycle(this, cameraSelector, preview, analysis)
    }

    // 이미지 처리 s21 Ultra == 35ms ~ 42ms
    private fun imageProcess(imageProxy: ImageProxy) {
        val bitmap = dataProcess.imageToBitmap(imageProxy)
        val floatBuffer = dataProcess.bitmapToFloatBuffer(bitmap)
        val inputName = session.inputNames.iterator().next() // session 이름
        //모델의 요구 입력값 [1 3 224 224] [배치 사이즈, 픽셀(RGB), 너비, 높이], 모델마다 크기는 다를 수 있음.
        val shape = longArrayOf(
            DataProcess.BATCH_SIZE.toLong(),
            DataProcess.PIXEL_SIZE.toLong(),
            DataProcess.INPUT_SIZE.toLong(),
            DataProcess.INPUT_SIZE.toLong()
        )
        val inputTensor = OnnxTensor.createTensor(ortEnvironment, floatBuffer, shape)
        val resultTensor = session.run(Collections.singletonMap(inputName, inputTensor))
        val outputs = resultTensor.get(0).value as Array<*> // [1 1000]
        val index = dataProcess.getHighConf(outputs)
        val name = dataProcess.getClassName(index)

        runOnUiThread {
            name?.let { textView.text = it }
        }
    }

    private fun load() {
        dataProcess.loadModel(this) // onnx 모델 불러오기
        dataProcess.loadLabel(this) // coco txt 파일 불러오기

        ortEnvironment = OrtEnvironment.getEnvironment()
        session = ortEnvironment.createSession(
            this.filesDir.absolutePath.toString() + "/" + DataProcess.FILE_NAME,
            OrtSession.SessionOptions()
        )
    }

    override fun onRequestPermissionsResult(
        requestCode: Int,
        permissions: Array<out String>,
        grantResults: IntArray
    ) {
        if (requestCode == 1) {
            grantResults.forEach {
                if (it != PackageManager.PERMISSION_GRANTED) {
                    Toast.makeText(this, "권한을 허용하지 않으면 사용할 수 없습니다!", Toast.LENGTH_SHORT).show()
                    finish()
                }
            }
        }
        super.onRequestPermissionsResult(requestCode, permissions, grantResults)
    }

    private fun setPermissions() {
        val permissions = ArrayList<String>()
        permissions.add(android.Manifest.permission.CAMERA)

        permissions.forEach {
            if (ActivityCompat.checkSelfPermission(this, it) != PackageManager.PERMISSION_GRANTED) {
                ActivityCompat.requestPermissions(this, permissions.toTypedArray(), 1)
            }
        }
    }
}
//데이터 처리를 위한 DataProcess 클래스

import android.content.Context
import android.graphics.Bitmap
import android.graphics.Matrix
import androidx.camera.core.ImageProxy
import java.io.BufferedReader
import java.io.File
import java.io.FileOutputStream
import java.io.InputStreamReader
import java.nio.FloatBuffer

class DataProcess {

    private lateinit var classes: Array<String>

    companion object {
        const val BATCH_SIZE = 1
        const val INPUT_SIZE = 224
        const val PIXEL_SIZE = 3
        const val FILE_NAME = "yolov8n-cls.onnx"
        const val LABEL_NAME = "yolov8n-cls.txt"
    }

//    // 여러 개 반환 추천하지 않음
//    fun classNames(classList: List<Int>): String {
//        var names = ""
//        classList.forEachIndexed { index, it ->
//            if (index != 0) {
//                names += ", "
//            }
//            names += classes[it]
//        }
//        return names
//    }

    fun getClassName(i: Int?): String? {
        return if (i != null) {
            classes[i]
        } else null
    }

//    // 여러 개 반환 추천하지 않음
//    fun dataConfThresh(outputs: Array<*>): List<Int> {
//        val confThresholds = 0.2f
//        val output = outputs[0] as FloatArray
//        return output.withIndex().filter { it.value >= confThresholds }.map { it.index }
//    }

    //제일 높은 값 하나만 반환
    fun getHighConf(outputs: Array<*>): Int? {
        val confThresholds = 0.6f
        val output = outputs[0] as FloatArray
        return output.withIndex().filter { it.value >= confThresholds }
            .maxByOrNull { it.value }?.index
    }

    fun imageToBitmap(imageProxy: ImageProxy): Bitmap {
        val bitmap = imageProxy.toBitmap()
        val matrix = Matrix().apply { postRotate(90f) }
        val scaledBitmap = Bitmap.createScaledBitmap(bitmap, INPUT_SIZE, INPUT_SIZE, true)
        return Bitmap.createBitmap(
            scaledBitmap,
            0,
            0,
            scaledBitmap.width,
            scaledBitmap.height,
            matrix,
            true
        )
    }


    fun bitmapToFloatBuffer(bitmap: Bitmap): FloatBuffer {
        val imageSTD = 255f
        val buffer = FloatBuffer.allocate(BATCH_SIZE * PIXEL_SIZE * INPUT_SIZE * INPUT_SIZE)
        buffer.rewind()

        val area = INPUT_SIZE * INPUT_SIZE
        val bitmapData = IntArray(area)
        bitmap.getPixels(
            bitmapData,
            0,
            bitmap.width,
            0,
            0,
            bitmap.width,
            bitmap.height
        ) //배열에 RGB 담기

        //하나씩 받아서 버퍼에 할당
        for (i in 0 until INPUT_SIZE - 1) {
            for (j in 0 until INPUT_SIZE - 1) {
                val idx = INPUT_SIZE * i + j
                val pixelValue = bitmapData[idx]
                // 위에서 부터 차례대로 R 값 추출, G 값 추출, B값 추출 -> 255로 나누어서 0~1 사이로 정규화
                buffer.put(idx, ((pixelValue shr 16 and 0xff) / imageSTD))
                buffer.put(idx + area, ((pixelValue shr 8 and 0xff) / imageSTD))
                buffer.put(idx + area * 2, ((pixelValue and 0xff) / imageSTD))
                //원리 bitmap == ARGB 형태의 32bit, R값의 시작은 16bit (16 ~ 23bit 가 R영역), 따라서 16bit 를 쉬프트
                //그럼 A값이 사라진 RGB 값인 24bit 가 남는다. 이후 255와 AND 연산을 통해 맨 뒤 8bit 인 R값만 가져오고, 255로 나누어 정규화를 한다.
                //다시 8bit 를 쉬프트 하여 R값을 제거한 G,B 값만 남은 곳에 다시 AND 연산, 255 정규화, 다시 반복해서 RGB 값을 buffer 에 담는다.
            }
        }
        buffer.rewind()
        return buffer
    }

    fun loadLabel(context: Context) {
        // txt 파일 불러오기
        BufferedReader(InputStreamReader(context.assets.open(LABEL_NAME))).use { reader ->
            var line: String?
            val classList = ArrayList<String>()
            while (reader.readLine().also { line = it } != null) {
                classList.add(line!!)
            }
            classes = classList.toTypedArray()
        }
    }

    fun loadModel(context: Context) {
        //onnx 파일 불러오기
        val assetManager = context.assets
        val outputFile = File(context.filesDir.toString() + "/" + FILE_NAME)

        assetManager.open(FILE_NAME).use { inputStream ->
            FileOutputStream(outputFile).use { outputStream ->
                val buffer = ByteArray(1024)
                var read: Int
                while (inputStream.read(buffer).also { read = it } != -1) {
                    outputStream.write(buffer, 0, read)
                }
            }
        }
    }

}

classification은 object detection에 비해 코드도 적을 뿐 아니라 추론 속도도 굉장히 빠른 편이다. 초보자도 쉽게 사용할 수 있을 듯 하다.

아래는 깃허브 주소이다. 모든 코드를 보고 싶거나, 모델이나 라벨링 텍스트 파일을 다운 받고 싶으면 참고하면 될 듯 하다.
https://github.com/Yurve/YOLOv8_Classification_android

2개의 댓글

comment-user-thumbnail
2023년 7월 13일

쭉 둘러보는데 너무 유익하네요. 감사합니다!

1개의 답글