LLama2 정리

Plato·2023년 12월 22일
0

딥러닝

목록 보기
8/21

아키텍처

  • LLama1과 거의 동일함
  • 차이
    • Multi Head Attention(MHA) 대신 Grouped Query Attention(GQA) 사용
      • 이유: 학습과 추론 비용을 낮추기 위해
      • MHA
        • Query Key 그리고 Value 텐서가 동일한 개수의 헤드를 가짐
      • Multi Query Attention(MQA)
        • Query는 여러 개의 헤드를 갖고 Key와 Value는 한 개의 헤드를 가짐
      • GQA
        • Query head를 n 개의 그룹으로 나눔
          • 각 그룹에 한 개의 Key와 Value 헤드를 할당함
        • 이유: MHA는 많은 연산을 요구하는 대신 추론 능력이 우월하지만 MQA는 적은 연산을 요구하는 대신 추론 능력이 떨어짐.
          • GQA는 MHA보다 더 적은 수의 Key, Value 헤드를 가짐
            • 그래서 MHA보다 적은 연산량을 요구함
          • MQA보다 더 많은 수의 Key, Value 헤드를 가짐
            • 그래서 MQA보다 추론 능력이 좋음
    • LLama1 대비 context length를 두 배 늘림

하이퍼 파라미터

  • LLama1과 동일

사전 학습 데이터 세트

  • LLama1보다 큰 데이터 세트
  • 해로운 학습 데이터를 공격적으로 제거하지 않음
    • 이유: 해로운 데이터가 포함된 데이터 세트에 학습할 때, 안전성 확보를 위한 미세 조정의 결과가 좋았음

미세 조정

  • Supervised Fine Tuning(SFT)
    • 27540 개의 prompt/answer pair에 미세 조정
    • 크지만 낮은 품질의 미세 조정 데이터 세트보다, 작더라도 높은 품질의 데이터 세트가 유용함을 실험적으로 확인
    • 프롬프트와 답변 사이에 특수 토큰 삽입
      • 프롬프트 토큰에 대한 손실 함수를 0으로 설정
  • RLHF
    • 사전 학습된 llama2-base에 reward를 생성하는 헤드를 부착하여 학습
    • 보상 모델 학습 -> 보상을 사용하여 llama2 미세 조정 -> 미세 조정한 llama2 출력에 대한 선호 데이터 수집 -> 보상 모델 추가 학습 -> llama2 미세 조정 -> ... 반복
      • 이유: 미세 조정한 llama2 출력으로 보상 모델을 추가 학습하지 않으면, llama2가 보상 모델의 허점을 악용할 수 있음.
    • 안전과 유용성에 대한 보상을 따로 학습함
      • 이유: 안전과 유용성 사이에 trade-off가 존재할 수 있음. 그렇기에 한 개의 보상 모델이 안전과 유용성을 둘 다 고려하기는 어려울 수 있다.
      • 다만 안전에 대한 보상을 학습할 때, 안전 데이터 세트뿐 아니라 유용성 데이터 세트도 일부분 포함함
        • 비율 safety:helpfulness=90:10safety:helpfulness=90:10
        • 이유: llama2의 두 출력이 다 안전할 때, 도움 되는 출력을 선호하도록 만들기 위함
      • 도움되는 정도에 대한 보상을 학습할 때도 safety 데이터 세트를 일부분 포함시킴
    • annotator가 선호하는 출력뿐만 아니라 선택한 출력이 얼마나 더 나은지도 답변
      • 가능한 답변: 매우 낫다, 낫다, 조금 낫다, 미세하게 낫다 혹은 잘 모르겠다
    • loss=log(σ(rθ(x,yc)rθ(x,yr)))loss = −log(σ(r_θ(x, y_c) − r_θ(x, y_r)))
      • rθ(x,y)r_\theta(x,y)는 파라미터 θ\theta를 갖는 보상 모델이 (프롬프트 xx, 모델 출력 yy)쌍에 할당한 보상
      • 사람이 선호하는 출력이 ycy_c 그리고 선호하지 않는 출력이 yry_r
      • 이 손실 함수는, 선호된 출력이 그렇지 못한 출력보다 더 높은 보상을 갖도록 유도함
    • 많이 나은 출력이 선호되지 않는 출력보다 많이 큰 보상을 갖도록 손실 함수를 아래와 같이 수정함
      • loss=log(σ(rθ(x,yc)rθ(x,yr)m(r)))loss = −log(σ(r_θ(x, y_c) − r_θ(x, y_r) - m(r)))
        • m(r)m(r)은 얼마나 나은지에 대한 답변을 나타냄. 이산적인 값을 갖고 값의 크기는 다음과 같음. 매우 낫다 > 낫다 > 조금 낫다> 미세하게 낫다 혹은 잘 모르겠다
    • 보상 모델의 중요성
      • 보상 모델의 정확성으로 LLAMA2-CHAT의 성능을 쉽게 추측할 수 있었다. 보상 모델이 정확해지면 정확해질수록 LLAMA2-CHAT의 성능이 좋아졌다는 것. 당연한 얘기이긴 하다
    • RL 알고리즘
      • Proximal Policy Optimization(PPO)
      • Rejection Sampling fine-tuning
        • llama2로 K개의 출력을 생성한 뒤, 제일 큰 보상을 갖는 출력을 정답으로 지도 학습
        • 다만 제일 최신 모델의 출력만 사용하지 않고 RLHF 과정 동안의 모든 모델 출력 중에서 제일 큰 보상을 갖는 출력을 정답으로 사용
          • 이유: 최신 모델의 출력만으로 학습하니 특정 작업의 수행 능력이 떨어짐. 이전에 학습한 내용을 잊는 것으로 판단하여 과거 여러 모델의 출력을 사용하도록 만듦
      • 초반에는 rejection sampling으로 미세 조정하고 후반에는 K개의 출력으로 PPO를 진행함
    • multi-turn dialogue
      • 대화의 시작부터 끝까지 attend 할 지침이 있을 수 있다. ex) "이제부터 너는 벤자민 프랭클린이야. 생각과 행동 그리고 말 모두 벤자민 프랭클린처럼 하도록 해."
        • 그런데 오랫동안 기억해야 할 지침을 LLAMA2가 까먹는 문제 발생
        • 해결 방법: 지침을 오랫동안 따르는 데이터를 생성하여 미세 조정
          • 모든 유저 메시지에 지침을 prepend한 뒤 RLHF 모델로 답변 생성
            • 이유: 지침을 prepend하면 지침을 따르는 데이터를 생성할 수 있음
          • 답변 생성 후 유저 메시지에 prepend한 지침 삭제
            • 이유: 지침을 prepend하지 않더라도 지침을 따르는 모델을 원하기 때문
          • 마지막 turn에 속하는 토큰을 제외한 모든 토큰의 loss를 0으로 설정
            • 이유: 위의 방식으로 생성한 데이터에서, 유저 메시지는 원본과 동일하지만 어시스턴트 메시지는 RLHF 모델이 생성한 텍스트다. 원본의 유저 메시지는 원본의 어시스턴트 메시지에 대한 응답인데, 어시스턴트 메시지만 바뀐 것이기에 이상한 데이터가 만들어질 수 있다. context가 망가진 데이터이기 때문에, 마지막 turn으로만 학습한다...?
              • 필자의 이해가 맞았다면, 동의하기 어려운 방식이다. 나은 방식이 떠오르지 않지만, 마지막 turn의 토큰으로만 학습하는 게 이 문제를 어떻게 해결하는 건지 모르겠다.

0개의 댓글