[PyTorch] Tensor의 Type

olxtar·2022년 8월 10일
0
post-thumbnail

Comment :

tensor.long()이 뭔지 찾아보다가 PyTorch의 Tensor의 타입에 대해서 간단히 정리함


Tensor's Type

참고 : [PyTorch]TENSOR ATTRIBUTES PyTorch 공식Docu, 생각보다 종류가 많음...


  • 32-bit floating Point : FloatTensor
  • 64-bit floating Point : DoubleTensor
  • 16-bit floating Point : HalfTensor
  • 8-bit integer : ByteTensor(unsigned), CharTensor(signed)
  • 16-bit integer : ShortTensor
  • 32-bit integer : IntTensor
  • 64-bit integer : LongTensor

PyTorch의 Tensor 데이터 타입은 위와 같으며 보통 실수 계산을 하기 위해서는 FloatTensor, 정수를 사용하기 위해서는 LongTensor를 사용하며 Boolean, 즉 True/False 사용 시 ByteTensor를 사용한다.






Tensor's Type 변경


1. type

  • Tensor.type(dtype=None, non_blocking=False) \rightarrow str or Tensor

Returns the type if dtype is not provided, else casts this object to the specified type.
dtype을 설정하지 않고 사용하면 해당 Tensor의 타입을 출력하고, dtype을 설정하면 해당 dtype으로 Tensor를 변환하여 돌려줌

t_tensor = torch.tensor(1)
t_long = t_tensor.type('torch.LongTensor')
t_float = t_tensor.type('torch.FloatTensor')
t_byte = t_tensor.type('torch.ByteTensor')

print(t_long.type())
print(t_float.type())
print(t_byte.type())

>>>
torch.LongTensor
torch.FloatTensor
torch.ByteTensor
profile
예술과 기술

0개의 댓글