어제 돌렸던 마스크 데이터에 대해서 나온 retina_last.pt 평가 진행
평가 코드를 넣기 전에 어제 custom_data 코드에서 train 부분을 지우고 넣음
데이터는 전부 29일 폴더안에 있어서 경로를 변경해서 돌렸다.
batch size를 2로 넣고 돌려야 코드가 터지지 않는다.
"""eval loop"""
retina.load_state_dict(torch.load(f'../29_image_train/retina_last.pt'))
# retina.load_state_dict(torch.load(f'retina_last.pt', map_location='cpu'))
"""
make_prediction 함수에는 학습된 딥러닝 모델을 활용해 예측하는 알고리즘이 저장돼 있습니다.
threshold 파라미터를 조정해 신뢰도가 일정 수준 이상의 바운딩 박스만 선택합니다.
보통 0.5 이상인 값을 최종 선택합니다.
"""
def make_prediction(model, img, threshold):
model.eval()
preds = model(img)
for id in range(len(preds)):
idx_list = []
for idx, score in enumerate(preds[id]['scores']):
if score > threshold: # threshold 넘는 idx 구함
idx_list.append(idx)
preds[id]['boxes'] = preds[id]['boxes'][idx_list]
preds[id]['labels'] = preds[id]['labels'][idx_list]
preds[id]['scores'] = preds[id]['scores'][idx_list]
print("pred info ", preds)
return preds
labels = []
preds_adj_all = []
annot_all = []
for im, annot in tqdm(test_loader, position=0, leave=True):
im = list(img.to(device) for img in im)
#annot = [{k: v.to(device) for k, v in t.items()} for t in annot]
for t in annot:
labels += t['labels']
with torch.no_grad():
preds_adj = make_prediction(retina, im, 0.5)
preds_adj = [{k: v.to(torch.device('cpu'))
for k, v in t.items()} for t in preds_adj]
print("make_prediction function 다음 처리 되는 값 >> ", preds_adj)
preds_adj_all.append(preds_adj)
annot_all.append(annot)
"""바운딩 박스 시각화 그리는 곳"""
nrows = 8
ncols = 2
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*4, nrows*4))
batch_i = 0
for im, annot in test_loader:
pos = batch_i * 4 + 1
for sample_i in range(len(im)):
img, rects = plot_image_from_output(im[sample_i], annot[sample_i])
axes[(pos)//2, 1-((pos) % 2)].imshow(img)
for rect in rects:
axes[(pos)//2, 1-((pos) % 2)].add_patch(rect)
img, rects = plot_image_from_output(
im[sample_i], preds_adj_all[batch_i][sample_i])
axes[(pos)//2, 1-((pos+1) % 2)].imshow(img)
for rect in rects:
axes[(pos)//2, 1-((pos+1) % 2)].add_patch(rect)
pos += 2
batch_i += 1
if batch_i == 4:
break
# xtick, ytick 제거
for idx, ax in enumerate(axes.flat):
ax.set_xticks([])
ax.set_yticks([])
colnames = ['True', 'Pred']
for idx, ax in enumerate(axes[0]):
ax.set_title(colnames[idx])
plt.tight_layout()
plt.show()
마스크를 쓰지 않은 사람의 얼굴은 아예 detect가 되지 않은 것 같다.
object detection은 공부가 더 필요할 것 같다..