Pytorch Tensor의 max함수

chanykim·2022년 7월 7일
0

TORCH.MAX

torch.max(input) → Tensor

input텐서에 있는 모든 요소의 최대값을 반환합니다.

Parameter

  • input (Tensor) – the input tensor.

torch.max(input, dim, keepdim=False, *, out=None)

지정된 차원에서 텐서의 각 행의 최대값인 튜플을 반환합니다.
그리고 각 최대값(argmax)의 인덱스 위치를 나타냅니다.(values, indices)
축소된 행에 최대값이 여러 개 있는 경우 첫 번째 최대값의 인덱스가 반환됩니다.

Parameter

  • input (Tensor) – the input tensor.

  • dim (int) – the dimension to reduce.

  • keepdim (bool) – whether the output tensor has dim retained or not. Default: False.

  • out (tuple, optional) – the result tuple of two output tensors (max, max_indices)

예제

example = torch.tensor([[1,4,7], [3,6,9]])텐서가 있을 때 max함수를 쓸 때 2가지로 사용할 수 있습니다.
torch.max(example)
example.max()

example = torch.tensor([[1,4,7], [3,6,9]])

print('\n', example)
tensor([[1, 4, 7],
        [3, 6, 9]])


print('\n', torch.max(example)) # 텐서 안에 최대값
tensor(9)


print('\n', torch.max(example, dim=1)) # 행방향에 대한 최대값 (행에 대한 vlaue, index 값) 
torch.return_types.max(
values=tensor([7, 9]),
indices=tensor([2, 2]))


print('\n', torch.max(example, dim=0)) # 열방향에 대한 최대값 (열에 대한 vlaue, index 값) 
torch.return_types.max(
values=tensor([3, 6, 9]),
indices=tensor([1, 1, 1]))


print('\n', torch.max(example, dim=1)[0]) #행방향에서 value만 추출
tensor([7, 9])


print('\n', torch.max(example, dim=1)[1]) # 행방향에서 index만 추출
tensor([2, 2])


print('\n', torch.max(example, dim=0)[0]) # 열방향에서 value만 추출
tensor([3, 6, 9])


print('\n', torch.max(example, dim=0)[1]) # 열방향에서 index만 추출
tensor([1, 1, 1])
profile
오늘보다 더 나은 내일

0개의 댓글