[pytorch] 😎 Custom ImageFolder Class!

강콩콩·2022λ…„ 4μ›” 11일
2

pytorch

λͺ©λ‘ 보기
4/7
post-thumbnail

😁 μ•ˆλ…•ν•˜μ„Έμš”, μ˜€λŠ˜μ€ vision κ΄€λ ¨ λͺ¨λΈ μž‘μ„±μ‹œ μš”κΈ΄ν•˜κ²Œ μ‚¬μš©λ˜λŠ” ImageFolder Class μ‚¬μš©λ²•μ„ κ°„λ‹¨νžˆ μ•Œμ•„λ³΄κ³ ,
😊 이λ₯Ό ν™œμš©ν•˜μ—¬ Custom Class도 λ§Œλ“€μ–΄λ³΄λ„λ‘ ν•˜κ² μŠ΅λ‹ˆλ‹€ :)

ImageFolder

πŸ˜‰ Dataset class의 μΌμ’…μœΌλ‘œμ„œ, Data의 경둜만 주어지면 Dataset 객체λ₯Ό κ°„λ‹¨νžˆ λ§Œλ“€ 수 μžˆλŠ” ν΄λž˜μŠ€μž…λ‹ˆλ‹€.

https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html?highlight=imagefolder#torchvision.datasets.ImageFolder

🐢 그리고 μ €λŠ”, μ§€λ‚œλ²ˆμ— μ‚¬μš©ν–ˆλ˜ Stanford Dog Dataset을 ν™œμš©ν•˜λ„λ‘ ν•˜κ² μŠ΅λ‹ˆλ‹€. :)

# ubuntu linux
wget http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar
tar -xvf images.tar

μœ„μ˜ λͺ…λ Ήμ–΄λ₯Ό μˆ˜ν–‰ν•˜λ©΄, imagesλΌλŠ” 폴더 μ΄ν•˜μ— κ·€μ—¬μš΄ 강아지 사진 데이터λ₯Ό 얻을 수 μžˆμŠ΅λ‹ˆλ‹€.


πŸ˜‹ μ œκ°€ μ’‹μ•„ν•˜λŠ” κ³¨λŒ•μ΄ 사진이 잘 λ‹€μš΄λ‘œλ“œ 된 것을 ν™•μΈν•˜μ˜€μŠ΅λ‹ˆλ‹€!
πŸ˜ƒ μ΄μ œλŠ” κ°„λ‹¨νžˆ ImageFolder classλ₯Ό ν™œμš©ν•΄λ³΄κ² μŠ΅λ‹ˆλ‹€. :)

from torchvision import transforms

size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

dog_transform = transforms.Compose([
                transforms.RandomResizedCrop(
                    (size, size), scale=(0.5, 1.0)),  
                transforms.RandomHorizontalFlip(), 
                transforms.ToTensor(),  # ν…μ„œλ‘œ λ³€ν™˜
                transforms.Normalize(mean, std)  # ν‘œμ€€ν™”
            ])

dog_dataset = torchvision.datasets.ImageFolder('Images/', transform=dog_transform)

dog_dataset.class_to_idx
>>>{'n02085620-Chihuahua': 0,
 'n02085782-Japanese_spaniel': 1,
 'n02085936-Maltese_dog': 2,
 ...

πŸ“ dataset.class_to_index 값을 ν™•μΈν•˜μ—¬, Dataset 객체의 class와 indexκ°„ mapping 관계λ₯Ό 확인할 수 μžˆμŠ΅λ‹ˆλ‹€.

data_loader = torch.utils.data.DataLoader(dog_dataset,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=2)
                                          
next(iter(data_loader))[0].shape, next(iter(data_loader))[1].shape
>>> (torch.Size([16, 3, 224, 224]), torch.Size([16]))

✨ 그리고, DataLoader classλ₯Ό ν™œμš©ν•˜μ—¬ μ •μƒμ μœΌλ‘œ 데이터가 λ½‘ν˜€λ‚˜μ˜€λŠ” κ²ƒκΉŒμ§€ 확인 μ™„λ£Œμž…λ‹ˆλ‹€ :)
😜 μ–΄λ•Œμš”, μ°Έ 쉽죠?

Why Custom Class?

πŸ˜’ μ•„λ‹ˆ 근데 μ™œ Custom Classλ₯Ό λ§Œλ“œλ €κ³  ν•˜λŠ”κ±°μ•Ό? μ§€κΈˆλ„ 잘 λ˜λŠ”λ°? - 라고 μƒκ°ν•˜μ‹ λ‹€λ©΄!
πŸ˜€ 저도 κ·Έλ ‡κ²Œ μƒκ°ν–ˆμŠ΅λ‹ˆλ‹€λ§Œ, μœ„ Class μ‚¬μš©μ„ μœ„ν•΄ ν•˜λ‚˜μ˜ 맹점이 μžˆμ—ˆμŠ΅λ‹ˆλ‹€!

"class_to_idx의 생성 κΈ°μ€€" 이 λ°”λ‘œ κ·Έκ²ƒμ΄μ—ˆμ£ !

πŸ€” κ°„λ‹¨ν•˜κ³  κ°•λ ₯ν•œ ImageFolder ClassλŠ” μ°Έ μ’‹μ§€λ§Œ, class_to_idxλŠ” "alphabet μˆœμ„œ"에 λ”°λΌμ„œ indexκ°€ κ²°μ •λ˜κ³  μžˆμ—ˆμŠ΅λ‹ˆλ‹€.
βœ” λ§Œμ•½ apple / banana / cider 3개 label이라면, {"apple" : 0, "banana" : 1, "cider" : 2} 인 κ²ƒμž…λ‹ˆλ‹€.
πŸ˜ƒ λ¬Όλ‘ , μ•ŒνŒŒλ²³ μˆœμ„œλŠ” κ½€λ‚˜ 보편적인 ruleμ΄μ§€λ§Œ, μ‹€ μ—…λ¬΄μ—μ„œλŠ” μ•ŒνŒŒλ²³ 라벨 μˆœμ„œκ°€ μ•„λ‹Œ class_to_idx κΈ°μ€€μœΌλ‘œλ„ λͺ¨λΈ ν•™μŠ΅μ΄ ν•„μš”ν•œ κ²½μš°κ°€ μžˆμ—ˆμŠ΅λ‹ˆλ‹€.

μƒˆλ‘œμš΄ class_to_idx 제곡

πŸ˜‹ ν…ŒμŠ€νŠΈλ₯Ό μœ„ν•΄ μƒˆλ‘œμš΄ class_to_idx 객체λ₯Ό λ§Œλ“€μ—ˆμŠ΅λ‹ˆλ‹€ :)

import os

label_list = os.listdir('Images/')

custom_class_to_idx = {label : idx for idx, label in enumerate(label_list)}

# μ•ŒνŒŒλ²³ μˆœμ„œλ‘œ idx 지정이 λ˜μ§€ μ•Šμ€ dict
custom_class_to_idx
>>> {'n02111500-Great_Pyrenees': 0,
 'n02111889-Samoyed': 32,
 'n02112018-Pomeranian': 1,
 'n02112137-chow': 97,
 ...
 

😎 이 mapping 관계λ₯Ό ν™œμš©ν•˜μ—¬ Dataset Classλ₯Ό λ§Œλ“€ μ˜ˆμ •μž…λ‹ˆλ‹€ :)

ImageFolder & DatasetFolder Class의 Source 뢄석

https://pytorch.org/vision/stable/_modules/torchvision/datasets/folder.html#ImageFolder

✨ μœ„μ˜ λ§ν¬λŠ” ImageFolder의 source codeμž…λ‹ˆλ‹€.
😎 λ”°λ‘œ ImageFolder Class둜 μž‘μ„±λœ λ‚΄μš©λ³΄λ‹€λŠ” DatasetFolder Class의 λ‘œμ§μ„ κ·ΈλŒ€λ‘œ 상속받아 μ‚¬μš©ν•˜μ˜€λ„€μš”!

class ImageFolder(DatasetFolder):
    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        loader: Callable[[str], Any] = default_loader,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
        super().__init__(
            root,
            loader,
            IMG_EXTENSIONS if is_valid_file is None else None,
            transform=transform,
            target_transform=target_transform,
            is_valid_file=is_valid_file,
        )
        self.imgs = self.samples

😁 κ·Έλ ‡λ‹€λ©΄ DatasetFolder Class의 source code도 ν•œλ²ˆ 보도둝 ν•˜μ£ .
🀣 더헛! VisionDataset Classλ₯Ό 또 μƒμ†λ°›μ•˜λ„€μš”! 이렇닀면 μ € ClassκΉŒμ§€ 뢄석해야 ν•˜λ‚˜? μ‹Άμ§€λ§Œ
😎 그럴 ν•„μš”κΉŒμ§€λŠ” μ—†μŠ΅λ‹ˆλ‹€. μ €ν¬μ—κ²Œ ν•„μš”ν•œκ±΄ μ–΄λ””κΉŒμ§€λ‚˜ 'class_to_idx' attribute의 μˆ˜μ •μ΄λ‹ˆκΉŒμš”.

class DatasetFolder(VisionDataset):
    """
    주석 μƒλž΅
    """

    def __init__(
        self,
        root: str,
        loader: Callable[[str], Any],
        extensions: Optional[Tuple[str, ...]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        classes, class_to_idx = self.find_classes(self.root)
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
        """Find the class folders in a dataset structured as follows::

            directory/
            β”œβ”€β”€ class_x
            β”‚   β”œβ”€β”€ xxx.ext
            β”‚   β”œβ”€β”€ xxy.ext
            β”‚   └── ...
            β”‚       └── xxz.ext
            └── class_y
                β”œβ”€β”€ 123.ext
                β”œβ”€β”€ nsdf3.ext
                └── ...
                └── asd932_.ext

        This method can be overridden to only consider
        a subset of classes, or to adapt to a different dataset directory structure.

        Args:
            directory(str): Root directory path, corresponding to ``self.root``

        Raises:
            FileNotFoundError: If ``dir`` has no class folders.

        Returns:
            (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
        """
        return find_classes(directory)       

✨ μ•„ν•˜, self.class_to_idxλŠ” self.find_classes λ©”μ„œλ“œλ₯Ό 톡해 κ²°μ •λ©λ‹ˆλ‹€.
πŸ˜‹ κ²°κ΅­ find_classes λ©”μ„œλ“œλ§Œ μˆ˜μ •μ„ ν•˜λ©΄ μ›ν•˜λŠ” λ°”λ₯Ό 이룰 수 μžˆμ„ 것 κ°™μŠ΅λ‹ˆλ‹€!

😊 그리고, find_classes λ©”μ„œλ“œμ— μ‚¬μš©λ˜λŠ” find_classes ν•¨μˆ˜μ˜ λ‚΄μš©μ€ μ•„λž˜μ™€ κ°™μŠ΅λ‹ˆλ‹€ :)

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

πŸ‘Œ OK! Dataset Class에 μ‚¬μš©λ  classes와 class_to_idxλ₯Ό returnν•©λ‹ˆλ‹€.
(classesλŠ” idx μˆœμ„œλŒ€λ‘œ classκ°€ λ‚˜μ—΄λœ listμž…λ‹ˆλ‹€.)

Custom ImageFolder & DatasetFolder Class μž‘μ„±!

# from https://pytorch.org/vision/0.11/_modules/torchvision/datasets/folder.html
################################################################################
################################################################################
# copied from folder.py

from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Union
from PIL import Image
from torchvision.datasets import VisionDataset, DatasetFolder

def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file
        extensions (tuple of strings): extensions to consider (lowercase)

    Returns:
        bool: True if the filename ends with one of given extensions
    """
    return filename.lower().endswith(extensions)

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')


def pil_loader(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


# TODO: specify the return type
def accimage_loader(path: str) -> Any:
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path: str) -> Any:
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

def make_dataset(
    directory: str,
    class_to_idx: Optional[Dict[str, int]] = None,
    extensions: Optional[Tuple[str, ...]] = None,
    is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
    """Generates a list of samples of a form (path_to_sample, class).

    See :class:`DatasetFolder` for details.

    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
    by default.
    """
    directory = os.path.expanduser(directory)

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

    if extensions is not None:

        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))

    is_valid_file = cast(Callable[[str], bool], is_valid_file)

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                if is_valid_file(fname):
                    path = os.path.join(root, fname)
                    item = path, class_index
                    instances.append(item)

                    if target_class not in available_classes:
                        available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes:
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
            msg += f"Supported extensions are: {', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances

# copied from folder.py END

################################################################################
################################################################################


class CustomDatasetFolder(VisionDataset):
    def __init__(
        self,
        root: str,
        loader: Callable[[str], Any],
        class_list: List[str],
        extensions: Optional[Tuple[str, ...]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        classes, class_to_idx = self.find_classes(class_list)
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Dict[str, int],
        extensions: Optional[Tuple[str, ...]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> List[Tuple[str, int]]:

        if class_to_idx is None:
            raise ValueError("The class_to_idx parameter cannot be None.")
        return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

    def find_classes(self, class_list: List[str]) -> Tuple[List[str], Dict[str, int]]:
        return class_list, {label : idx for idx, label in enumerate(class_list)}

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)


class CustomImageFolder(CustomDatasetFolder):
    def __init__(
        self,
        root: str,
        class_list: List[str],
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        loader: Callable[[str], Any] = default_loader,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
        super().__init__(
        root,
        loader,
        class_list,
        IMG_EXTENSIONS if is_valid_file is None else None,
        transform=transform,
        target_transform=target_transform,
        is_valid_file=is_valid_file,
    )
        self.imgs = self.samples

🀣 ν—₯ν—₯... μ’€ μž‘μ„±ν•˜λ‹€λ³΄λ‹ˆ κΈΈμ–΄μ‘Œλ„€μš” ;<

custom_dog_dataset = CustomImageFolder('Images/', label_list, transform=dog_transform)

custom_dog_dataset.class_to_idx
>>> {'n02111500-Great_Pyrenees': 0,
 'n02111889-Samoyed': 32,
 'n02112018-Pomeranian': 1,
 'n02112137-chow': 97,
 ...

✌ 이둜써 μ œκ°€ μ›ν•˜λŠ” class_to_idx dictλ₯Ό κΈ°μ€€μœΌλ‘œ λ§Œλ“€μ–΄μ§„ Dataset classκ°€ μ™„μ„±λ˜μ—ˆμŠ΅λ‹ˆλ‹€!

마치며

πŸ˜‚ 였늘 μž‘μ„±ν•œ 글은, 생각보닀 쑰금 μ•„μ‰¬μš΄ 면이 μžˆμŠ΅λ‹ˆλ‹€.
πŸ€¦β€β™‚οΈ λ‹¨μˆœνžˆ 상속과 Overriding을 ν™œμš©ν•˜λ©΄ κ°„λ‹¨νžˆ Custom classλ₯Ό λ§Œλ“€μˆ˜ μžˆμ§€ μ•Šμ„κΉŒ? ν•˜μ˜€μ§€λ§Œ..! 컨셉은 κ°„λ‹¨ν–ˆμœΌλ‚˜, μ½”λ“œλŠ” μƒλ‹Ήνžˆ κΈΈμ–΄μ Έλ²„λ Έλ„€μš”.
πŸ€” λΆ€μ‘±ν–ˆλ˜ 뢀뢄도 μžˆμ„ 수 μžˆμ„κ²ƒ κ°™μ•„, μ½”λ“œλ₯Ό 쒀더 뢄석해보고, 간단 λͺ…λ£Œν•˜κ²Œ ν™œμš©ν•  수 μžˆλŠ” λ°©μ•ˆμ΄ μžˆλŠ”μ§€ 2μ°¨ κ²€ν† κ°€ ν•„μš”ν•΄ λ³΄μž…λ‹ˆλ‹€.
πŸ˜‹ 뭐... κ·Έλž˜λ„ μ›ν•˜λŠ” κ²°κ³ΌλŠ” λ‚˜μ™”μœΌλ‹ˆ, μ–΄μ¨Œλ“  된거 μ•„λ‹κΉŒμš”?! (ν•˜.ν•˜.ν•˜.)
😘 μ½μ–΄μ£Όμ…”μ„œ κ°μ‚¬λ“œλ¦¬λ©°, 또 롙도둝 ν•˜κ² μŠ΅λ‹ˆλ‹€!

profile
MLOps, ML Engineer. λ°μ΄ν„°μ—μ„œ μ‹œμŠ€ν…œμœΌλ‘œ, μ‹œμŠ€ν…œμ—μ„œ κ°€μΉ˜λ‘œ.

0개의 λŒ“κΈ€