GPU를 이용한 pytorch를 사용하다보면 이런 에러를 만날 수도 있다.
이 에러는 tensor값을 boolen 값으로 비교하려고 할 떄 발생하는 에러이다.
예를 들면 다음과 같은 상황이다.
a = [[1, 2],[3, 4]]
a = torch.tensor(a)
if a == 1:
print('a is 1')
else:
print('a is not 1')
위의 코드와 같이 tensor값을 비교하면 다음과 같은 에러가 발생한다.
이를 해결하기 위한 방법으로는 여러가지가 있겠지만, 나는 이 방법을 사용했다.
a = torch.tensor(a)
if a.cpu().detach().numpy() == 1:
...
이와 같이 tensor값을 다시 cpu에서 사용가능한 값으로 바꿔주면 된다.
cpu를 먼저 선언해주고, detach후 numpy배열로 바꿔야 에러가 안나고 잘 바뀐다.
예외) 나는 값보다 타입으로 비교가 가능하여
if type(a) == int:
...
이처럼 비교가 가능한 타입으로 바꿔주면 에러없이 실행이 가능하다.
도움이 되었습니다. 감사합니다.