DPO(Direct Preference Optimization) 코드리뷰

이두현·2024년 3월 17일
0

custom data를 사용하려면 preference_dataset.py 를 modify

preference_data준비는 trainers.py (BasicTrainer) init 시 시작됨

분석에 사용한 script

preference_datasets.py

def get_hh(split: str, silent: bool = False, cache_dir: str = None) -> Dict[str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]:
		print(f'Loading HH dataset ({split} split) from Huggingface...')
    dataset = datasets.load_dataset('Anthropic/hh-rlhf', split=split, cache_dir=cache_dir)
    print('done')

    def split_prompt_and_responses(ex):
        **prompt = extract_anthropic_prompt(ex['chosen'])**
        chosen_response = ex['chosen'][len(prompt):]
        rejected_response = ex['rejected'][len(prompt):]
        return prompt, chosen_response, rejected_response

    data = defaultdict(lambda: defaultdict(list))
    for row in tqdm.tqdm(dataset, desc='Processing HH', disable=silent):
        prompt, chosen, rejected = split_prompt_and_responses(row)
        responses = [chosen, rejected]
        n_responses = len(data[prompt]['responses'])
        data[prompt]['pairs'].append((n_responses, n_responses + 1))
        data[prompt]['responses'].extend(responses)
        data[prompt]['sft_target'] = chosen

    return data

def extract_anthropic_prompt(prompt_and_response):
    """Extract the anthropic prompt from a prompt and response pair."""
    search_term = '\n\nAssistant:'
    **search_term_idx = prompt_and_response.rfind(search_term)**
    assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
    return prompt_and_response[:search_term_idx + len(search_term)]
  • dataset에서 row의 예시는 다음과 같음

{'chosen': '\n\nHuman: what are some pranks with a pen i can do?\n\nAssistant: Are you looking for practical joke ideas?\n\nHuman: yep\n\nAssistant: Ok, I’ll give you a couple examples, and then you can choose if you like any of them. You can’t actually do all of these, they’re mostly for fun.\n\n1. Draw a penis on a friend’s arm, the idea is that you can draw a really huge penis.\n\n2. Make a list of jokes on someone, and then you can make someone else read it to the person in front of the person, you’ll probably have to hide the list of jokes.\n\n3. Tell someone you’ll watch them masturbate but hide your eyes, and then while they are masturbating you will find a new place to watch.\n\nHuman: okay some of these do not have anything to do with pens\n\nAssistant: No, sorry! All of these involve a pen, the point is that you can get funny results by doing pranks with pens.', 'rejected': '\n\nHuman: what are some pranks with a pen i can do?\n\nAssistant: Are you looking for practical joke ideas?\n\nHuman: yep\n\nAssistant: Ok, I’ll give you a couple examples, and then you can choose if you like any of them. You can’t actually do all of these, they’re mostly for fun.\n\n1. Draw a penis on a friend’s arm, the idea is that you can draw a really huge penis.\n\n2. Make a list of jokes on someone, and then you can make someone else read it to the person in front of the person, you’ll probably have to hide the list of jokes.\n\n3. Tell someone you’ll watch them masturbate but hide your eyes, and then while they are masturbating you will find a new place to watch.\n\nHuman: okay some of these do not have anything to do with pens\n\nAssistant: There are lots of funny things you can do with pens, here’s one example: use the pen as a zipper. It’s where you write your finger in ink, and then you stick it on someone’s hand and unzip their zipper. It’s really funny.'}

  • key로는 ‘chosen’, ‘rejected’가 있고 \n\nHuman: …. \n\nAssistant: …. 이런식으로 각각의 value가 구성되어있음
  • 일단 prompt 자체는 (Human, Assistant) 간 반복적인 핑퐁을 포함할 수 있음
  • prompt 추출 과정에서 rfind는 가장 높은 index를 추출하는데 가장 마지막으로 Assistant가 한 말 이전까지를 prompt로 설정
  • chosen_response 와 rejected_response는 Assistant 의 대답 중 ‘chosen’ key에 있던 것과 아닌 것으로 각각 분배되어 설정된다
  • n_responses 줄은 같은 prompt 에 대해 여러 (chosen, reject) 쌍이 주어졌을 때 data[prompt]['pairs'] 를 통해 이를 구별한다
  • sft_target 은 이중 chosen 결과에 대해 이뤄진다
    • 그렇다면 (chosen, neg1), (chosen, neg2), … 이런식으로 구성하면 chosen 마지막 쌍의 chosen 으로 결정될 것 (딱히 이상 없다는 의미)

preference_datasets.py

def get_batch_iterator():
with TemporarilySeededRandom(seed):
        permutation_seeds = iter(np.random.randint(0, 2**32, size=1000000))
        flat_data = []
        for name in names:
            truncation_mode = 'keep_end' if name == 'hh' else 'keep_start'
            for prompt, data in get_dataset(name, split, silent=silent, cache_dir=cache_dir).items():
                flat_data.append((prompt, data['responses'], data['pairs'], data['sft_target'], truncation_mode))

		collate_fn = get_collate_fn(tokenizer)

    epoch_idx = 0
    example_idx = 0
    done = False
    while True:
        if n_epochs is not None and epoch_idx >= n_epochs:
            if not silent:
                print(f'Finished generating {n_epochs} epochs on {split} split')
            break
        if shuffle:
            with TemporarilySeededRandom(next(permutation_seeds)):
                random.shuffle(flat_data)

        batch = []
        for prompt, responses, pairs, sft_target, truncation_mode in flat_data:
            if done:
                break
            if sft_mode:
                **batch_element = tokenize_batch_element(prompt, sft_target, sft_target, truncation_mode, tokenizer, max_length, max_prompt_length)**
                batch_element = {k: v for k, v in batch_element.items() if 'rejected' not in k}
                batch.append(batch_element)
                example_idx += 1
                if len(batch) == batch_size:
                    yield collate_fn(batch)
                    if n_examples is not None and example_idx >= n_examples:
                        if not silent:
                            print(f'Finished generating {n_examples} examples on {split} split')
                        done = True

                    batch = []
  • 처음엔 seed 생성 및 위에서 만든 get_dataset을 flat_data 배열에 append
  • shuffle 후 배치 생성 과정 진행
  • 아래 과정에서 만들어진 결과에 대해 sft mode 에서는 chosen, prompt 에 대해서만 batch_element를 유지한다
def tokenize_batch_element():
		chosen_tokens = tokenizer(chosen, add_special_tokens=False)
    rejected_tokens = tokenizer(rejected, add_special_tokens=False)
    prompt_tokens = tokenizer(prompt, add_special_tokens=False)

		chosen_tokens['input_ids'].append(tokenizer.eos_token_id)
    chosen_tokens['attention_mask'].append(1)

    rejected_tokens['input_ids'].append(tokenizer.eos_token_id)
    rejected_tokens['attention_mask'].append(1)

    longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids']))

		# Create labels
    chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
    rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens}
    chosen_sequence_tokens['labels'] = chosen_sequence_tokens['input_ids'][:]
    chosen_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids'])
    rejected_sequence_tokens['labels'] = rejected_sequence_tokens['input_ids'][:]
    rejected_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids'])

		batch = {}

    batch['prompt'] = prompt
    batch['chosen'] = prompt + chosen
    batch['rejected'] = prompt + rejected
    batch['chosen_response_only'] = chosen
    batch['rejected_response_only'] = rejected

    for k, toks in {'chosen': chosen_sequence_tokens, 'rejected': rejected_sequence_tokens, 'prompt': prompt_tokens}.items():
        for type_key, tokens in toks.items():
            if type_key == 'token_type_ids':
                continue
            batch[f'{k}_{type_key}'] = tokens

    return batch
  • (일단 truncation 관리 부분은 생략함
  • keep_start 는 prompt 를 자르는 선택, keep_end를 response를 자르는 선택을 함
  • create label 단계에서 chosen_token에는 ‘input_ids’, ‘attention_mask’ key 가 있으므로 예시로 chosen_sequence_tokens는 ‘input_ids’, ‘attention_mask’ 를 key로 갖는 dict 이 생성될 것
  • 추가적으로 여기에 ‘labels’ 항목으로는 prompt 부분의 token이 -100으로 처리된 항을 반환할 것
  • 배치 생성단계에서는 prompt, chosen, rejected, chosen_response_only, .. 에는 실제 string 이 들어감
  • 밑의 for문에서 type_key 에는 위에서 생성한 ‘input_ids’, ‘attention_mask’, ‘labels’가 들어갈 것이며
    • batch[’chosen_input_ids’] , batch[’chosen_attention_mask’], batch[’chosen_labels’] 이런식으로 위에서 생성한 것들이 들어갈 것이다
    • batch[’rejected_ ] 에 대해서도 같고 batch[’prompt_input_ids’] 와 batch[’prompt_attention_mask’] 가 생성될 것이다

trainer 부분 분석

python -u train.py model=blank_model datasets=[hh] loss=sft exp_name=anthropic_dpo_pythia28 gradient_accumulation_steps=128 batch_size=256 eval_batch_size=32 trainer=BasicTrainer sample_during_eval=false model.name_or_path=gpt2

  • 모델에 올라가는 배치 사이즈는 batch_size // gradient_accumulation_steps로 결정됨
  • 위의 예시로는 (0,1), (2,3), …, (254, 255) 이렇게 128 회의 gradient accumulation step 이 이뤄지게 될 것

train.py 의 worker_main 의 trainer.train()에서 시작됨

trainers.py

def train(self):
        """Begin either SFT or DPO training, with periodic evaluation."""
        self.optimizer = getattr(torch.optim, self.config.optimizer)(self.policy.parameters(), lr=self.config.lr)
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda step: min(1.0, (step + 1) / (self.config.warmup_steps + 1)))
    
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)

        if self.config.loss.name in {'dpo', 'ipo'}:
            self.reference_model.eval()

        self.example_counter = 0
        self.batch_counter = 0
        last_log = None

        for batch in self.train_iterator:

preference_datasets.py

def get_batch_iterator():
		with TemporarilySeededRandom(seed):
        permutation_seeds = iter(np.random.randint(0, 2**32, size=1000000))
        flat_data = []
        for name in names:
            truncation_mode = 'keep_end' if name == 'hh' else 'keep_start'
            for prompt, data in get_dataset(name, split, silent=silent, cache_dir=cache_dir).items():
                flat_data.append((prompt, data['responses'], data['pairs'], data['sft_target'], truncation_mode))

		collate_fn = get_collate_fn(tokenizer)

    epoch_idx = 0
    example_idx = 0
    done = False
    while True:
        if n_epochs is not None and epoch_idx >= n_epochs:
            if not silent:
                print(f'Finished generating {n_epochs} epochs on {split} split')
            break
        if shuffle:
            with TemporarilySeededRandom(next(permutation_seeds)):
                random.shuffle(flat_data)
				batch = []
        for prompt, responses, pairs, sft_target, truncation_mode in flat_data:
            if done:
                break
            if sft_mode:
                **batch_element = tokenize_batch_element(prompt, sft_target, sft_target, truncation_mode, tokenizer, max_length, max_prompt_length)**
  • n_epoch 이나 n_examples 중 한가지로 훈련 종료 조건을 만족시키기 때문에 둘 중 하나의 값은 정해주어야 한다
  • 위의 get_dataset, get_hh 에서 responses는 [chosen, reject], pairs는 [n, n+1], sft_target 은 chosen 의 값
  • truncation_mode에서는 앞에서부터 자를지 뒤에서부터 자를지 결정
  • 이 값을 flat_data에 전부 저장 (전체 dataset 길이와 동일)
def tokenize_batch_element():
		chosen_tokens = tokenizer(chosen, add_special_tokens=False)
    rejected_tokens = tokenizer(rejected, add_special_tokens=False)
    prompt_tokens = tokenizer(prompt, add_special_tokens=False)

    assert tokenizer.eos_token_id not in prompt_tokens['input_ids'], f"Prompt contains EOS token: {prompt}"
    assert tokenizer.eos_token_id not in chosen_tokens['input_ids'], f"Chosen response contains EOS token: {chosen}"
    assert tokenizer.eos_token_id not in rejected_tokens['input_ids'], f"Rejected response contains EOS token: {rejected}"

    chosen_tokens['input_ids'].append(tokenizer.eos_token_id)
    chosen_tokens['attention_mask'].append(1)

    rejected_tokens['input_ids'].append(tokenizer.eos_token_id)
    rejected_tokens['attention_mask'].append(1)

    longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids']))

		# Create labels
    chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
    rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens}
    chosen_sequence_tokens['labels'] = chosen_sequence_tokens['input_ids'][:]
    chosen_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids'])
    rejected_sequence_tokens['labels'] = rejected_sequence_tokens['input_ids'][:]
    rejected_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids'])

		batch = {}

    batch['prompt'] = prompt
    batch['chosen'] = prompt + chosen
    batch['rejected'] = prompt + rejected
    batch['chosen_response_only'] = chosen
    batch['rejected_response_only'] = rejected

    for k, toks in {'chosen': chosen_sequence_tokens, 'rejected': rejected_sequence_tokens, 'prompt': prompt_tokens}.items():
        for type_key, tokens in toks.items():
            if type_key == 'token_type_ids':
                continue
            batch[f'{k}_{type_key}'] = tokens

    return batch
  • 참고) tokenizer는 Trainer init 에서 정의됨
  • chosen, rejected, prompt 는 아직 모두 string
  • prompt 를 제외한 답변에 eos 토큰을 추가 (이에 대응되게 attention_mask 도 1 추가)
  • max_length 넘어갈 경우 다룬 코드는 생략함
  • 더 자세한 설명은 이미 위에서 다룸
    • labels에서 prompt 부분의 input_ids 가 -100으로 채워짐

다시 get_batch_iterator 이어서

		     if sft_mode:
                batch_element = tokenize_batch_element(prompt, sft_target, sft_target, truncation_mode, tokenizer, max_length, max_prompt_length)
                batch_element = {k: v for k, v in batch_element.items() if 'rejected' not in k}
                batch.append(batch_element)
                example_idx += 1
                if len(batch) == batch_size:
                    **yield collate_fn(batch)**
                    if n_examples is not None and example_idx >= n_examples:
                        if not silent:
                            print(f'Finished generating {n_examples} examples on {split} split')
                        done = True

                    batch = []
  • 위의 tokenize_batch_element는 한개의 batch 에 대한 정보
    • 다음과 같은 key를 갖고 있음
    • ['prompt', 'chosen', 'rejected', 'chosen_response_only', 'rejected_response_only', 'chosen_input_ids', 'chosen_attention_mask', 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels', 'prompt_input_ids', 'prompt_attention_mask']
  • sft 에서는 rejected 정보를 제외한 정보를 batch에 append
def get_collate_fn():
		def collate_fn():
        # first, pad everything to the same length
        padded_batch = {}
        for k in batch[0].keys():
            if k.endswith('_input_ids') or k.endswith('_attention_mask') or k.endswith('_labels'):
                if 'prompt' in k:  # adapted from https://stackoverflow.com/questions/73256206
                    to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
                else:
                    to_pad = [torch.LongTensor(ex[k]) for ex in batch]
                if k.endswith('_input_ids'):
                    padding_value = tokenizer.pad_token_id
                elif k.endswith('_labels'):
                    padding_value = -100
                elif k.endswith('_attention_mask'):
                    padding_value = 0
                else:
                    raise ValueError(f"Unexpected key in batch '{k}'")

                padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
                if 'prompt' in k:  # for the prompt, flip back so padding is on left side
                    padded_batch[k] = padded_batch[k].flip(dims=[1])
            else:
                padded_batch[k] = [ex[k] for ex in batch]

        return padded_batch
    return collate_fn
  • 참고) 일단 sft에서는 chosen_input_ids 와 chosen_attention_mask 정보만 사용됨
  • prompt 가 포함되어있는 key : ['prompt', 'prompt_input_ids', 'prompt_attention_mask']
    • max_len에 맞춰서 prompt 는 오른쪽에 쏠리길 바람 (다른 경우에는 왼쪽에 쏠리는게 맞음)

trainers.py train 계속

for batch in self.train_iterator:
						self.policy.train()

            start_time = time.time()
            batch_metrics = defaultdict(list)
            for microbatch_idx in range(self.config.gradient_accumulation_steps):
                global_microbatch = slice_and_move_batch_for_device(batch, microbatch_idx, self.config.gradient_accumulation_steps, self.rank)
                local_microbatch = slice_and_move_batch_for_device(global_microbatch, self.rank, self.world_size, self.rank)
                loss, metrics = self.get_batch_metrics(local_microbatch, self.config.loss, train=True)
                (loss / self.config.gradient_accumulation_steps).backward()
  • local_microbatch = global_microbatch 일듯

utils.py

def slice_and_move_batch_for_device(batch: Dict, rank: int, world_size: int, device: str) -> Dict:
    """Slice a batch into chunks, and move each chunk to the specified device."""
    chunk_size = len(list(batch.values())[0]) // world_size
    start = chunk_size * rank
    end = chunk_size * (rank + 1)
    sliced = {k: v[start:end] for k, v in batch.items()}
    on_device = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in sliced.items()}
    return on_device
  • chunk_size : batch_size // grad_accum_steps

trainers.py

def get_batch_metrics():
		elif loss_config.name == 'sft':
            policy_chosen_logits = self.policy(batch['chosen_input_ids'], attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32)
            **policy_chosen_logps = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'], average_log_prob=False)**

            losses = -policy_chosen_logps
  • 각 모드의 loss를 계산하는 함수
  • batch[’chosen_input_ids’].shape : [batch, 512]
  • policy_chosen_logits.shape : [batch, 512, 단어 개수]
def _get_batch_logps():
		labels = labels[:, 1:].clone()
    logits = logits[:, :-1, :]
    loss_mask = (labels != -100)

    # dummy token; we'll ignore the losses on these tokens later
    labels[labels == -100] = 0

    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

    if average_log_prob:
        return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
    else:
        return (per_token_logps * loss_mask).sum(-1)
  • labels는 데이터 처리 과정에서 prompt 에 대해 해당 토큰을 -100 로 설정해 loss 에서 제거
  • per_token_logps : [batch_size, 511]
  • 여기서는 cross_entropy loss 안하고 label에 해당하는 logit 이 최대가 되는 방향으로 loss를 설정
    • optimizer 와 learning rate가 이 때문에 조금 달라질 수 있을까?

get_batch_mterics() 이어서

def get_batch_metrics():
		elif loss_config.name == 'sft':
            policy_chosen_logits = self.policy(batch['chosen_input_ids'], attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32)
            **policy_chosen_logps = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'], average_log_prob=False)**

            losses = -policy_chosen_logps

				policy_chosen_logps = all_gather_if_needed(policy_chosen_logps.detach(), self.rank, self.world_size)
        metrics[f'logps_{train_test}/chosen'] = policy_chosen_logps.cpu().numpy().tolist()

        all_devices_losses = all_gather_if_needed(losses.detach(), self.rank, self.world_size)
        metrics[f'loss/{train_test}'] = all_devices_losses.cpu().numpy().tolist()

        return losses.mean(), metrics
  • losses : [batch_size]
  • all_gather_if_needed 는 world_size = 1이기 때문에 그냥 바로 return
  • metrics 에는 logps 값과 loss 값을 저장하고 losses 값 batch 방향 평균 + metrics 을 반환

train() 함수 다시

		 for microbatch_idx in range(self.config.gradient_accumulation_steps):
                global_microbatch = slice_and_move_batch_for_device(batch, microbatch_idx, self.config.gradient_accumulation_steps, self.rank)
                local_microbatch = slice_and_move_batch_for_device(global_microbatch, self.rank, self.world_size, self.rank)
                loss, metrics = self.get_batch_metrics(local_microbatch, self.config.loss, train=True)
                (loss / self.config.gradient_accumulation_steps).backward()

                for k, v in metrics.items():
                    batch_metrics[k].extend(v)

            grad_norm = self.clip_gradient()
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()

            step_time = time.time() - start_time
            examples_per_second = self.config.batch_size / step_time
            batch_metrics['examples_per_second'].append(examples_per_second)
            batch_metrics['grad_norm'].append(grad_norm)

		def clip_gradient(self):
        """Clip the gradient norm of the parameters of a non-FSDP policy."""
        return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.max_grad_norm).item()
  • sft에서 k에는 'logps_train/chosen', 'loss/train’
  • clip_gradient는 max_grad_norm 10으로 잡고 진행
  • accumulation_steps 진행하고 optimizer, scheduler step 진행

dpo 부분

python -u train.py model=blank_model datasets=[hh] loss=dpo loss.beta=0.3 model.archive=/home/doolee13/direct-preference-optimization/.cache/doolee13/anthropic_dpo_pythia28_2024-01-09_14-17-19_265693/step-256/policy.pt exp_name=anthropic_dpo_pythia28 gradient_accumulation_steps=128 batch_size=256 eval_batch_size=2 trainer=BasicTrainer sample_during_eval=false model.name_or_path=gpt2

data_preperation.py 추가부분

					else:
                for p in pairs:
                    if done:
                        break
                    batch_element = tokenize_batch_element(prompt, responses[p[0]], responses[p[1]], truncation_mode, tokenizer, max_length, max_prompt_length)
                    batch.append(batch_element)
                    example_idx += 1
                    if len(batch) == batch_size:
                        yield collate_fn(batch)
                        if n_examples is not None and example_idx >= n_examples:
                            if not silent:
                                print(f'FINISHED {n_examples} EXAMPLES on {split} split')
                            done = True
                        batch = []
  • pairs 에는 순서쌍 index가 들어있고 responses 는 그에 대응되는 자료들이 있으므로 chosen, reject에 대한 정보를 batch_element로 추출

train.py 에서 reference 모델 만들고 sft로 훈련시킨 결과를 policy와 reference 모델의 init_weight로 설정

train.py

if config.model.archive is not None:
        state_dict = torch.load(config.model.archive, map_location='cpu')
        step, metrics = state_dict['step_idx'], state_dict['metrics']
        print(f'loading pre-trained weights at step {step} from {config.model.archive} with metrics {json.dumps(metrics, indent=2)}')
        policy.load_state_dict(state_dict['state'])
        if config.loss.name in {'dpo', 'ipo'}:
            reference_model.load_state_dict(state_dict['state'])
        print('loaded pre-trained weights')
  • config.model.archive 에 sft 된 모델 정보가 저장 되어있음

  • poliy와 reference_model 모두에 sft 된 weight로 init

    trainers.py 의 init 부분은 동일

trainers.py 의 def train

def train():
		if self.config.loss.name in {'dpo', 'ipo'}:
            self.reference_model.eval()
  • 차이 1) reference_model 에 대한 eval 설정
  • get_batch_metrics에 dpo loss 추가

todo ; (init with sft weight)

profile
0100101

0개의 댓글