custom data를 사용하려면 preference_dataset.py 를 modify
preference_data준비는 trainers.py (BasicTrainer) init 시 시작됨
분석에 사용한 script
preference_datasets.py
def get_hh(split: str, silent: bool = False, cache_dir: str = None) -> Dict[str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]:
print(f'Loading HH dataset ({split} split) from Huggingface...')
dataset = datasets.load_dataset('Anthropic/hh-rlhf', split=split, cache_dir=cache_dir)
print('done')
def split_prompt_and_responses(ex):
**prompt = extract_anthropic_prompt(ex['chosen'])**
chosen_response = ex['chosen'][len(prompt):]
rejected_response = ex['rejected'][len(prompt):]
return prompt, chosen_response, rejected_response
data = defaultdict(lambda: defaultdict(list))
for row in tqdm.tqdm(dataset, desc='Processing HH', disable=silent):
prompt, chosen, rejected = split_prompt_and_responses(row)
responses = [chosen, rejected]
n_responses = len(data[prompt]['responses'])
data[prompt]['pairs'].append((n_responses, n_responses + 1))
data[prompt]['responses'].extend(responses)
data[prompt]['sft_target'] = chosen
return data
def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = '\n\nAssistant:'
**search_term_idx = prompt_and_response.rfind(search_term)**
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
return prompt_and_response[:search_term_idx + len(search_term)]
{'chosen': '\n\nHuman: what are some pranks with a pen i can do?\n\nAssistant: Are you looking for practical joke ideas?\n\nHuman: yep\n\nAssistant: Ok, I’ll give you a couple examples, and then you can choose if you like any of them. You can’t actually do all of these, they’re mostly for fun.\n\n1. Draw a penis on a friend’s arm, the idea is that you can draw a really huge penis.\n\n2. Make a list of jokes on someone, and then you can make someone else read it to the person in front of the person, you’ll probably have to hide the list of jokes.\n\n3. Tell someone you’ll watch them masturbate but hide your eyes, and then while they are masturbating you will find a new place to watch.\n\nHuman: okay some of these do not have anything to do with pens\n\nAssistant: No, sorry! All of these involve a pen, the point is that you can get funny results by doing pranks with pens.', 'rejected': '\n\nHuman: what are some pranks with a pen i can do?\n\nAssistant: Are you looking for practical joke ideas?\n\nHuman: yep\n\nAssistant: Ok, I’ll give you a couple examples, and then you can choose if you like any of them. You can’t actually do all of these, they’re mostly for fun.\n\n1. Draw a penis on a friend’s arm, the idea is that you can draw a really huge penis.\n\n2. Make a list of jokes on someone, and then you can make someone else read it to the person in front of the person, you’ll probably have to hide the list of jokes.\n\n3. Tell someone you’ll watch them masturbate but hide your eyes, and then while they are masturbating you will find a new place to watch.\n\nHuman: okay some of these do not have anything to do with pens\n\nAssistant: There are lots of funny things you can do with pens, here’s one example: use the pen as a zipper. It’s where you write your finger in ink, and then you stick it on someone’s hand and unzip their zipper. It’s really funny.'}
preference_datasets.py
def get_batch_iterator():
with TemporarilySeededRandom(seed):
permutation_seeds = iter(np.random.randint(0, 2**32, size=1000000))
flat_data = []
for name in names:
truncation_mode = 'keep_end' if name == 'hh' else 'keep_start'
for prompt, data in get_dataset(name, split, silent=silent, cache_dir=cache_dir).items():
flat_data.append((prompt, data['responses'], data['pairs'], data['sft_target'], truncation_mode))
collate_fn = get_collate_fn(tokenizer)
epoch_idx = 0
example_idx = 0
done = False
while True:
if n_epochs is not None and epoch_idx >= n_epochs:
if not silent:
print(f'Finished generating {n_epochs} epochs on {split} split')
break
if shuffle:
with TemporarilySeededRandom(next(permutation_seeds)):
random.shuffle(flat_data)
batch = []
for prompt, responses, pairs, sft_target, truncation_mode in flat_data:
if done:
break
if sft_mode:
**batch_element = tokenize_batch_element(prompt, sft_target, sft_target, truncation_mode, tokenizer, max_length, max_prompt_length)**
batch_element = {k: v for k, v in batch_element.items() if 'rejected' not in k}
batch.append(batch_element)
example_idx += 1
if len(batch) == batch_size:
yield collate_fn(batch)
if n_examples is not None and example_idx >= n_examples:
if not silent:
print(f'Finished generating {n_examples} examples on {split} split')
done = True
batch = []
def tokenize_batch_element():
chosen_tokens = tokenizer(chosen, add_special_tokens=False)
rejected_tokens = tokenizer(rejected, add_special_tokens=False)
prompt_tokens = tokenizer(prompt, add_special_tokens=False)
chosen_tokens['input_ids'].append(tokenizer.eos_token_id)
chosen_tokens['attention_mask'].append(1)
rejected_tokens['input_ids'].append(tokenizer.eos_token_id)
rejected_tokens['attention_mask'].append(1)
longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids']))
# Create labels
chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens}
chosen_sequence_tokens['labels'] = chosen_sequence_tokens['input_ids'][:]
chosen_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids'])
rejected_sequence_tokens['labels'] = rejected_sequence_tokens['input_ids'][:]
rejected_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids'])
batch = {}
batch['prompt'] = prompt
batch['chosen'] = prompt + chosen
batch['rejected'] = prompt + rejected
batch['chosen_response_only'] = chosen
batch['rejected_response_only'] = rejected
for k, toks in {'chosen': chosen_sequence_tokens, 'rejected': rejected_sequence_tokens, 'prompt': prompt_tokens}.items():
for type_key, tokens in toks.items():
if type_key == 'token_type_ids':
continue
batch[f'{k}_{type_key}'] = tokens
return batch
trainer 부분 분석
python -u train.py model=blank_model datasets=[hh] loss=sft exp_name=anthropic_dpo_pythia28 gradient_accumulation_steps=128 batch_size=256 eval_batch_size=32 trainer=BasicTrainer sample_during_eval=false model.name_or_path=gpt2
train.py 의 worker_main 의 trainer.train()에서 시작됨
trainers.py
def train(self):
"""Begin either SFT or DPO training, with periodic evaluation."""
self.optimizer = getattr(torch.optim, self.config.optimizer)(self.policy.parameters(), lr=self.config.lr)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda step: min(1.0, (step + 1) / (self.config.warmup_steps + 1)))
torch.manual_seed(self.seed)
np.random.seed(self.seed)
random.seed(self.seed)
if self.config.loss.name in {'dpo', 'ipo'}:
self.reference_model.eval()
self.example_counter = 0
self.batch_counter = 0
last_log = None
for batch in self.train_iterator:
preference_datasets.py
def get_batch_iterator():
with TemporarilySeededRandom(seed):
permutation_seeds = iter(np.random.randint(0, 2**32, size=1000000))
flat_data = []
for name in names:
truncation_mode = 'keep_end' if name == 'hh' else 'keep_start'
for prompt, data in get_dataset(name, split, silent=silent, cache_dir=cache_dir).items():
flat_data.append((prompt, data['responses'], data['pairs'], data['sft_target'], truncation_mode))
collate_fn = get_collate_fn(tokenizer)
epoch_idx = 0
example_idx = 0
done = False
while True:
if n_epochs is not None and epoch_idx >= n_epochs:
if not silent:
print(f'Finished generating {n_epochs} epochs on {split} split')
break
if shuffle:
with TemporarilySeededRandom(next(permutation_seeds)):
random.shuffle(flat_data)
batch = []
for prompt, responses, pairs, sft_target, truncation_mode in flat_data:
if done:
break
if sft_mode:
**batch_element = tokenize_batch_element(prompt, sft_target, sft_target, truncation_mode, tokenizer, max_length, max_prompt_length)**
def tokenize_batch_element():
chosen_tokens = tokenizer(chosen, add_special_tokens=False)
rejected_tokens = tokenizer(rejected, add_special_tokens=False)
prompt_tokens = tokenizer(prompt, add_special_tokens=False)
assert tokenizer.eos_token_id not in prompt_tokens['input_ids'], f"Prompt contains EOS token: {prompt}"
assert tokenizer.eos_token_id not in chosen_tokens['input_ids'], f"Chosen response contains EOS token: {chosen}"
assert tokenizer.eos_token_id not in rejected_tokens['input_ids'], f"Rejected response contains EOS token: {rejected}"
chosen_tokens['input_ids'].append(tokenizer.eos_token_id)
chosen_tokens['attention_mask'].append(1)
rejected_tokens['input_ids'].append(tokenizer.eos_token_id)
rejected_tokens['attention_mask'].append(1)
longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids']))
# Create labels
chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens}
chosen_sequence_tokens['labels'] = chosen_sequence_tokens['input_ids'][:]
chosen_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids'])
rejected_sequence_tokens['labels'] = rejected_sequence_tokens['input_ids'][:]
rejected_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids'])
batch = {}
batch['prompt'] = prompt
batch['chosen'] = prompt + chosen
batch['rejected'] = prompt + rejected
batch['chosen_response_only'] = chosen
batch['rejected_response_only'] = rejected
for k, toks in {'chosen': chosen_sequence_tokens, 'rejected': rejected_sequence_tokens, 'prompt': prompt_tokens}.items():
for type_key, tokens in toks.items():
if type_key == 'token_type_ids':
continue
batch[f'{k}_{type_key}'] = tokens
return batch
다시 get_batch_iterator 이어서
if sft_mode:
batch_element = tokenize_batch_element(prompt, sft_target, sft_target, truncation_mode, tokenizer, max_length, max_prompt_length)
batch_element = {k: v for k, v in batch_element.items() if 'rejected' not in k}
batch.append(batch_element)
example_idx += 1
if len(batch) == batch_size:
**yield collate_fn(batch)**
if n_examples is not None and example_idx >= n_examples:
if not silent:
print(f'Finished generating {n_examples} examples on {split} split')
done = True
batch = []
def get_collate_fn():
def collate_fn():
# first, pad everything to the same length
padded_batch = {}
for k in batch[0].keys():
if k.endswith('_input_ids') or k.endswith('_attention_mask') or k.endswith('_labels'):
if 'prompt' in k: # adapted from https://stackoverflow.com/questions/73256206
to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
else:
to_pad = [torch.LongTensor(ex[k]) for ex in batch]
if k.endswith('_input_ids'):
padding_value = tokenizer.pad_token_id
elif k.endswith('_labels'):
padding_value = -100
elif k.endswith('_attention_mask'):
padding_value = 0
else:
raise ValueError(f"Unexpected key in batch '{k}'")
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
if 'prompt' in k: # for the prompt, flip back so padding is on left side
padded_batch[k] = padded_batch[k].flip(dims=[1])
else:
padded_batch[k] = [ex[k] for ex in batch]
return padded_batch
return collate_fn
trainers.py train 계속
for batch in self.train_iterator:
self.policy.train()
start_time = time.time()
batch_metrics = defaultdict(list)
for microbatch_idx in range(self.config.gradient_accumulation_steps):
global_microbatch = slice_and_move_batch_for_device(batch, microbatch_idx, self.config.gradient_accumulation_steps, self.rank)
local_microbatch = slice_and_move_batch_for_device(global_microbatch, self.rank, self.world_size, self.rank)
loss, metrics = self.get_batch_metrics(local_microbatch, self.config.loss, train=True)
(loss / self.config.gradient_accumulation_steps).backward()
utils.py
def slice_and_move_batch_for_device(batch: Dict, rank: int, world_size: int, device: str) -> Dict:
"""Slice a batch into chunks, and move each chunk to the specified device."""
chunk_size = len(list(batch.values())[0]) // world_size
start = chunk_size * rank
end = chunk_size * (rank + 1)
sliced = {k: v[start:end] for k, v in batch.items()}
on_device = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in sliced.items()}
return on_device
trainers.py
def get_batch_metrics():
elif loss_config.name == 'sft':
policy_chosen_logits = self.policy(batch['chosen_input_ids'], attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32)
**policy_chosen_logps = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'], average_log_prob=False)**
losses = -policy_chosen_logps
def _get_batch_logps():
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = (labels != -100)
# dummy token; we'll ignore the losses on these tokens later
labels[labels == -100] = 0
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
get_batch_mterics() 이어서
def get_batch_metrics():
elif loss_config.name == 'sft':
policy_chosen_logits = self.policy(batch['chosen_input_ids'], attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32)
**policy_chosen_logps = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'], average_log_prob=False)**
losses = -policy_chosen_logps
policy_chosen_logps = all_gather_if_needed(policy_chosen_logps.detach(), self.rank, self.world_size)
metrics[f'logps_{train_test}/chosen'] = policy_chosen_logps.cpu().numpy().tolist()
all_devices_losses = all_gather_if_needed(losses.detach(), self.rank, self.world_size)
metrics[f'loss/{train_test}'] = all_devices_losses.cpu().numpy().tolist()
return losses.mean(), metrics
train() 함수 다시
for microbatch_idx in range(self.config.gradient_accumulation_steps):
global_microbatch = slice_and_move_batch_for_device(batch, microbatch_idx, self.config.gradient_accumulation_steps, self.rank)
local_microbatch = slice_and_move_batch_for_device(global_microbatch, self.rank, self.world_size, self.rank)
loss, metrics = self.get_batch_metrics(local_microbatch, self.config.loss, train=True)
(loss / self.config.gradient_accumulation_steps).backward()
for k, v in metrics.items():
batch_metrics[k].extend(v)
grad_norm = self.clip_gradient()
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
step_time = time.time() - start_time
examples_per_second = self.config.batch_size / step_time
batch_metrics['examples_per_second'].append(examples_per_second)
batch_metrics['grad_norm'].append(grad_norm)
def clip_gradient(self):
"""Clip the gradient norm of the parameters of a non-FSDP policy."""
return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.max_grad_norm).item()
dpo 부분
python -u train.py model=blank_model datasets=[hh] loss=dpo loss.beta=0.3 model.archive=/home/doolee13/direct-preference-optimization/.cache/doolee13/anthropic_dpo_pythia28_2024-01-09_14-17-19_265693/step-256/policy.pt exp_name=anthropic_dpo_pythia28 gradient_accumulation_steps=128 batch_size=256 eval_batch_size=2 trainer=BasicTrainer sample_during_eval=false model.name_or_path=gpt2
data_preperation.py 추가부분
else:
for p in pairs:
if done:
break
batch_element = tokenize_batch_element(prompt, responses[p[0]], responses[p[1]], truncation_mode, tokenizer, max_length, max_prompt_length)
batch.append(batch_element)
example_idx += 1
if len(batch) == batch_size:
yield collate_fn(batch)
if n_examples is not None and example_idx >= n_examples:
if not silent:
print(f'FINISHED {n_examples} EXAMPLES on {split} split')
done = True
batch = []
train.py 에서 reference 모델 만들고 sft로 훈련시킨 결과를 policy와 reference 모델의 init_weight로 설정
train.py
if config.model.archive is not None:
state_dict = torch.load(config.model.archive, map_location='cpu')
step, metrics = state_dict['step_idx'], state_dict['metrics']
print(f'loading pre-trained weights at step {step} from {config.model.archive} with metrics {json.dumps(metrics, indent=2)}')
policy.load_state_dict(state_dict['state'])
if config.loss.name in {'dpo', 'ipo'}:
reference_model.load_state_dict(state_dict['state'])
print('loaded pre-trained weights')
config.model.archive 에 sft 된 모델 정보가 저장 되어있음
poliy와 reference_model 모두에 sft 된 weight로 init
trainers.py 의 init 부분은 동일
trainers.py 의 def train
def train():
if self.config.loss.name in {'dpo', 'ipo'}:
self.reference_model.eval()
todo ; (init with sft weight)