https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
새롭게 알게된 점
Pytorch의 Linear layer 모듈은 오직 TensorFloat32 dtype만을 지원한다.
근데 이 때, weight가 float16이면 32가 들어왔을 때 계산이 안되므로, weight도 float32로 넣어주어야 한다.
pretrainedweight = torch.nn.Parameter(h_zs.clone())
layer1 = nn.Linear(512, 100)
layer1.weight = pretrainedweight
를 아래와 같이 바꾸면 에러가 사라짐.
pretrainedweight = torch.nn.Parameter(h_zs.clone().float())
layer1 = nn.Linear(512, 100)
layer1.weight = pretrainedweight.float()
혹은 gpu에서 autocast를 실행해서 float16으로 계산하도록 만들어준다.