문제
torch.distributed.barrier()는 distributed training (multi-gpu training) 환경에서 multi-process로 학습을 수행할 때, 각 process (rank)들마다 진행 속도가 다를 수 있다.
- 모든 process들이 barrier()에 도달할 때까지 wait()을 걸어줌으로써 sync를 맞춰주는 역할을 한다.
 
- process가 도달을 하지 않거나, sync가 맞지 않으면 무한 대기에 빠진다.
 
 
- 모델 학습할 때 프로세스가 대강 아래와 같은데, epoch loop의 
torch.distributed.barrier()에서 stall되는 문제가 있었다. 
pseudo code
for loop:	
	for loop:	
    	...
    	output = model(input)
        ...
        loss.backward
        ...
        torch.distributed.barrier()	
        
	if torch.distributed.get_rank() == 0:
    	...
        validation()
        ...
        torch.save()
        ...
	torch.distributed.barrier()	
Validation code (rank=0만)
    def validation(self):
        acc1 = 0.0
        acc5 = 0.0
        losses = []
        with torch.no_grad():
            for _, data in enumerate(tqdm(self.val_loader)):
                input = data[0].cuda()
                gt = data[1].cuda()
                outputs = self.encoder(input)
                loss = self.CELoss(outputs, gt)
                losses.append(loss)
                acc1_, acc5_ = self.accuracy(outputs, gt, topk=(1, 5))
                acc1 += acc1_
                acc5 += acc5_
        acc1 = acc1 / len(self.val_loader)
        acc5 = acc5 / len(self.val_loader)
        loss = sum(losses) / len(losses)
        return acc1, acc5, loss
해결
- 다음 링크에 따르면, 
torch.distributed.get_rank() == 0 에 조건문을 걸고 모델을 실행시키더라도, 모델이 DistributedDataParallel()로 wrap되어있을 때 특정 상황(?)에서 forward pass때 sync를 대기하는 현상이 있다고 한다. 
- 아래와 같이 validation() 함수에서 
self.encoder.module(input)으로 변경하여 해결하였음. 
    def validation(self):
        acc1 = 0.0
        acc5 = 0.0
        losses = []
        with torch.no_grad():
            for _, data in enumerate(tqdm(self.val_loader)):
                input = data[0].cuda()
                gt = data[1].cuda()
                outputs = self.encoder.module(input)	
                loss = self.CELoss(outputs, gt)
                losses.append(loss)
                acc1_, acc5_ = self.accuracy(outputs, gt, topk=(1, 5))
                acc1 += acc1_
                acc5 += acc5_
        acc1 = acc1 / len(self.val_loader)
        acc5 = acc5 / len(self.val_loader)
        loss = sum(losses) / len(losses)
        return acc1, acc5, loss