ml-agent 활용 예제 - 01. Unity 환경 제작하기

Seulgi Kim·2023년 5월 7일
0

reinforce learning

목록 보기
7/14

01. 게임 환경 제작하기

참고) https://www.youtube.com/watch?v=EqoU1PodQQ4&t=8297s

위 유투브를 따라서 제작하면 아래 사진과 같은 게임 환경을 손쉽게 제작할 수 있다.
여기서 우리가 원하는 바는 게임 환경 제작이 아닌 강화학습을 위한 환경 제작이므로, 수정사항이 몇 가지 존재한다.

02. Unity에 ml-agent 설치하기

가장 먼저 ml-agent를 Window-package manager를 통해 설치한다.

add packages from disk를 통해 원하는 ml-agent 버전을 설치할 수 있다.

03. agent 설정

이후 ml-agent를 사용하기 위한 함수 설정을 해주어야 하는데, 이는 다음 글을 참고했다.

https://github.com/Unity-Technologies/ml-agents/blob/develop/docs/Learning-Environment-Create-New.md

에이전트의 스크립트에서, 위에 ml-agent 관련 라이브러리 추가.

using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

에이전트의 스크립트에서, 클래스를 MonoBehavior에서 Agent로 수정.

여기서는 Bird라는 object에서 점프를 구현한 BirdJump 스크립트가 존재했는데, 이를 Agent라고 수정했다.

public class BirdJump : Agent

Start(), Update() 함수를 지우고 ml-agent 관련 함수 추가

Agent 클래스의 필수 함수로는 3가지가 있다.
Initialize() : 초기화 작업을 위해 한번 호출. Start()과 비슷한 기능.
OnEpisodeBegin() : 에피소드가 시작할 때마다 호출. 환경 초기화의 역할을 한다. 여기서는 에이전트(새)가 충돌이 발생할 때마다 한 에피소드가 종료되며 호출된다.
CollectObservations() : 환경 정보 관측과 수집의 역할을 맡은 함수.
OnActionReceived() : 정책으로부터 전달받은 행동을 실행하는 함수.

Initialize()

Rigidbody2D rb;
Transform tr;
public override void Initialize()
{
    rb = GetComponent<Rigidbody2D>();
    tr = GetComponent<Transform>();
}

에이전트의 속도와 위치를 관측하기 위해 GetComponent<> 를 이용하여 각각 변수에 담아준다.

OnepisodeBegin()

public override void OnEpisodeBegin()
{
    rb.velocity = Vector2.zero;
    tr.position = Vector2.zero;
    rb.angularVelocity = 0.0f;
    
    pipe.Reset();
}

에이전트가 충돌할 때마다 에피소드는 종료되며, 그때마다 환경을 초기화 해준다.
가장 먼저 에이전트의 위치와 속도를 초기화를 하며, 그 후 현재 게임에 존재하는 모든 파이프를 제거해주어야 한다.
하지만 기존에 알고있던 Destroy 메소드를 이용하여 단순히 Destroy(newpipe)를 해주면 가장 최근에 생성된 파이프만 삭제된다는 것을 찾았다.

이를 수정하기 위해서, Prefab을 이용해 생성된 모든 파이프 정보를 기억해야할 필요가 있다.
MakePipe.cs 파일 안의 스크립트를 다음과 같이 수정하자.

public class MakePipe : MonoBehaviour
{
    public GameObject pipe;
    public float timeDiff;
    public float deleteTime;
    float timer = 0;
    // Start is called before the first frame update
    void Start()
    {
        
    }

    // Update is called once per frame
    public GameObject newpipe;
    void Update()
    {
        timer += Time.deltaTime;
        if (timer > timeDiff)
        {
            newpipe = Instantiate(pipe);
            newpipe.transform.position = new Vector3(3, Random.Range(-2.7f, 2.7f), 0);
            newpipe.gameObject.tag = "TARGET";
            timer = 0;
            Destroy(newpipe, deleteTime);
        }
    }

    public void Reset()
    {
        GameObject[] gameObjects;
        gameObjects = GameObject.FindGameObjectsWithTag("TARGET");
        for (int i = 0; i < gameObjects.Length; i++)
        {
            Destroy(gameObjects[i]);
        }
    }
}

Reset() 명령이 들어오면 게임 오브젝트 중 TARGET 태그를 전부 찾아 제거하는 로직으로 짰다.

CollectObservations(VectorSensor sensor)

Transform instance;
public override void CollectObservations(VectorSensor sensor)
{
    GameObject[] gameObjects;
    gameObjects = GameObject.FindGameObjectsWithTag("TARGET");
    for (int i = 0; i < gameObjects.Length; i++)
    {
        instance = gameObjects[i].GetComponent<Transform>();
        sensor.AddObservation(instance.localPosition);
    }
    for (int i = gameObjects.Length; i < 3; i++)
    {
        sensor.AddObservation(Vector3.zero);
    }
           
    sensor.AddObservation(tr.localPosition);

    // Agent velocity
    sensor.AddObservation(rb.velocity.y);
}

이 함수에서는 브레인에 전달할 관측 정보를 적는다.
여기서는 에이전트의 위치, 떨어지는 속도, 파이프의 위치를 관측할 것이다.
그러기 위해서는 prefab으로 생성된 파이프의 위치를 가져와야 하는데, 다른 함수에서 Instantiate로 생성한 인스턴스에 바로 접근하는 방법을 모르겠다.
현재는 해당 인스턴스에 TARGET이라는 태그를 적고, FindGameObjectsWithTag 함수를 이용하여 태그로 접근한다.

OnActionReceived(ActionBuffers actionBuffers)

const int k_NoAction = 0;
const int k_Up = 1;
public override void OnActionReceived(ActionBuffers actionBuffers)
{
    var action = actionBuffers.DiscreteActions[0];

    switch (action)
    {
        case k_NoAction
    	    // do nothing
            break;
        case k_Up:
            rb.velocity = Vector2.up * jumpPower;
            break;
    }
}

에이전트는 가만히 있거나 점프를 하는 두 가지의 행동 중 하나를 선택해야만 한다.
따라서 DiscreteActions에 해당하고, Discrete Branch의 수는 1이다.
얻어지는 action에 따라 0이면 가만히 있기, 1이면 점프하기를 뜻한다.

OnCollisionEnter2D(Collision2D other)

private void OnCollisionEnter2D(Collision2D other)
{
    // SceneManager.LoadScene("GameOverScene");
    SetReward((float)Score.score);
    EndEpisode();
}

기존의 OnCollisionEnter2D 함수를 수정하여, GameOverScene 씬을 불러오는 것이 아니라 에피소드가 종료되게 만들었다.
지금까지 통과한 파이프의 개수인 Score.score를 보상으로 준다.

에피소드가 종료될 때 스코어를 보상으로 주면 충돌할 때 보상이 얻어져서 잘못된 학습이 이루어진다. 이를 해결하기 위하여 OnTriggerEnter2D() 함수를 추가하여 파이프 사이를 통과할 때마다 보상을 얻고 충돌하여 에피소드가 종료되면 -1의 패널티를 받게 수정하였다.

Heuristic(in ActionBuffers actionsOut)

public override void Heuristic(in ActionBuffers actionsOut)
{
    var discreteActionsOut = actionsOut.DiscreteActions;
    discreteActionsOut[0] = k_NoAction;
    if (Input.GetMouseButton(0))
    {
        discreteActionsOut[0] = k_Up;
    }
}

테스트를 위하여 Heuristic 함수도 생성해주었는데, OnActionReceived와 같이 아무것도 누르지 않았을 때는 discreteActionsOut[0]에 k_NoAction을, 마우스 클릭이 들어왔을 때는 k_Up을 입력해주었다.

Behavior Parameters 설정

이제 Behavior Parameters만 설정해주면 테스트를 해볼 수 있다.
Bird의 Add Components에서 Behavior Parameters를 추가하고, 알맞게 설정한다.

Vector Observation은 파이프의 위치 3개(x, y, z), 에이전트의 위치 3개 (x, y, z), 에이전트의 y축 속도 1개 로 총 7개이다.
Actions는 에이전트의 액션은 총 2가지로, 가만히 있거나 점프를 하는 discrete action이다.
점프 속도 항도 추가하여 Continuous Action으로 설정하여도 되나, 여기서는 Discrete Branches는 1, Branch 0 Size는 2로 설정하였다.
그 후 Behavior Type을 Heuristic Only로 설정하면 환경이 알맞게 제작되었는지 테스트 가능하다.

0개의 댓글