본 사이트를 참고하여 정리한 내용입니다.
GNN의 가장 대표적인 활용 사례: graph 분류 작업
ex) 각 원자에 해당하는 graph와 feature를 기반으로, 전체 분자의 behavior 예측
그러나, GNN은 오직 node embeddings만을 학습한다.
❔ 어떻게 하면 node embeddings를 결합하여 노드단위가 아닌 '전체 graph embedding'을 만들어낼 수 있을까?
이 포스트의 주요 내용을 미리 정리하자면 다음과 같다.
1. global pooling 이라는 새로운 종류의 layer.
(node embeddings를 결합하기 위함)
2. GIN 이라는 새로운 아키텍쳐 소개
(GCN과 GraphSAGE와 비교)
목적: GNN의 discriminative(=representational) power 극대화
Weisfeiler-Lehman test를 사용해서 GNN의 power를 특성화할 수 있다.
비슷한 양상을 띤다.
1. 모든 node가 같은 label로 시작한다.
2. Labels from neighboring nodes are aggregated and hashed to produce a new label.
3. 앞 두 과정들 반복, until the labels stop changing.
이러한 WL test에 영감 받아 new aggregator를 디자인하게 되었다.
이 새로운 aggregator는 non-isomorphic graph를 다룰 때 서로 다른 node embeddings를 만들어내야 한다.
👉 2개의 injective 함수를 사용하며, MLP를 통해 학습할 수 있다.
특정 노드 i의 hidden vector를 GIN으로 계산하는 식은 다음과 같다.
위 식에서 '입실론'은 target node의 중요도를 (compared to its neighbors) 결정한다. epsilon can be a learnable parameter or a fixed scalar.
: GNN으로 계산된 node embeddings를 이용해서 graph embedding을 만들어내는 것
simple ways to obtain a graph embeddings
위와 같은 분석 결과로 아래와 같은 global pooling 식이 완성되었다.
각 층마다 노드 임베딩은 합산되고 그 결과들이 concatenate(모든 결과 결합)된다. 이를 통해 sum operator의 expressiveness와, concatenation 이전의 iterations의 정보들(writer는 memory라 표현함)을 결합시킬 수 있다.
GIN layer에 MLP 적용
그렇다면 GIN의 architecture는
(images by author)
class GIN(torch.nn.Module):
"""GIN"""
def __init__(self, dim_h):
super(GIN, self).__init__()
self.conv1 = GINConv(
Sequential(Linear(dataset.num_node_features, dim_h),
BatchNorm1d(dim_h), ReLU(),
Linear(dim_h, dim_h), ReLU()))
self.conv2 = GINConv(
Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
Linear(dim_h, dim_h), ReLU()))
self.conv3 = GINConv(
Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
Linear(dim_h, dim_h), ReLU()))
self.lin1 = Linear(dim_h*3, dim_h*3)
self.lin2 = Linear(dim_h*3, dataset.num_classes)
def forward(self, x, edge_index, batch):
# Node embeddings
h1 = self.conv1(x, edge_index)
h2 = self.conv2(h1, edge_index)
h3 = self.conv3(h2, edge_index)
# Graph-level readout
h1 = global_add_pool(h1, batch)
h2 = global_add_pool(h2, batch)
h3 = global_add_pool(h3, batch)
# Concatenate graph embeddings
h = torch.cat((h1, h2, h3), dim=1)
# Classifier
h = self.lin1(h)
h = h.relu()
h = F.dropout(h, p=0.5, training=self.training)
h = self.lin2(h)
return h, F.log_softmax(h, dim=1)
gcn = GCN(dim_h=32)
gin = GIN(dim_h=32)
gcn = train(gcn, train_loader)
gin = train(gin, train_loader)
위 코드는 실습코드의 일부이며 GIN과 GCN의 결과를 비교해본다.
정확도 면에서 GIN이 GCN을 훨씬 앞선다.
Although GINs achieve good performance, especially with social graphs, their theoretical superiority doesn’t always translate well in the real world. It is true with other “provably powerful” architectures, which tend to underperform in practice, such as the 3WLGNN. -Author