해당 코드(https://github.com/AntixK/PyTorch-VAE)를 통해 VAE를 공부하고 있는데 문제가 생겼다.
encoder는 분명 점점 작아지면서 정보를 압축하는 것으로 알고 있는데 hidden dims가 점점 커지는 걸 그대로 사용하고 있는 것이었다.
결과적으로 hidden dim은 점점 커지는게 맞고, encoder에서 말하는 압축을 위해 점점 작아진다는 건 실제 공간인 spatial dimension이 줄어드는 것이다.
예를들어)
w,h = 128
input_channel = 32
output_channel = 64
kernel_size = 3
stride = 2
padding = 1
라는 상황을 가정할 때,
(128,128,32) * (3,3,32)64에 대해 stride와 padding을 고려해 계산해야 한다.
결국
(128+1(padding)-3(filter_w)+1 / 2(stride))의 결과인
(63,63,64)가 최종 output이 된다.
128 * 128 * 32 = 524,288
63 * 63 * 64 = 254,016으로
최종적으로 hidden dimension은 늘었지만 spatial dimension은 줄어드는 것을 알 수 있다.