Tensorpack으로 DataLoader 구현하기

EuiyeonKim·2020년 4월 17일
0

Dataflow란?

파이선 자체에 있는 generator와 같은 맥락

from tensorpack import DataFlow

class myDataFlow(DataFlow):
    def __iter__(): essential
    def __len__(): optional
    def reset_state(): optional

Parallel Dataflow 구현하기

Dataflow를 parallel하게 구현하는 방법에는 크게 두 가지가 있다

  1. 같은 dataflow를 여러번 실행하고, 결과를 queue에 저장하기
  2. Dataflow의 task를 distribute하기

물론 두 가지 방법을 동시에 사용할 수도 있다

같은 dataflow를 여러 프로세스에서 실행하고, 결과를 queue에 저장하기

d1 = MyDataFlow()   # some dataflow written by the user
d2 = MultiProcessRunnerZMQ(d1, num_proc=20)

위와 같은 경우에는 d1에 랜덤성이 있어야 한다. 아니면 모든 thread가 다 같은 결과를 return함

Dataflow의 task를 distribute하기

d1 = MyDataFlow()   # a dataflow that produces [image file name, label]

def f(file_name, label):
    # read image
    # run heavy pre-proecssing / augmentation on the image
    return img, label

d2 = MultiProcessMapData(d1, num_proc=20, f)

위와 같은 경우 d1은 한 프로세스에서 돌아가고, f가 threading된다.
d1을 충분히 efficient하게 구현해야 효과가 좋다.
d1에서의 랜덤성을 고려하지 않아도 된다

두 가지 방법 모두 multi threading, multi processing으로 distribute 할 수 있다.

Multi-threading과 Multi-processing의 장단점

  1. Python에서는 GIL에 의해 thread에 제한이 걸려있어서 한 프로세스 내의 thread는 python 구문을 parallel하게 interpret할 수 없다
    -> multi-threading이 생각만큼 scalable하지 않다
  2. Process들은 자원을 공유하지 않기 때문에 process끼리 communication에 overhead가 생길 수 있다

효율적인 Dataflow 구현하기

from tensorpack.dataflow import *
# ds0은 file io만 수행
ds0 = dataset.ILSVRC12Files('/path/to/ILSVRC12', 'train', shuffle=True)

# Distribute할 task (augmentation) 
augmentor = AugmentorList(lots_of_augmentors)
# process에서 augmentation을 multi-threading
# 하나의 thread가 yield된 data를 1000크기의 버퍼에 쌓고, 나머지 process들이 버퍼에서 data를 읽어task 수행
ds1 = MultiThreadMapData(
    ds0, num_thread=25,
    map_func=lambda dp:
      [augmentor.augment(cv2.imread(dp[0], cv2.IMREAD_COLOR)), dp[1]],
    buffer_size=1000)

# ZMQ 통신을 사용해서 ds1 실행
ds1 = MultiProcessRunnerZMQ(ds1, num_proc=1)

# 원래 BatchData는 default인 경우 np.ndarray로 데이터를 저장하는데, 데이터의 크기가 일정하지 않은 경우 use_list를 True로 둬야 한다
# Master에서 batch로 쪼개서 return
ds = BatchData(ds1, 256)

# TestDataSpeed로 데이터 공급 속도 측정 가능
TestDataSpeed(ds).start()
profile
병아리 딥러닝 개발자

0개의 댓글