아키텍처
- LLama1과 거의 동일함
- 차이
- Multi Head Attention(MHA) 대신 Grouped Query Attention(GQA) 사용
- 이유: 학습과 추론 비용을 낮추기 위해
- MHA
- Query Key 그리고 Value 텐서가 동일한 개수의 헤드를 가짐
- Multi Query Attention(MQA)
- Query는 여러 개의 헤드를 갖고 Key와 Value는 한 개의 헤드를 가짐
- GQA
- Query head를 n 개의 그룹으로 나눔
- 각 그룹에 한 개의 Key와 Value 헤드를 할당함
- 이유: MHA는 많은 연산을 요구하는 대신 추론 능력이 우월하지만 MQA는 적은 연산을 요구하는 대신 추론 능력이 떨어짐.
- GQA는 MHA보다 더 적은 수의 Key, Value 헤드를 가짐
- MQA보다 더 많은 수의 Key, Value 헤드를 가짐
- LLama1 대비 context length를 두 배 늘림
하이퍼 파라미터
사전 학습 데이터 세트
- LLama1보다 큰 데이터 세트
- 해로운 학습 데이터를 공격적으로 제거하지 않음
- 이유: 해로운 데이터가 포함된 데이터 세트에 학습할 때, 안전성 확보를 위한 미세 조정의 결과가 좋았음
미세 조정
- Supervised Fine Tuning(SFT)
- 27540 개의 prompt/answer pair에 미세 조정
- 크지만 낮은 품질의 미세 조정 데이터 세트보다, 작더라도 높은 품질의 데이터 세트가 유용함을 실험적으로 확인
- 프롬프트와 답변 사이에 특수 토큰 삽입
- 프롬프트 토큰에 대한 손실 함수를 0으로 설정
- RLHF
- 사전 학습된 llama2-base에 reward를 생성하는 헤드를 부착하여 학습
- 보상 모델 학습 -> 보상을 사용하여 llama2 미세 조정 -> 미세 조정한 llama2 출력에 대한 선호 데이터 수집 -> 보상 모델 추가 학습 -> llama2 미세 조정 -> ... 반복
- 이유: 미세 조정한 llama2 출력으로 보상 모델을 추가 학습하지 않으면, llama2가 보상 모델의 허점을 악용할 수 있음.
- 안전과 유용성에 대한 보상을 따로 학습함
- 이유: 안전과 유용성 사이에 trade-off가 존재할 수 있음. 그렇기에 한 개의 보상 모델이 안전과 유용성을 둘 다 고려하기는 어려울 수 있다.
- 다만 안전에 대한 보상을 학습할 때, 안전 데이터 세트뿐 아니라 유용성 데이터 세트도 일부분 포함함
- 비율 safety:helpfulness=90:10
- 이유: llama2의 두 출력이 다 안전할 때, 도움 되는 출력을 선호하도록 만들기 위함
- 도움되는 정도에 대한 보상을 학습할 때도 safety 데이터 세트를 일부분 포함시킴
- annotator가 선호하는 출력뿐만 아니라 선택한 출력이 얼마나 더 나은지도 답변
- 가능한 답변: 매우 낫다, 낫다, 조금 낫다, 미세하게 낫다 혹은 잘 모르겠다
- loss=−log(σ(rθ(x,yc)−rθ(x,yr)))
- rθ(x,y)는 파라미터 θ를 갖는 보상 모델이 (프롬프트 x, 모델 출력 y)쌍에 할당한 보상
- 사람이 선호하는 출력이 yc 그리고 선호하지 않는 출력이 yr
- 이 손실 함수는, 선호된 출력이 그렇지 못한 출력보다 더 높은 보상을 갖도록 유도함
- 많이 나은 출력이 선호되지 않는 출력보다 많이 큰 보상을 갖도록 손실 함수를 아래와 같이 수정함
- loss=−log(σ(rθ(x,yc)−rθ(x,yr)−m(r)))
- m(r)은 얼마나 나은지에 대한 답변을 나타냄. 이산적인 값을 갖고 값의 크기는 다음과 같음. 매우 낫다 > 낫다 > 조금 낫다> 미세하게 낫다 혹은 잘 모르겠다
- 보상 모델의 중요성
- 보상 모델의 정확성으로 LLAMA2-CHAT의 성능을 쉽게 추측할 수 있었다. 보상 모델이 정확해지면 정확해질수록 LLAMA2-CHAT의 성능이 좋아졌다는 것. 당연한 얘기이긴 하다
- RL 알고리즘
- Proximal Policy Optimization(PPO)
- Rejection Sampling fine-tuning
- llama2로 K개의 출력을 생성한 뒤, 제일 큰 보상을 갖는 출력을 정답으로 지도 학습
- 다만 제일 최신 모델의 출력만 사용하지 않고 RLHF 과정 동안의 모든 모델 출력 중에서 제일 큰 보상을 갖는 출력을 정답으로 사용
- 이유: 최신 모델의 출력만으로 학습하니 특정 작업의 수행 능력이 떨어짐. 이전에 학습한 내용을 잊는 것으로 판단하여 과거 여러 모델의 출력을 사용하도록 만듦
- 초반에는 rejection sampling으로 미세 조정하고 후반에는 K개의 출력으로 PPO를 진행함
- multi-turn dialogue
- 대화의 시작부터 끝까지 attend 할 지침이 있을 수 있다. ex) "이제부터 너는 벤자민 프랭클린이야. 생각과 행동 그리고 말 모두 벤자민 프랭클린처럼 하도록 해."
- 그런데 오랫동안 기억해야 할 지침을 LLAMA2가 까먹는 문제 발생
- 해결 방법: 지침을 오랫동안 따르는 데이터를 생성하여 미세 조정
- 모든 유저 메시지에 지침을 prepend한 뒤 RLHF 모델로 답변 생성
- 이유: 지침을 prepend하면 지침을 따르는 데이터를 생성할 수 있음
- 답변 생성 후 유저 메시지에 prepend한 지침 삭제
- 이유: 지침을 prepend하지 않더라도 지침을 따르는 모델을 원하기 때문
- 마지막 turn에 속하는 토큰을 제외한 모든 토큰의 loss를 0으로 설정
- 이유: 위의 방식으로 생성한 데이터에서, 유저 메시지는 원본과 동일하지만 어시스턴트 메시지는 RLHF 모델이 생성한 텍스트다. 원본의 유저 메시지는 원본의 어시스턴트 메시지에 대한 응답인데, 어시스턴트 메시지만 바뀐 것이기에 이상한 데이터가 만들어질 수 있다. context가 망가진 데이터이기 때문에, 마지막 turn으로만 학습한다...?
- 필자의 이해가 맞았다면, 동의하기 어려운 방식이다. 나은 방식이 떠오르지 않지만, 마지막 turn의 토큰으로만 학습하는 게 이 문제를 어떻게 해결하는 건지 모르겠다.