이번 포스팅에선 OpenCV를 이용하여 비슷한 이미지 찾아내는 예제를 살펴보겠다.
예제 데이터 출처는 CIFAR-100이다.
코드들은 전부 jupyter notebook을 사용했으며, CIFAR-100도 전부 다운받아 로컬에 위치시켜 놓고 실습을 진행하였다.
OpenCV의 수많은 기능 중 이미지에서 색상 히스토그램(histogram)을 추출하고, 이들을 서로 비교하는 기능을 제공한다.
잠!깐만~~~ 여기서 히스토그램(histogram)이란, 이미지에서 픽셀 별 색상 값의 분포라고 볼 수 있습니다. 즉, 우리는 이러한 히스토그램을 통해, 각 이미지의 색상 분포를 비교하여 서로 유사한 이미지를 판단하는 척도로 사용할 예정이다.
예제로 하기에는 RGB 각 채널별로 0~255 범위의 각 값에 해당하는 픽셀의 개수를 일일이 저장하기에는 계산량이 많아진다. 고로! 단순화하여 4개 구간(0~63, 64~127, 128~191, 192~255)로 나누어 픽셀 수를 세기로 한다.
OpenCV의 예제 페이지에 Plotting Histograms가 있다. 이 중에서 Using Matplotlib 내 색상별 히스토그램을 그리는 코드를 참조해서 히스토그램을 그려보자.
작업에 필요한 패키지들을 import 해주자.
import os
import pickle
import cv2
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from PIL import Image
(어... 생각보다 많다...?)
CIFAR-100의 train 데이터의 경로와 이미지 파일로 변환해서 저장해줬던 이미지 디렉토리 경로도 각각 변수에 저장해주겠다.
train_file_path = os.path.join(os.getcwd(), '/cifar-100-python/train')
images_dir_path = os.path.join(os.getcwd(), '/python_image_proc/cifar-images')
아래의 함수는 파일명을 인자로 받고, 해당 이미지 파일과 히스토그램을 출력해 주는 역할을 한다.
def draw_color_histogram_from_image(file_name):
image_path = os.path.join(images_dir_path, file_name)
# 이미지 열기
img = Image.open(image_path)
cv_image = cv.imread(image_path)
# Image 시각화
f=plt.figure(figsize=(10,3))
im1 = f.add_subplot(1,2,1)
im1.grid(False)
im1.imshow(img)
im1.set_title("Image")
# Histogram 시각화
im2 = f.add_subplot(1,2,2)
color = ('b','g','r') # OpenCV에서는 RGB를 BGR로 쓴다.
for i,col in enumerate(color):
# image에서 i번째 채널의 히스토그램을 뽑아서(0:blue, 1:green, 2:red)
histr = cv.calcHist([cv_image],[i],None,[256],[0,256])
# 채널 색상과 맞춰 그래프를 그린다.
im2.plot(histr,color = col)
im2.grid(False)
im2.set_title("Histogram")
이제 함수를 사용해서 이미지의 히스토그램을 보자.
draw_color_histogram_from_image('acer_saccharinum_s_000230.png')
아이디어를 차근히 정리해보자.
(왜냐하면 너무 길다....아... 길다... 쏘~~~롱.... 베리베리 롱....)
이미지 파일 경로 하나를 입력으로 받아, 검색 대상 이미지들 중 비슷한 이미지들을 골라 화면에 표시하는 기능을 만드는 것이 목표이다.
여기서 핵심은 '비슷한 이미지'를 개념을 구현하는 것이다. 이미지의 유사성을 비교하는 방법에는 다양한 방법이 있지만, 여기선 위에서 배운 히스토그램이라는 개념을 통해 이미지를 서로 비교할 것이다.
히스토그램을 만들어주는 함수 cv2.calcHist()
,
히스토그램끼리의 유사성을 계산해 주는 cv2.compareHist()
라는 함수를 사용할 것이다.
검색 대상 이미지는 여태.... 준비했던 CIFAR-100 이미지를 사용하겠다.
.py
파일로 만든다고 하면, 명령어로 입력을 받아야 하는데 이때는 인자로 sys.argv
를 쓰면된다. 아래 코드는 jupyter notebook으로 실행할 거라 이 포스팅에선 쓰지 않겠다.
설계과정을 간추려(?) 보면 아래와 같다.
[설계과정]
- 프로그램이 실행.
build_histogram_db()
- CIFAR-100 이미지들을 불러온 후,
- 불러온 이미지들을 각각 히스토그램으로 만든다.
- 이미지 이름을 Key로 하고, 히스토그램을 Value으로 하는 딕셔너리 histogram_db를 반환한다.
- CIFAR-100 히스토그램 중 입력된 이미지 이름에 해당하는 히스토그램을 입력 이미지로 선택하여
target_histogram
이라는 변수명으로 지정한다.search()
search
는 입력 이미지 히스토그램target_histogram
,
전체 검색 대상 이미지들의 히스토그램을 가진 딕셔너리histogram_db
를 입력으로 받는다.compareHist()
함수를 사용하여 입력 이미지와 검색 대상 이미지 각각의 히스토그램 간 유사도를 계산하고,
결과는 Key는 이미지 이름, Value은 유사도로 result라는 이름의 딕셔너리에 남겨준다.- 계산된 유사도(distance)를 기준으로 정렬하여 순서를 매긴 후,
- 유사도의 순서상으로 상위 5개 이미지(top_k)만 result에 남긴다.
- 고른 이미지들을 시각화하여 확인 후,
- 프로그램이 종료.
아래의 코드는 이미지 파일 1개에 대해 히스토그램을 만드는 함수 get_histogram()
을 만들어주는 함수이다.
def get_histogram(image):
histogram = []
# Create histograms per channels, in 4 bins each.
for i in range(3):
channel_histogram = cv.calcHist(images=[image],
channels=[i],
mask=None,
histSize=[4], # 히스토그램 구간을 4개로 한다.
ranges=[0, 256])
histogram.append(channel_histogram)
histogram = np.concatenate(histogram)
histogram = cv.normalize(histogram, histogram)
return histogram
제대로 동작하는지 확인해보자.
filename = train[b'filenames'][0].decode()
file_path = os.path.join(images_dir_path, filename)
image = cv.imread(file_path)
histogram = get_histogram(image)
histogram
------------------------------------------------------
array([[0.3126804 ],
[0.4080744 ],
[0.14521089],
[0.21940625],
[0.18654831],
[0.23742512],
[0.30208108],
[0.35931748],
[0.06465594],
[0.35825753],
[0.36991683],
[0.29254165]], dtype=float32)
ndarray를 잘 뱉어내니, 잘 작동하는 것이다.
이제 build_histogram_db()
를 만들어서 이미지 파일의 histogram으로 db를 만들어주자.
def build_histogram_db():
histogram_db = {}
#디렉토리에 모아 둔 이미지 파일들을 전부 가져오자.
path = images_dir_path
file_list = os.listdir(images_dir_path)
for file_name in tqdm(file_list):
file_path = os.path.join(images_dir_path, file_name)
image = cv.imread(file_path)
histogram = get_histogram(image)
histogram_db[file_name] = histogram
return histogram_db
histogram_db = build_histogram_db()
histogram_db['acer_saccharum_s_000341.png']
---------------------------------------------------------
100%|██████████| 49999/49999 [00:28<00:00, 1755.48it/s]
array([[0.54847693],
[0.15545495],
[0.02949658],
[0.0829093 ],
[0.02790217],
[0.5747847 ],
[0.13472761],
[0.07892328],
[0.00159441],
[0.05022391],
[0.4743369 ],
[0.29018256]], dtype=float32)
target_histogram
에 입력받은 이미지를 히스토그램으로 저장하는 함수(get_target_histogram
)를 만들어보자.
def get_target_histogram():
filename = input("이미지 파일명을 입력하세요 : ")
if filename not in histogram_db:
print('유효하지 않은 이미지 파일명입니다.')
return None
return histogram_db[filename]
get_target_histogram
을 사용해 파일명을 입력한 후, 히스토그램 값을 알아보자.
target_histogram = get_target_histogram()
target_histogram
---------------------------------------------
이미지 파일명을 입력하세요 : adam_s_002140.png
array([[0.14578441],
[0.37344772],
[0.31353632],
[0.18971942],
[0.14877997],
[0.53021586],
[0.24164264],
[0.10184938],
[0.12581393],
[0.44833696],
[0.3125378 ],
[0.13579917]], dtype=float32)
마지막으로 search()
함수를 구현해보자.
입력부는 target_histogram
과 build_histogram_db
로 이미 완성되어있는 상태이다.
입력부에 유사도 순으로 몇 개를 결과에 남길지 top_k
로 정해주고, 기본값으로 상위 5개를 결과로 남겨주자.
def search(histogram_db, target_histogram, top_k=5):
results = {}
# Calculate similarity distance by comparing histograms.
for file_name, histogram in tqdm(histogram_db.items()):
distance = cv.compareHist(H1=target_histogram,
H2=histogram,
method=cv.HISTCMP_CHISQR)
results[file_name] = distance
results = dict(sorted(results.items(), key=lambda item: item[1])[:top_k])
return results
result = search(histogram_db, target_histogram)
result
-------------------------------------------------------------
100%|██████████| 49999/49999 [00:00<00:00, 368613.12it/s]
{'adam_s_002140.png': 0.0,
'stone_crab_s_000334.png': 0.06996628092772557,
'army_tank_s_000750.png': 0.07936638326924272,
'orchid_s_001611.png': 0.08178092910116018,
'young_lady_s_000619.png': 0.08270758835316562}
이제 얼마나 비슷한지 이미지들을 시각화해서 살펴보자.
def show_result(result):
f=plt.figure(figsize=(10,3))
for idx, filename in enumerate(result.keys()):
img_path = os.path.join(images_dir_path, filename)
im = f.add_subplot(1,len(result),idx+1)
im.grid(False)
img = Image.open(img_path)
im.imshow(img)
(어딜봐서 비슷하단 거지....?)
위 내용들을 총정리하자면!
검색할 이미지를 입력으로 받으면 가장 유사한 이미지가 화면에 출력되는 코드이다.
target_histogram = get_target_histogram()
result = search(histogram_db, target_histogram)
show_result(result)
cifar 100을 다운받았는데 사용법을 모르겠어서 질문드려요 압축파일을 풀고 로컬에 위치시킨다는게 로컬 어디에 넣는건가요?? 다운해서 압축까진 풀었는데..