GNN과 NeighborSampler

강하예진(Erica)·2023년 5월 2일
0

Graph Neural Network

목록 보기
8/10
post-thumbnail

Task에 따라 적합한 그래프 로드 방식이 있다

만약 회귀 예측을 목표로 하고 있다면, 전체 그래프에서 지역적인 이웃 정보를 사용하여 추론하는 작업이 필요하다. 그럴 때는 Data로 그래프를 구성한 뒤, DataLoader를 바로 사용해서는 원하는 결과를 얻을 수 없다. 아마도 수많은 인덱스 에러를 마주칠 것이다. 그 이유는 바로 Pytorch Geometric의 DataLoader가 주로 노드 분류나 그래프 분류 작업에 사용되고, 이런 작업은 일반적으로 그래프 데이터를 배치 처리하는 데 초점을 맞추기 때문이다.

그래서 어떻게 해결할까?

이 문제를 해결하려면 NeighborSampler를 사용하면 된다. 그러면 각 미니 배치에 대해 지역적인 그래프 정보를 샘플링할 수 있다.
즉, 배핑된 인덱스를 사용해 직절한 피쳐를 로드하는 것은 DataLoader와 NeighborSampler의 결합을 통해 이루어진다.

과정

  1. Data 객체 생성: 데이터를 PyTorch Geometric의 Data 객체로 변환하자. Data 객체는 그래프의 노드 피처, 엣지 리스트 및 노드 label(영화 평점 등)을 저장한다.
from torch_geometric.data import Data

data = Data(x=torch.tensor(x, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long), y=torch.tensor(ratings, dtype=torch.float))
  1. DataLoader는 NeighborSampler를 사용해 미니 배치를 생성한다. 각 미니 배치에는 노드 인덱스와 해당 인덱스에 대한 지역적 그래프 정보가 포함된다. 여기서의 핵심은 각 미니 배치에 대해 지역적인 그래프 정보를 샘플링하는 것이다.
from torch_geometric.data import NeighborSampler

loader = NeighborSampler(data.edge_index, sizes=[10, 5], batch_size=32, shuffle=True, num_nodes=data.num_nodes)
  1. 학습 단계에서 앞서 정의한 loader를 사용해 미니 배치를 순회한다. 이때, 각 미니 배치의 노드 인덱스를 사용해 매핑된 노드 피쳐를 가져올 수 있다. NeighborSampler가 DataLoader의 기능을 대체한다. 그러니 train_loader와 test_loader를 설정할 때 기존 DataLoader를 사용할 필요가 없다.

위의 코드에서는 NeighborSampler를 사용하여 각 레이어에서 이웃 노드를 샘플링한다. sizes=[10, 5]는 각 레이어에서 이웃 노드를 10개와 5개로 샘플링한다는 것을 의미한다. 또한 batch_size=32를 사용하여 각 배치에서 32개의 노드를 처리한다.

profile
Recommend System & BackEnd Engineering

0개의 댓글