TypeError: GraphSAGERegressor.forward() missing 1 required positional argument: 'edge_attr'
train 함수 정의를 진행하며, edge_attr 전달에서 문제를 겪었다. 이 문제는
GraphSAGERegressor 모델 코드의 forward 메서드가 edge_attr을 요구하고 있지만, 현재 train 함수에서는 adjs만 전달하고 있기 때문에 발생한 에러였다.
DataLoader(neighbor sampler)에 edge_index와 edge_attr을 모두 리스트로 전달하도록 train 함수를 수정한 뒤, 문제를 해결했다.
우선, NeighborSampler에서 얻은 adjs를 사용해 edge_index를 추출해냈다. 여기서 중요한 점은 내가 data.edge_list가 data.edge_attr과 동일한 순서로 정렬되어 있도록 데이터를 구성해뒀다는 점이다.
그렇다는 것은, 각 adj에 대해 별도의 edge_attr 부분집합을 가져오지 않고 그냥 data.edge_attr 전체를 사용해도 된다는 것이다. 따라서 코드를 이렇게 구성했다.
edge_attr_list = [data.edge_attr.to(device) for _ in adjs]
adjs는 Data 객체의 속성(attribute)중 하나다. 그래프 데이터의 특성에 따라 sampling한 edge_index를 담고 있는, EdgeIndex 객체들의 리스트이다. 다시 말해 adjs 안에 EdgeIndex 객체들이 담겨 있고, EdgeIndex 객체 안에 edge_index들이 담겨 있다. EdgeIndex는 리스트이며, Pytorch Geometic에서 제공하는 데이터 구조로 그래프에서 사용되는 edge의 index 쌍을 저장한다.
즉, adjs는 그래프를 구성하는 edge 중에서, 일부분만 샘플링해 만든 여러 개의 EdgeIndex를 담고 있다.
코드에서 NeighborSampler를 사용해 그래프 데이터를 로드할 경우, adjs를 사용해 그래프 레이어를 계산해야 한다.