0630 개발일지

이나겸·2022년 6월 30일
0

1. 학습내용

어제 돌렸던 마스크 데이터에 대해서 나온 retina_last.pt 평가 진행
평가 코드를 넣기 전에 어제 custom_data 코드에서 train 부분을 지우고 넣음
데이터는 전부 29일 폴더안에 있어서 경로를 변경해서 돌렸다.
batch size를 2로 넣고 돌려야 코드가 터지지 않는다.

"""eval loop"""
retina.load_state_dict(torch.load(f'../29_image_train/retina_last.pt'))
# retina.load_state_dict(torch.load(f'retina_last.pt', map_location='cpu'))

"""
make_prediction 함수에는 학습된 딥러닝 모델을 활용해 예측하는 알고리즘이 저장돼 있습니다. 
threshold 파라미터를 조정해 신뢰도가 일정 수준 이상의 바운딩 박스만 선택합니다. 
보통 0.5 이상인 값을 최종 선택합니다. 
"""


def make_prediction(model, img, threshold):
    model.eval()
    preds = model(img)
    for id in range(len(preds)):
        idx_list = []
        for idx, score in enumerate(preds[id]['scores']):
            if score > threshold:  # threshold 넘는 idx 구함
                idx_list.append(idx)

        preds[id]['boxes'] = preds[id]['boxes'][idx_list]
        preds[id]['labels'] = preds[id]['labels'][idx_list]
        preds[id]['scores'] = preds[id]['scores'][idx_list]

    print("pred info ", preds)
    return preds


labels = []
preds_adj_all = []
annot_all = []

for im, annot in tqdm(test_loader, position=0, leave=True):
    im = list(img.to(device) for img in im)
    #annot = [{k: v.to(device) for k, v in t.items()} for t in annot]

    for t in annot:
        labels += t['labels']

    with torch.no_grad():
        preds_adj = make_prediction(retina, im, 0.5)
        preds_adj = [{k: v.to(torch.device('cpu'))
                      for k, v in t.items()} for t in preds_adj]

        print("make_prediction function 다음 처리 되는 값 >> ", preds_adj)

        preds_adj_all.append(preds_adj)
        annot_all.append(annot)


"""바운딩 박스 시각화 그리는 곳"""
nrows = 8
ncols = 2
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*4, nrows*4))

batch_i = 0
for im, annot in test_loader:
    pos = batch_i * 4 + 1
    for sample_i in range(len(im)):

        img, rects = plot_image_from_output(im[sample_i], annot[sample_i])
        axes[(pos)//2, 1-((pos) % 2)].imshow(img)
        for rect in rects:
            axes[(pos)//2, 1-((pos) % 2)].add_patch(rect)

        img, rects = plot_image_from_output(
            im[sample_i], preds_adj_all[batch_i][sample_i])
        axes[(pos)//2, 1-((pos+1) % 2)].imshow(img)
        for rect in rects:
            axes[(pos)//2, 1-((pos+1) % 2)].add_patch(rect)

        pos += 2

    batch_i += 1
    if batch_i == 4:
        break

# xtick, ytick 제거
for idx, ax in enumerate(axes.flat):
    ax.set_xticks([])
    ax.set_yticks([])

colnames = ['True', 'Pred']

for idx, ax in enumerate(axes[0]):
    ax.set_title(colnames[idx])

plt.tight_layout()
plt.show()

2. 학습소감

마스크를 쓰지 않은 사람의 얼굴은 아예 detect가 되지 않은 것 같다.
object detection은 공부가 더 필요할 것 같다..

0개의 댓글