torchmetrics: pytorch-lightning 친화적인 metrics & logging

Jonas M·2023년 3월 1일
0

AI 모델의 평가 metric

ai 모델로 분류 문제를 풀 때 보통 accuracy, top-k accuracy, precision, recall, f1 score 등을 사용한다. 각각의 의미는 여기에 다루지 않겠다. 자연스럽게 모델을 학습하는 도중 validation을 진행할 때 위와 같은 평가지표(metric)을 계산하고, 최고의 weight(체크포인트)를 선정하게 된다. 최종적으로 선정된 모델을 test set에 대해 평가할 때도 같은 metric을 사용한다.

AS-IS: scikit-learn의 metric

공부를 시작한 이후로 가장 많이 사용한 방식은 sklearn의 metric 라이브러리이다. 오래되었고 당시 많은 사람들이 사용하고 있었기 때문에 아무 의심 없이 싸이킷런 라이브러리가 최선이겠거니 생각하고 사용해왔다. 그러나... 아래와 같이 iteration의 결과값들을 모아주는 작업을 진행해야 했다. (이것이 코드를 지저분하게 만드는지 생각하지도 못했다.)

# pytorch lightning module
def validation_step(self, batch, batch_idx):
	images, labels, file_paths = batch
    loss, logits = self.shared_step(batch)
    self.log('val_loss', loss, on_epoch=True)
    return logits.cpu(), labels.cpu()

def validation_epoch_end(self, outputs):
	logits, targets = torch.tensor([]), torch.tensor([])
    for output in outputs:
    	logits = torch.cat((logits, output[0]))
    	targets = torch.cat((targets, output[1]))
    preds = torch.max(logits, dim=1).indices
    preds, targets = preds.numpy(), targets.numpy()
    top1 = accuracy_score(targets, preds)
    top5 = top_k_accuracy_score(targets, logits.numpy(), k=5, labels=range(41))
    precision = precision_score(targets, preds, average='macro', zero_division=1)
    recall = recall_score(targets, preds, average='macro', zero_division=1)
    f1 = f1_score(targets, preds, average='macro', zero_division=1)
    auc = roc_auc_score(targets, logits.numpy(), multi_class='ovo', labels=range(41))

	# log
    self.log('val_acc_top1', top1, sync_dist=True)
    self.log('val_acc_top5', top5, sync_dist=True)
    self.log('val_precision', precision, sync_dist=True)
    self.log('val_recall', recall, sync_dist=True)
    self.log('val_f1_macro', f1, sync_dist=True)
    self.log('val_auc', auc, sync_dist=True)

TO-BE: torchmetrics로 간결하게

pytorch-lightning에서 torchmerics라는 라이브러리를 제공하고 있다. 이들에 따르면 위와 같은 기능들을 제공한다고 한다. "Boilerplate를 줄인다"는 말은 비슷한 코드들이 반복되는 것을 방지한다는 뜻으로 모듈화를 잘 해두었다는 의미이다. 가장 관심이 갔던 부분은 역시 "Automatic"이다.

  • 배치에서 나오는 데이터들을 accumulation 해주고
  • multiple devices 간의 싱크를 제공해준다고 한다.

첫 번째는 AS-IS에서 outputs에서 output들을 뽑아서 하나로 모아주는 작업을 직접 코드상에서 하지 않도록 만들어준다. 이에 따라 validation_step에서 특별한 이유가 없다면 output을 return 하지 않아도 된다.
두 번째는 DDP 등의 방식으로 학습을 진행할 때, 여러 디바이스에서 나오는 결과값들을 모아주는 작업을 직접하지 않도록 도와준다. 사실 pytorch-lightning 자체에서 이러한 작업들을 도와주기 때문에 요즘에는 큰 문제가 되지는 않지만, multi-gpu 학습시에 기기간의 데이터 분산과 병합 문제는 깊은 단계에서의 작업이기 때문에 혹여 문제가 발생할 시에 핸들링하기가 어렵다.
다시 말해, outputs에 loop을 돌지 않아도 되기 때문에 코드가 간결해지고, 잘 짜여진 pytorch-lightning과 그에 잘 호환되는 metric 라이브러리를 활용하여 분산학습을 용이하게 도와줄 수 있다.

def validation_step(self, batch, batch_idx):
	idx, x_val, y_val = batch
    logits = self(x_val)
    loss = self.val_criterion(logits, y_val)
    self.log('val_loss', loss, on_epoch=True, sync_dist=True)
    self.metric.update(logits, y_val)
    
def validation_epoch_end(self, outputs):
	metric_out = self.metric.compute()
    self.log('val_acc_top1', metric_out['acc_top1'].cpu().item(), sync_dist=True)
    self.log('val_acc_top5', metric_out['acc_top5'].cpu().item(), sync_dist=True)
    self.log('val_precision', metric_out['precision'].cpu().item(), sync_dist=True)
    self.log('val_recall', metric_out['recall'].cpu().item(), sync_dist=True)
    self.log('val_f1_macro', metric_out['f1_macro'].cpu().item(), sync_dist=True)
    print('Result after Epoch: {:02d}>>'.format(self.current_epoch))
    pprint(metric_out)
    self.metric.reset()
profile
Graduate School of DataScience, NLP researcher

0개의 댓글