tfds.load

Younghwan Cha·2022년 10월 15일
0

OpenCV

목록 보기
10/10

https://www.tensorflow.org/tutorials/images/segmentation?hl=ko
상단 튜토리얼을 진행하던 중 다음과 같은 코드를 마주했다.

import tensorflow_datasets as tfds
tfds.disalbe_progress_bar()

datasets, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

oxford_iiit_pet 의 datasets 와 info 를 가져오는 코드다.
이 둘은 어떻게 생겼을까? 궁금하다 궁금해...
한번 보자.

INFO

먼저, info 의 경우 다음과 같다.

tfds.core.DatasetInfo(
    name='oxford_iiit_pet',
    full_name='oxford_iiit_pet/3.2.0',
    description="""
    The Oxford-IIIT pet dataset is a 37 category pet image dataset with roughly 200
    images for each class. The images have large variations in scale, pose and
    lighting. All images have an associated ground truth annotation of breed.
    """,
    homepage='http://www.robots.ox.ac.uk/~vgg/data/pets/',
    data_path='/home/cha/tensorflow_datasets/oxford_iiit_pet/3.2.0',
    file_format=tfrecord,
    download_size=773.52 MiB,
    dataset_size=774.69 MiB,
    features=FeaturesDict({
        'file_name': Text(shape=(), dtype=tf.string),
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=37),
        'segmentation_mask': Image(shape=(None, None, 1), dtype=tf.uint8),
        'species': ClassLabel(shape=(), dtype=tf.int64, num_classes=2),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=3669, num_shards=4>,
        'train': <SplitInfo num_examples=3680, num_shards=4>,
    },
    citation="""@InProceedings{parkhi12a,
      author       = "Parkhi, O. M. and Vedaldi, A. and Zisserman, A. and Jawahar, C.~V.",
      title        = "Cats and Dogs",
      booktitle    = "IEEE Conference on Computer Vision and Pattern Recognition",
      year         = "2012",
    }""",
)

대부분 이해가 가지만 아래 두 줄이 잘 이해가 안간다.

  supervised_keys=('image', 'label'),
  disable_shuffling=False,

이는 shuffle 에 대한 이해를 갖고 다시 와서 수정하도록 하자.

DATASETS

dataset 은 다음과 같다.

>DATASET
{
    'train': 
        <PrefetchDataset 
            element_spec={
                'file_name': TensorSpec(shape=(), dtype=tf.string, name=None), 
                'image': TensorSpec(shape=(None, None, 3), dtype=tf.uint8, name=None), 
                'label': TensorSpec(shape=(), dtype=tf.int64, name=None), 
                'segmentation_mask': TensorSpec(shape=(None, None, 1), dtype=tf.uint8, name=None), 
                'species': TensorSpec(shape=(), dtype=tf.int64, name=None)
            }
        >,
    'test': 
        <PrefetchDataset 
            element_spec={
                'file_name': TensorSpec(shape=(), dtype=tf.string, name=None), 
                'image': TensorSpec(shape=(None, None, 3), dtype=tf.uint8, name=None), 
                'label': TensorSpec(shape=(), dtype=tf.int64, name=None), 
                'segmentation_mask': TensorSpec(shape=(None, None, 1), dtype=tf.uint8, name=None), 
                'species': TensorSpec(shape=(), dtype=tf.int64, name=None)
            }
        >
}

이를 통해 우리가 원하는 dataset 의 모양을 알게되었다.
여기서 하나 더, 그렇다면 PrefetchDataset 객체는 뭘까?

print(dataset['train'].file_name 을 하니 다음과 같은 오류가 나왔다.

AttributeError: 'PrefetchDataset' object has no attribute 'file_name'

element_spec 으로 감싸져있구나
다시 print(dataset['train'].element_spec)

{
	'file_name': TensorSpec(shape=(), dtype=tf.string, name=None),
    'image': TensorSpec(shape=(None, None, 3), dtype=tf.uint8, name=None),
    'label': TensorSpec(shape=(), dtype=tf.int64, name=None), 
    'segmentation_mask': TensorSpec(shape=(None, None, 1), dtype=tf.uint8, name=None), 
    'species': TensorSpec(shape=(), dtype=tf.int64, name=None)
}

잘 나왔다. 이제 드디어 print(dataset['train'].element_spec.file_name)
...AttributeError: 'dict' object has no attribute 'file_name'
dictionary 객체란다..그렇다면 print(dataset['train'].element_spec['file_name']
TensorSpec(shape=(), dtype=tf.string, name=None)
나왔다!
근데 나오고 나서 생각해보니까 element_spec 이잖아ㅎㅎ실제 정보가 있을리 만무하다.
그럼 데이터는 어떻게 사용할까?

다음과 같은 코드로 데이터를 확인 할 수 있다.

import numpy as np
train_dataset = tfds.load('oxford_iiit_pet:3.*.*', split=tfds.Split.TRAIN)
print(train_dataset)

for data in train_dataset.take(1):
    image, label = data['image'], data['label']
    plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
    plt.axis('off')
    print("Label: %d" % label.numpy())

profile
개발 기록

0개의 댓글