Pytorch Geometric - Message Passing Network

문재경·2023년 5월 5일
0

0. Graph Convolution

GNN은 CNN과 유사하게 convolution 연산을 수행한다. 다만, 그래프의 불규칙한 구조를 반영할 수 있도록 기존의 1D 또는 2D convolution가 아닌 graph convolution 연산을 사용한다. 아래 그림 출처

Graph convolution 연산의 핵심은 노드를 임베딩함에 있어 엣지로 연결된 노드들, 즉 이웃 노드들의 정보를 활용하는 것이다. 하나의 중심 노드에 대해, 그 노드가 이웃하는 노드들의 정보를 하나로 모아 중심 노드를 표현할 수 있는 벡터로 출력한다. 이러한 과정은 이웃 노드들이 엣지를 따라 중심 노드로 정보를 전달한다는 측면에서 message passing이라고도 표현된다.

이를 수식으로 표현한다면 다음과 같다.

  • xi\textrm{\bf{x}}_i: 노드 ii의 feature 벡터, 초기(k=1k=1)에는 input feature와 같지만 레이어를 지나며 임베딩 벡터로 변환

  • ej,i\textrm{\bf{e}}_{j, i}: 노드 jj에서 ii로 향하는 엣지의 feature 벡터, 반드시 존재하지는 않음

  • ()\bigoplus(\cdot): aggregation 함수

  • ϕ()\phi(\cdot): 중심 노드와 이웃 노드, 엣지의 feature 벡터로 message를 계산하는 함수

  • γ()\gamma(\cdot): aggregation의 결과를 업데이트하는 함수

  • kk: 레이어 인덱스

수식에 따르면, 노드 ii를 임베딩하는 과정은 엣지를 기준으로 중심 노드 ii와 이웃하는 노드들 jN(i)j\in N(i)의 정보를 가공해 message로 만들고, 이들을 aggregate한 결과를 다시 가공하여 최종 결과를 도출하는 것으로 정리할 수 있다.

이러한 과정에서 aggregation 함수 ()\bigoplus(\cdot)는 합이나 평균, 최댓값과 같이 출력이 입력 순서에 영향을 받지 않는 방식으로 설정되어야 한다. 또한, 모델이 학습을 통해 파라미터를 업데이트할 수 있도록 ϕ()\phi(\cdot)γ()\gamma(\cdot)는 MLPs처럼 미분 가능한 형태로 정의되어야 한다.

1. PyG의 MessagePassing 클래스

PyG에는 torch_geometric.nn.conv.MessagePassing , 줄여서 MessagePassing 클래스가 구현되어 있다. 클래스의 이름처럼 MessagePassing은 이웃한 노드 간 message 전파(propagation), 즉 그래프 신경망의 message passing을 관장한다. 위의 수식에서 ϕ()\phi(\cdot)γ()\gamma(\cdot), ()\bigoplus(\cdot)만 사용자가 정의하면 연산적인 부분은 알아서 실행되도록 구현되어 있는 것이다.

1.1. Super class

class MessagePassing(torch.nn.Module):
...

MessagePassingtorch.nn.Module을 상속 받기 때문에, 모델의 학습을 위한 기능들을 사용 및 구현할 수 있다. Message passing을 기반으로 하는 그래프 신경망 모델들은 다시 MessagePassing을 상속 받아 정의된다.

1.2. __init__

def __init__(
        self,
        aggr: Optional[Union[str, List[str], Aggregation]] = "add",
        *,
        aggr_kwargs: Optional[Dict[str, Any]] = None,
        flow: str = "source_to_target",
        node_dim: int = -2,
        decomposed_layers: int = 1,
        **kwargs,
    ):
    ...

클래스를 생성하고 초기화하는데 있어 필요한 인자는 다음과 같다.

  • aggr: aggregation 방법
    • "add", "mean", "max" 등의 키워드 또는 torch_geometric.nn.Aggregation 객체 사용 가능
    • None으로도 설정 가능 - MessagePassing 객체의 aggregate() 메소드를 통해 구현
    • 기본값은 "add"
  • aggr_kwargs: aggr이 키워드로 지정되지 않았을 때, aggregate()로 전달하기 위한 인자
    • 기본값은 None
  • flow: message passing의 방향
    • "source_to_target""target_to_source"의 이지선다
    • 보통 이웃 노드(source)에서 중심 노드(target)로의 message 전달을 가정하므로,
      기본값은 "source_to_target"
  • node_dim: 전파할 축
    • 이후 propagate()에 입력되는 edge_index와 관련
    • 기본값은 -2
  • decomposed_layers: feature decomposition 레이어 수
    • feature decomposition - CPU 환경에서의 GNN 실행을 위해 메모리 사용량을 줄이는 방법
    • 기본값은 1

1.3. propagate()

def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
	...
    out = self.message(**msg_kwargs)
    ...
    out = self.aggregate(**aggr_kwargs)
    ...
    out = self.update(**update_kwargs)
    ...
    
    return out

본격적인 message passing이 실행을 담당하며 propagate()를 통해 시작된다. 호출 시에는 edge_index와 size, 기타 노드 임베딩 과정에 필요한 요소(**kwargs)들을 인자로 받는다.

여기서 edge_index는 그래프를 구성하는 연결 관계로 message가 전달되는 경로를 설정하는 인자이다. torch.Tensor 또는 SparseTensor 중 어느 형태로 입력되는지에 따라 분기점이 존재한다.

  • torch.Tensor로 입력되는 경우,
    • 자료형은 torch.long, 크기는 [2, 엣지 수]를 만족해야 하며, 2개의 행이 source_node와 target_node의 순서로 구성되어야 함 (flow="source_to_target").
    • size의 지정 필요
  • SparseTensor로 입력되는 경우,
    • 기존의 텐서를 인덱스와 밸류, 크기로 나누어 구성하는 SparseTensor
    • 인접 행렬을 SparseTensor로 나타내면,
      인덱스 -> 엣지, 밸류 -> 엣지 속성
    • size 자동 계산

이후, message(), aggregate(), update()를 순차적으로 호출하여 노드의 임베딩 벡터를 업데이트한다. 사용자가 구현하는 방식에 따라 message()aggregate()message_and_aggregate()로 합치는 것도 가능하다.

1.4. message(): ϕ()\phi(\cdot)

def message(self, x_j: Tensor) -> Tensor:
	
    return x_j

노드 jj에서 노드 ii로 전달하는 메시지를 생성하는 함수로, propagate() 내에서 호출되기 때문에 propagate()에 전달된 인자들을 사용하는 것이 가능하다.

기본적으로는 입력된 텐서 x_j를 그대로 출력하게끔 작성되어 있다. 메시지를 생성하는 방식에 따라 그래프 신경망이 구분되므로, 구체적인 방식은 오버라이딩을 통해 구현되도록 한다.

1.5. aggregate(): ()\bigoplus(\cdot)

def aggregate(self, inputs: Tensor, index: Tensor,
                  ptr: Optional[Tensor] = None,
                  dim_size: Optional[int] = None) -> Tensor:

	return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
                                dim=self.node_dim)

torch.nn.aggr.Aggregation을 통해 message aggregation을 수행한다. message()의 출력값이 입력되며, message()와 마찬가지로 propagate()에 전달된 인자들을 사용하는 것이 가능하다.

1.6. update(): γ()\gamma(\cdot)

def update(self, inputs: Tensor) -> Tensor:
        
        return inputs

aggregate()의 출력값을 최종 노드 임베딩 벡터로 출력하기 전에 업데이트하는 역할을 한다. message()처럼 구체적인 방식은 오버라이딩을 통해 구현되며, propagate()에 전달된 인자들을 사용할 수 있다.

2. Message passing 신경망 구현

MessagePassing을 이용해서 그래프 신경망 모델을 만들어보자.

GCN 레이어를 만든다고 하면, message passing 연산은 다음과 같이 정의할 수 있다.

xik=jN(i){i}1deg(i)deg(j)(Wxjk1)+b\textrm{\bf{x}}_i^{k}=\sum_{j\in\mathcal{N}(i)\cup\lbrace i\rbrace}{1 \over \sqrt{\textrm{deg}(i)}\cdot\sqrt{\textrm{deg}(j)}}\cdot (\textrm{\bf{W}}^\top\cdot\textrm{\bf{x}}_j^{k-1})+\textrm{\bf{b}}
  1. 이전 레이어에서의 feature(또는 임베딩) 벡터 xjk1\textrm{\bf{x}}_j^{k-1}를 학습 가능한 파라미터로 구성된 가중치 행렬 W\textrm{\bf{W}}^\top로 선형 변환한다.
  2. 노드가 갖는 이웃의 수 deg(i)\sqrt{\textrm{deg}(i)}로 정규화한다.
  3. 임베딩하려는 중심 노드 ii에 대해, 이웃 노드들의 벡터들을 더한다.
    : self-loop를 추가하여 자기 자신도 이웃으로 간주
  4. bias 벡터가 더해진 최종 임베딩 벡터를 출력한다.
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # Step 3: 이웃 노드들의 벡터들을 더하는 방식으로 aggregation
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.Tensor(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: 입력된 이전 레이어의 feature 벡터를 선형 변환
        x = self.lin(x)

        # Step 3: 인접 행렬에 self-loop 추가
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: 정규화를 위해 노드가 갖는 이웃의 수 계산
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # 주어진 조건으로 message passing
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 4: bias 벡터 추가
        out += self.bias

        return out

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 2: 노드가 갖는 이웃의 수로 정규화
        return norm.view(-1, 1) * x_j

위 코드에서 GCNConvMessagePassing을 상속 받아 정의된다.

__init__()에서는 message passing에 있어 aggregation 방식을 "add"로 설정하고,
선형 변환 self.lin = Linear(...)와 bias 벡터 self.bias=Parameter(...)를 생성하여 파라미터를 초기화한다.

레이어 객체의 작동 매커니즘은 forward() 메소드를 통해 구현한다. 입력된 xedge_index에 대해 선형 변환 및 self-loop를 추가하는 과정이 포함되어 있다. propagate()를 호출하여 message passing을 진행하며 이에 필요한 message 생성 함수는 message()를 오버라이딩하여 구현한 것을 확인할 수 있다.

이렇게 만든 GCNConv는 다음과 같이 선언하여 사용하는 것이 가능하다.

conv = GCNConv(in_channels=16, out_channels=32)
x = conv(x, edge_index)
profile
안녕하세요...

0개의 댓글