
전처리 시작.
1. 크롤링 먼저하기
# 크롤링한 이미지 저장
from selenium import webdriver
import time
from selenium.webdriver.common.keys import Keys
import pandas as pd
import os
from urllib.request import (urlopen, urlparse, urlretrieve)
# 구글 이미지 URL
chrome_path = "./chromedriver"
# window
# chrome_path = "./chromedriver.exe"
base_url = "https://www.google.co.kr/imghp"
# 구글 검색 옵션
chrome_options = webdriver.ChromeOptions()
chrome_options.add_argument("lang=ko_KR") # 한국어
chrome_options.add_argument("window-size=1920x1080")
# driver = webdriver.Chrome(chrome_path, chrome_options=chrome_options)
# driver.get(base_url)
# driver.implicitly_wait(3) # element 로드될 때까지 지정한 시간만큼 대기할 수 있도록 하는 옵션
# driver.get_screenshot_as_file("google_screen.png")
# driver.close()
def selenium_scroll_option() :
SCROLL_PAUSE_SEC = 1
# 스크롤 높이 가져옴
last_height = driver.execute_script(
"return document.body.scrollHeight")
while True :
# 끝까지 스크롤 다운
driver.execute_script(
"window.scrollTo(0, document.body.scrollHeight);")
time.sleep(SCROLL_PAUSE_SEC)
# 스크롤 다운 후 스크롤 높이 다시 가져옴
new_height = driver.execute_script(
"return document.body.scrollHeight")
if new_height == last_height :
break
last_height = new_height
item_list = ['망고','용과','리치','두리안']
for i in range(len(item_list)):
image_name = item_list[i]
driver = webdriver.Chrome(chrome_path)
driver.get("http://www.google.co.kr/imghp?hl=ko")
browser = driver.find_element_by_name('q')
browser.send_keys(item_list[i])
browser.send_keys(Keys.RETURN)
selenium_scroll_option()# 스크롤 하여 이미지 확보
driver.find_element_by_xpath(
'//*[@id="islmp"]/div/div/div/div[1]/div[2]/div[2]/input').click()
selenium_scroll_option()
# 이미지 저장 src 요소를 리스트업 해서 이미지 url 저장
image = driver.find_elements_by_css_selector(".rg_i.Q4LuWd")
# 클래스 네임에서 공백은 . 을 찍어줌
# print(image)
image_url = []
for i in image:
if i.get_attribute("src") != None :
image_url.append(i.get_attribute("src"))
else :
image_url.append(i.get_attribute("data-src"))
#전체 이미지 개수
print(f"전체 다운로드한 이미지 개수 : {len(image_url)}")
image_url = pd.DataFrame(image_url)[0].unique()
for i in range(len(item_list)):
# 해당하는 파일에 이미지 다운로드
os.makedirs(f"./fruit{i}", exist_ok=True)
new_img = f"./fruit{i}/"
if image_name == f'{item_list[i]}':
for t, url in enumerate(image_url, 0):
urlretrieve(url, new_img + image_name + "_" + str(t) + ".png")
driver.close()
print("완료")
2. 불필요한 이미지 제거(일러스트, 상관없는 이미지 등)
3. 이미지 패딩 추가 및 리사이즈.
사진 저장하고 리사이즈 해야하는 듯..!
# 이미지 패딩 추가 및 리사이즈
from PIL import Image
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
# paste 이미지 붙이기(추가할 이미지 , 붙일 위치(가로, 세로))
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
# 새로운 이미지 생성
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
save_list = ['mango', 'Dragon_fruit', 'lychee', 'durian']
print(len(item_list))
for i in range(len(item_list)):
im_path = f'fruit{i}/'
save_path = f'{save_list[i]}/'
if not os.path.exists(save_path):
os.mkdir(save_path)
im_list = os.listdir(im_path) #원본 이미지 경로의 모든 이미지 list 지정
print(len(im_list))
for name in im_list:
im = Image.open(im_path+name)
new_im = expand2square(im, (0, 0, 0)).resize((225, 225))
new_im.save(save_path+name, quality=100)
print('종료')
전처리 완료.
후처리 시작.
# 후처리 시작.
from h11 import Data
import cv2
import torch
import glob
from PIL import Image
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import os
class MyCustomDatasetImage(Dataset):
def __init__(self, path):
# 정의
self.all_data = sorted(glob.glob(os.path.join(path, '*', '*.png')))
def __getitem__(self, index):
data_path = self.all_data[index]
image = Image.open(data_path)
data_split = data_path.split('/')
data_labels = data_split[1]
print(data_labels)
# windows
# data_split = data_path.split('\\')
labels = 0
if data_labels == '망고':
labels = 0
elif data_labels == '용과':
labels = 1
elif data_labels == '리치':
labels = 2
elif data_labels == '두리안':
labels = 3
print(data_labels, labels)
# cv2 PIL 이용해서 이미지 변경 하면됨.
return image, data_path, labels
# 정의 작성된 내용을 구현
def __len__(self):
# 전체 길이를 반환 -> 리스트 [] len()
return len(self.all_data)
save_list = ['mango', 'Dragon_fruit', 'lychee', 'durian']
# 2가지 구성 필요 -> dataset dataloader
for i in range(len(save_list)):
dataset = MyCustomDatasetImage(path=f"{save_list[i]}")
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
for path, label in dataloader:
print(path, label)