input텐서에 있는 모든 요소의 최대값을 반환합니다.
지정된 차원에서 텐서의 각 행의 최대값인 튜플을 반환합니다.
그리고 각 최대값(argmax)의 인덱스 위치를 나타냅니다.(values, indices)
축소된 행에 최대값이 여러 개 있는 경우 첫 번째 최대값의 인덱스가 반환됩니다.
input (Tensor) – the input tensor.
dim (int) – the dimension to reduce.
keepdim (bool) – whether the output tensor has dim retained or not. Default: False.
out (tuple, optional) – the result tuple of two output tensors (max, max_indices)
example = torch.tensor([[1,4,7], [3,6,9]])
텐서가 있을 때 max함수를 쓸 때 2가지로 사용할 수 있습니다.
torch.max(example)
example.max()
example = torch.tensor([[1,4,7], [3,6,9]])
print('\n', example)
tensor([[1, 4, 7],
[3, 6, 9]])
print('\n', torch.max(example)) # 텐서 안에 최대값
tensor(9)
print('\n', torch.max(example, dim=1)) # 행방향에 대한 최대값 (행에 대한 vlaue, index 값)
torch.return_types.max(
values=tensor([7, 9]),
indices=tensor([2, 2]))
print('\n', torch.max(example, dim=0)) # 열방향에 대한 최대값 (열에 대한 vlaue, index 값)
torch.return_types.max(
values=tensor([3, 6, 9]),
indices=tensor([1, 1, 1]))
print('\n', torch.max(example, dim=1)[0]) #행방향에서 value만 추출
tensor([7, 9])
print('\n', torch.max(example, dim=1)[1]) # 행방향에서 index만 추출
tensor([2, 2])
print('\n', torch.max(example, dim=0)[0]) # 열방향에서 value만 추출
tensor([3, 6, 9])
print('\n', torch.max(example, dim=0)[1]) # 열방향에서 index만 추출
tensor([1, 1, 1])