[Android] tflite 모델로 수어 지문자 번역

deBaeloper08·2022년 10월 18일
0

MediaPipe를 사용하여 얻은 손의 LandMark들을 갖고 수어 지문자를 번역하고자 한다. 모델 같은 경우 직접 학습하지 않고 학습된 모델을 사용했다.

수어 모델 Github

res폴더에 Assets 폴더를 생성한 후 tflite 모델 파일을 Assets 폴더에 넣어준다.

// Tensorflow
implementation 'org.tensorflow:tensorflow-lite:2.10.0'
implementation 'org.tensorflow:tensorflow-lite-task-vision-play-services:0.4.2'
implementation 'com.google.android.gms:play-services-tflite-gpu:16.0.0'
implementation 'org.tensorflow:tensorflow-lite-task-vision:0.4.0'

app gradle 파일에 해당 코드를 작성하고 sync를 해준다.

private fun getTfliteInterpreter(path: String): Interpreter? {
        try {
            return Interpreter(loadModelFile(this@MainActivity, path)!!)
        }
        catch (e: Exception) {
            e.printStackTrace()
        }
        return null
    }

private fun loadModelFile(activity: Activity, path: String): MappedByteBuffer? {
        val fileDescriptor = activity.assets.openFd(path)
        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        val fileChannel = inputStream.channel
        val startOffset = fileDescriptor.startOffset
        val declaredLength = fileDescriptor.declaredLength
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
    }

getTfliteInterpreter 함수는 path 경로에 있는 tflite 모델을 불러오는 함수다.

hands.setResultListener {
    translate(it)
    glSurfaceView.setRenderData(it)
    glSurfaceView.requestRender()
}

Mediapipe를 통해 얻은 결과물을 translate 함수에 인자값으로 넘겨준다.

private fun translate(result : HandsResult){
        if (result.multiHandLandmarks().isEmpty()) {
            return
        }
        val landmarkList = result.multiHandLandmarks()[0].landmarkList
        val joint = Array(21){FloatArray(3)}
        for(i in 0..19) {
            joint[i][0] = landmarkList[i].x
            joint[i][1] = landmarkList[i].y
            joint[i][2] = landmarkList[i].z
        }

        val v1 = joint.slice(0..19).toMutableList()
        for(i in 4..16 step(4)) {
            v1[i] = v1[0]
        }
        var v2 = joint.slice(1..20)
        val v = Array(20) { FloatArray(3) }

        for(i in 0..19) {
            for(j in 0..2) {
                v[i][j] = v2[i][j] - v1[i][j]
            }
        }

        for(i in 0..19) {
            val norm = sqrt(v[i][0] * v[i][0] + v[i][1] * v[i][1] + v[i][2] * v[i][2])
            for(j in 0..2) {
                v[i][j] /= norm
            }
        }

        val tmpv1 = mutableListOf<FloatArray>()
        for(i in 0..18) {
            if(i != 3 && i != 7 && i != 11 && i != 15) {
                tmpv1.add(v[i])
            }
        }
        val tmpv2 = mutableListOf<FloatArray>()
        for(i in 1..19) {
            if(i != 4 && i != 8 && i != 12 && i != 16) {
                tmpv2.add(v[i])
            }
        }

        val einsum = FloatArray(15)
        for( i in 0..14) {
            einsum[i] = tmpv1[i][0] * tmpv2[i][0] + tmpv1[i][1] * tmpv2[i][1] + 
            	tmpv1[i][2] * tmpv2[i][2]
        }
        val angle = FloatArray(15)
        val data = FloatArray(15)
        for(i in 0..14) {
            angle[i] = Math.toDegrees(acos(einsum[i]).toDouble()).toFloat()
            data[i] = round(angle[i] * 100000) / 100000
        }

        val interpreter = getTfliteInterpreter("converted_model.tflite")
        val byteBuffer = ByteBuffer.allocateDirect(15*4).order(ByteOrder.nativeOrder())

        for(d in data) {
            byteBuffer.putFloat(d)
        }

        val modelOutput = ByteBuffer.allocateDirect(26*4).order(ByteOrder.nativeOrder())
        modelOutput.rewind()

        interpreter!!.run(byteBuffer,modelOutput)

        val outputFeature0 = TensorBuffer.createFixedSize(intArrayOf(1,26), DataType.FLOAT32)
        outputFeature0.loadBuffer(modelOutput)

        // ByteBuffer to FloatBuffer
        val outputsFloatBuffer = modelOutput.asFloatBuffer()
        val outputs = mutableListOf<Float>()
        for(i in 1..26) {
            outputs.add(outputsFloatBuffer.get())
        }

        val sortedOutput = outputs.sortedDescending()
        val index = outputs.indexOf(sortedOutput[0])

        Log.d("TAG", "translate: ${classes[index]}")
    }

translate 함수는 Mediapipe로 손의 LandMark 값들을 받아 수어로 번역하는 함수다. 필자가 사용한 모델의 경우 손의 LandMark들이 이루는 각도를 계산하여 수어를 번역한다. main.py에서 데이터를 가공하여 input에 넣는 과정을 kotlin으로 변환하여 입력값을 넣어주었다.

이처럼 Log에 수어를 번역한 결과를 확인할 수 있다.

profile
안녕하세요! 안드로이드 개발자입니다😊

2개의 댓글

comment-user-thumbnail
2023년 5월 23일

너무 도움이 되는데 어떻게 해야할지 모르겠어요 ㅠㅠ
어떤 클래스에 어떤걸 넣어야하는지..ㅠㅠㅠ

답글 달기
comment-user-thumbnail
2024년 1월 19일

안녕하세요, 비슷한 프로젝트를 진행하고 있는 대학생입니다.
mediapipe를 학습시킨 모델을 tflite 로 저장하는것까지는 성공했는데
안드로이드에 올리는 과정에서 많은 오류가 나던 중 이 글을 발견하게 되었습니다.
혹시 실례가 되지 않는다면 안드로이드 파일을 받아볼 수 있을까요?
제가 안드로이드에는 익숙하지 못해서 이 설명만으로는 따라하기가 조금 힘들어서요..
절박한 마음에 댓글 작성해봅니다
감사합니다

답글 달기