[TypeError] edge_attr

강하예진(Erica)·2023년 5월 4일
0

I Nailed It!

목록 보기
3/4

TypeError: GraphSAGERegressor.forward() missing 1 required positional argument: 'edge_attr'

train 함수 정의를 진행하며, edge_attr 전달에서 문제를 겪었다. 이 문제는
GraphSAGERegressor 모델 코드의 forward 메서드가 edge_attr을 요구하고 있지만, 현재 train 함수에서는 adjs만 전달하고 있기 때문에 발생한 에러였다.
DataLoader(neighbor sampler)에 edge_index와 edge_attr을 모두 리스트로 전달하도록 train 함수를 수정한 뒤, 문제를 해결했다.

@@@@@@@@@@@ 문제 해결 방법

  1. 우선, NeighborSampler에서 얻은 adjs를 사용해 edge_index를 추출해냈다. 여기서 중요한 점은 내가 data.edge_list가 data.edge_attr과 동일한 순서로 정렬되어 있도록 데이터를 구성해뒀다는 점이다.

  2. 그렇다는 것은, 각 adj에 대해 별도의 edge_attr 부분집합을 가져오지 않고 그냥 data.edge_attr 전체를 사용해도 된다는 것이다. 따라서 코드를 이렇게 구성했다.

edge_attr_list = [data.edge_attr.to(device) for _ in adjs]
  1. 이게 가능한 이유는, NeighborSampler에서 사용되는 EdgeIndex가 data.edge_index의 부분집합(배치 크기)이기 때문이다. 따라서 동일한 EdgeIndex로 인덱싱을 해도, edge_attr에서 제대로 원하는 엣지에 해당하는 엣지 가중치(edge_attr) 값을 가져올 수 있다. 따라서 NeighborSampler에서는 data.edge_attr를 사용해도 문제 없이 동작한다.

@@@@@@@@@@@ 에러 관련 주요 개념들

1. adjs와 EdgeIndex

adjs는 Data 객체의 속성(attribute)중 하나다. 그래프 데이터의 특성에 따라 sampling한 edge_index를 담고 있는, EdgeIndex 객체들의 리스트이다. 다시 말해 adjs 안에 EdgeIndex 객체들이 담겨 있고, EdgeIndex 객체 안에 edge_index들이 담겨 있다. EdgeIndex는 리스트이며, Pytorch Geometic에서 제공하는 데이터 구조로 그래프에서 사용되는 edge의 index 쌍을 저장한다.
즉, adjs는 그래프를 구성하는 edge 중에서, 일부분만 샘플링해 만든 여러 개의 EdgeIndex를 담고 있다.
코드에서 NeighborSampler를 사용해 그래프 데이터를 로드할 경우, adjs를 사용해 그래프 레이어를 계산해야 한다.

  • 너무 당연한 내용이지만
    GraphSAGE 모델은 노드의 특성(feature) 정보, 이웃 노드들의 feature 정보, 그리고 해당 노드와 이웃하는 엣지(edge)의 속성(edge_attr) 정보를 모두 사용하여 노드의 임베딩을 계산한다. 이렇게 함으로써, GraphSAGE 모델은 그래프 구조와 노드의 특성 정보를 모두 활용하여 노드의 임베딩을 학습할 수 있다.
profile
Recommend System & BackEnd Engineering

0개의 댓글