torch.flatten(input, start_dim=0, end_dim=- 1) → Tensor
torch.flatten 은 입력을 1차원 텐서로 reshape 해준다. start_dim과 end_dim을 입력해줄 수 있다.
start_dim
end_dim
요런식으로 작동한다.
start_dim을 설정해보면
다음과 같은데,
flatten을 시작하는 dimension을 설정해준다.