pytorch | torch.bmm

nawnoes·2021년 8월 6일
0

PyTorch

목록 보기
6/7

TORCH.BMM

Batch matrix multiplication 으로 두 operand가 모두 batch일때 사용하며, 브로드캐스트 기능을 지원하지 않는다. 두 입력은 3-D 텐서가 되어야한다.

  • [B, N, M] x [B, M, P] = [B, N, P]
outi=inputi@mat2iout_i = input_i @ mat2_i

torch.bmm(input, mat2, *, deterministic=False, out=None) -> tensor

Example:

input = torch.randn(10,3,4)
mat2 = torch.randn(10,4,5)
res = torch.bmm(input, mat2)
res.size()

# result : torch.Size([10,3,5])

0개의 댓글