Graph Data와 교차 검증

강하예진(Erica)·2023년 4월 28일
0

Graph Neural Network

목록 보기
3/10

Graph Data로도 교차 검증(Cross validation)이 가능하다. 다시 말해, 일반적인 ML 학습처럼 train-test 분할이 가능하다는 것이다. 여기엔 두 가지 방법이 있다.

1. Edge 기반 split

그래프의 일부 엣지를 테스트용으로 제거하고, 나머지 엣지를 사용해 모델을 학습한 다음, 제거된 엣지를 예측하는 데 사용한다.

  • 실행 방법
  1. 전체 그래프에서 일부 엣지를 무작위로 선택하여 테스트 셋으로 분할
  2. 테스트 셋으로 분할된 엣지를 그래프에서 제거. 이로 인해 학습 셋이 된다.
  3. 학습 셋을 사용하여 모델을 학습시키고, 제거된 엣지를 예측하는 데 사용.
    이 과정을 k번 반복하여 k-fold 교차 검증을 실행한다.

2. Node 기반 split

일부 노드를 테스트 셋으로 선택하고, 해당 노드와 연결된 엣지를 테스트 셋에 포함시키는 방법이다. 엣지가 노드로 바뀔 뿐, 프로세스는 Edge 기반 split과 동일하다.
유저가 상품에 남길 평점을 예측하는 회귀 예측 문제를 다루는 경우엔 일반적으로 Node 기반 분할이 더 적절하다.

  • 이유는 아래와 같다.
  1. 노드 기반 분할은 사용자 및 도서 노드를 각각 훈련, 검증 및 테스트 집합으로 분할할 수 있다. 이렇게 하면 테스트 및 검증 집합에서 새로운 사용자 및 도서에 대한 모델 성능을 평가할 수 있게 된다.
  2. 엣지 기반 분할 방식은 전체 연결 정보를 사용하므로, 테스트 및 검증 집합에서 사용자와 도서 사이의 연결 정보를 볼 수 있다(cheating). 이로 인해 테스트 및 검증 집합에서 모델 성능이 과대평가될 수 있다. 반면, 노드 기반 분할은 독립적인 사용자 및 도서 집합을 사용하여 이러한 문제를 해결한다.
  3. 노드 기반 분할은 데이터를 분할하는 가장 직관적인 방법이며, 실제 추천 시스템에서 사용자 및 도서가 각각의 개체로 취급되기 때문에 실제 상황을 더 잘 반영한다.

3. 주의점

엣지(노드)를 제거할 때 연결되지 않은 노드나 그래프의 일부 영역이 완전히 제거될 수도 있다. 그래프의 연결성이 훼손되는 것이다. 이러한 상황을 피하려면, 분할 전략을 조심스럽게 선택하거나 연결 구성품(connected components)을 고려해야 한다.

분할 전략을 선택할 때는 그래프의 특성과 모델의 목표를 고려해야 한다. 예를 들어, Link prediction task의 경우 Edge 기반 분할이 더 적합하고, Node classification Task의 경우 Node 기반 분할이 더 적절할 것이다.

PyTorch Geometric은 이런 분할들을 쉽게 적용할 수 있도록 돕는 라이브러리다. 이를 사용해 예제 코드를 작성해 보겠다.

예제 코드 - Node base split

우선, 몇 가지 짚고 넘어갈 사항이 있다.
1. random_split을 사용하지 않는다. 사용자가 평가하지 않은 도서에 대한 정보가 포함될 수 있기 때문. 따라서 사용자-도서 쌍의 평점 정보를 고려하여 분할해야 한다.
2. Graph Data는 masking 방법으로 분할을 진행한다.
3. masking 방법을 사용하기 위해, PyTorch Geometric과 sklearn 라이브러리를 함께 사용한다.

그래프 데이터에 train_mask와 test_mask가 추가되었으므로, 이를 사용하여 모델을 학습할 때 훈련 및 테스트 데이터에 해당하는 사용자와 도서 노드를 구분할 수 있다.

이 코드를 통해 PyTorch Geometric의 DataLoader를 사용하여 그래프 데이터를 로드할 때, batch.train_mask 및 batch.test_mask를 사용하여 각 배치에 대한 훈련 및 테스트 노드를 선택할 수 있다.

4. 예제 코드 - DataLoad

PyTorch Geometric의 데이터 로더를 사용하여 그래프 데이터를 로드하고, batch.train_mask 및 batch.test_mask를 사용하여 각 배치에 대한 훈련 및 테스트 노드를 선택할 수 있다.

graph_data는 PyTorch Geometric의 Data 객체로부터 생성된 그래프 데이터이며, batch_size와 shuffle을 사용하여 데이터 로더를 구성할 수 있다. 이렇게 생성된 데이터 로더를 통해 배치별로 그래프 데이터를 처리하며, 각 배치의 훈련 및 테스트 노드를 batch.train_mask 및 batch.test_mask를 사용하여 선택할 수 있게 된다.

profile
Recommend System & BackEnd Engineering

0개의 댓글