DQN (Deep Q Network)

Seulgi Kim·2023년 5월 1일
0

reinforce learning

목록 보기
6/14

강화학습의 종류

  • 가치 기반 강화학습 (Value-based Reinforcement Learning)
  • 정책 기반 강화학습 (Policy-Based Reinforcement Learning)

DQN은 대표적인 가치 기반 강화학습의 알고리즘이다.

Q Learning

큐 함수를 학습하여 최적의 큐 함수를 얻고 이를 통해 의사결정을 수행

Q 함수란?

주어진 상태에서 주어진 행동에 대한 가치를 도출하는 함수.

Q함수에 대한 벨만 최적 방정식

q(s,a)=maxπqπ(s,a)=E[Rt+1+γmaxaq(st+1,a)st=s,At=a)]q_*(s,a) = \max\limits_\pi q_\pi(s,a) = E[R_{t+1} + \gamma \max\limits_{a'} q_*(s_{t+1},a') | s_t = s, A_t = a)]

Q Learning 이란?

모델이 없는 환경에서 학습하는 알고리즘이다.
주어진 상태와 행동을 수행함에 따라 얻는 가치를 나타내는 가치 함수인 큐 함수를 학습함으로써 최적의 정책을 학습한다.
알고리즘이 시작되기 전에 Q 함수는 고정된 임의의 값을 가진다.
알고리즘이 시작되면 시간 tt에 에이전트가 상태 sts_t에서 행동 ata_t를 취하고 새로운 상태 st+1s_{t+1}로 전이한다. 이때 보상 rtr_t가 얻어지며 Q(st,at)Q(s_t,a_t)가 갱신된다.
이를 수식으로 나타내면 다음과 같다.

Q(st,at)(1α)Q(st,at)+α(rt+γmaxaQ(st+1,a))Q(s_t, a_t) \leftarrow (1-\alpha) \cdot Q(s_t, a_t) + \alpha \cdot (r_t + \gamma \cdot \max\limits_a Q(s_{t+1},a))

이때 첫번째 Q(st,at)Q(s_t, a_t)는 갱신되는 Q 함수 값이며, 두번째 Q(st,at)Q(s_t, a_t)는 갱신되기 전의 이전 값이다.
여기서 α\alpha는 learning rate로, 0보다 크고 1보다 작거나 같은 값을 가진다.

충분히 QQ 함수를 학습하면, QQ 함수가 수렴한다.
여기서 수렴한 상황은 갱신되는 Q(st,at)Q(s_t, a_t)가 갱신 전의 Q(st,at)Q(s_t, a_t)와 똑같아지는 지점으로, rt+γmaxaQ(st+1,a))=Q(st,at)r_t + \gamma \cdot \max\limits_a Q(s_{t+1}, a)) = Q(s_t, a_t) 인 상황이다.

기존 Q Learning의 문제점

모든 상태와 행동에 대한 큐 함수 값을 따로 저장하여 이를 이용해 학습을 수행하고 행동을 결정함.
매우 많은 상태와 행동이 존재하는 환경에서 사용이 어려움.

Deep Q Network

Deepmind에서 공개한 알고리즘.
강화학습과 인공신경망을 결합하고, 게임화면을 입력으로 학습.
논문 : Playing Atari with Deep Reinforcement Learning (2013), Nature: Human-level Control Through Deep Reinforcement Learning (2015)

CNN을 이용해 상태를 입력으로 받아 각 행동의 QQ 함수를 근사해줌.

DQN의 장점

CNN을 통해 각 상태와 행동에 대한 큐 함수 값을 근사
모든 상태와 행동에 대한 큐 함수 값을 따로 저장하지 않고 큐 함수 값에 대한 추정을 수행
즉, 많은 상태와 행동이 존재하는 환경에서도 학습이 가능함.

DQN 알고리즘의 흐름

CNN의 입력으로 현재 상태가 들어감
CNN을 통해 현재 상태에 여러 행동에 대한 Q 함수 값을 근사하고, 최적의 행동을 선택함.
수행한 행동에 따라 보상을 받고, 그에 따른 타겟값(Q 함수 값)과 손실함수를 계산.
타겟값과 예측값의 차이가 최소화 되도록 학습을 수행.
다음으로 변화된 상태를 현재 상태로 업데이트 하고 다시 CNN에 입력.

DQN 알고리즘의 기법

  1. 경험 리플레이 (Experience Replay)
  2. 타겟 네트워크 (Target Network)

Experience Replay

논문 : Reinforcement Learning for Robots using Neural Network (1993)
경험 데이터를 매 스텝마다 리플레이 메모리에 저장
(경험 데이터 = 현재 상태, 현재 행동, 보상, 다음 상태, 게임 종료 정보)
매 스텝마다 리플레이 메모리의 경험 데이터를 임의로 일정 개수만큼 추출하여 mini-batch 학습을 수행

Experience Replay의 장점

강화학습의 특성상, 시간 순으로 데이터를 얻고, 얻어진 데이터끼리는 상관관계가 크다.
상관관계가 있는 데이터로만 학습을 하면, 전체 데이터와는 관계없는 잘못된 관계를 얻을 수 있다.
하지만 Experience Replay 기법을 이용하면, 데이터간의 상관관계를 고려하지 않고 랜덤하게 샘플링하므로, 전체적인 데이터의 경향성을 잘 반영한 결과를 얻을 수 있다.

Target Network

DQN에서는 구조가 완전히 동일한 두 종류의 네트워크를 사용한다.
1. 일반 네트워크 : 행동을 결정하거나 큐 함수 값을 예측
2. 타겟 네트워크 : 학습에 필요한 타겟값을 계산

타겟값 y={r if game is finishedr+γmaxaQ(st+1,a;θ) if game is not finishedy = \begin{cases} r {\rm ~if~game~is~finished} \\ r + \gamma \max\limits_a Q(s_{t+1}, a;\theta^-) {\rm~if~game~is~not~finished} \end{cases}
(θ\theta = 일반 네트워크의 파라미터, θ\theta^- = 타겟 네트워크의 파라미터)

일반 네트워크는 매 스텝마다 업데이트가 이루어지며, 행동에 따른 Q 함수 값과 최적의 행동을 계산.
즉, 같은 입력을 해도 매 스텝마다 네트워크가 다른 답을 줄수도 있음.
만약 일반 네트워크를 타겟값을 계산하는데 사용한다면 매 스텝마다 타겟값이 변하며, 학습이 어려워진다.

타겟 네트워크는 매 특정 스텝마다 한번씩 일반 네트워크를 복제하며, 타겟값을 계산.
학습의 목표가 되는 타겟값을 최대한 일정하게 유지하며 안정적인 학습이 가능하도록 함.

DQN의 손실함수

Huber loss를 사용.

Error=yQ(st,at;θ)Error = y - Q(s_t, a_t; \theta)
Loss={Error2 if Error<1Error if Error1Loss = \begin{cases} Error^2 ~{\rm if}~ |Error| < 1 \\ |Error| ~{\rm if}~ |Error| \geq 1 \end{cases}

0개의 댓글