torch.cat(tensors, dim=0, *, out=None) → Tensor
torch.cat은 tensor를 concatenate해주는 역할을 한다. 모든 텐서는 concatenating을 진행하는 dimension을 제외한 모두 같은 shape를 지녀야 연산이 가능하다.
Example:
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497]])
x의 shape는 (2,3) 이지만, dim=0으로 cat하면 (4,3)이 된다.
x의 shape는 (2,3) 이지만, dim=1으로 cat하면 (2,6)이 된다.