Pytorch Tensor concatenation

김유상·2022년 12월 22일
0

pytorch에서는 tensor를 다루는 여러 가지 방법을 제공하는데 이번에는 tensor를 이어 붙이는 방법을 사용해 보았다. torch.cat() 함수를 이용해 간단하게 수행할 수 있으며 split한 것을 반대로 붙여준다고 생각하면 될 것 같다.

매개변수 형식은 (tensor, dimension)으로 설정할 수 있고 dimension은 어느 축으로 concatenation을 수행할 것인지를 뜻한다.
dim = 0으로 설정하면 0 번째 차원에 이어 붙이는 것이다.

2차원 배열에 비유하자면 dim = 0은 세로로 이어 붙이고, dim = 1은 가로로 이어 붙인다.

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497]])

반환 형식은 이어 붙인 tensor이고 이것을 그대로 반환한다.

Referenced: https://pytorch.org/docs/stable/generated/torch.cat.html

profile
continuous programming

0개의 댓글