GNN과 범주형 변수 처리

강하예진(Erica)·2023년 5월 11일
0

Graph Neural Network

목록 보기
10/10

Vector Embedding 방식

Category Embedding 방식을 통해, 지정한 임의 차원의 벡터로 변환할 수 있다. Embedding Layer을 사용해야 한다.
물론 One-Hot Encoding 방식을 쓸 수도 있지만, 여기서는 벡터 임베딩 방식을 다룬다.

Label Encoder

  1. Label Encoder를 통해 위치정보 str를 정수로 변환한다.
  • Label Encoder란? - 특정 column의 row마다 있는 모든 문자열에 대해서 고유한 정수로 인코딩. str를 0부터 시작하는 int로 변환시키는 기능을 하며, 문자열은 유니크한 정수로 변환되고, 문자열의 알파벳 순서에 따라 정수가 할당된다.

Embedding Layer

  1. Embedding Layer을 사용해 변환된 정수를 n차원 벡터로 변환한다. Keras, Pytorch 모두 사용 가능하다. 다시 말해 유니크한 정수값을 벡터에 표현하는 과정이다.

Shape?

  1. 결과는 [num_sample, 1, n차원]의 텐서가 된다.
    이걸 numpy array로 변환해서 차원을 줄인다. 그러면 최종적으로 만들어진 Location node feature vector는 [num_sample, n차원]의 벡터가 된다.

Concatenate

  1. 마지막으로는 기존 node feature matix에 연결해야 한다.
    임베딩 벡터를 적절한 user node의 feature matrix에 연결하기 위해서는, NumPy의 np.concatenate() 함수를 사용하면 된다.
    이 과정까지 마치면 기존 feature matrix의 shape가 (num_users, num_features + n차원)으로 확장될 것이다.

예제 코드

import numpy as np
import torch
from sklearn.preprocessing import LabelEncoder
from torch.nn import Embedding

# LabelEncoder를 사용해 위치 정보를 정수로 변환
le = LabelEncoder()
train_data['Location_encoded'] = le.fit_transform(train_data['Location'])

# 임베딩 레이어 초기화
embedding_layer = Embedding(num_embeddings=151, embedding_dim=25)

# 위치 정보를 25차원 벡터로 변환
location_embeddings = embedding_layer(torch.unsqueeze(torch.tensor(train_data['Location_encoded'].values, dtype=torch.long), dim=1))

# 결과를 NumPy array로 변환
location_embeddings = location_embeddings.detach().numpy().squeeze()

# 기존 user node feature matrix를 (num_users, num_features)형태라고 가정
previous_node_feature_matrix = np.random.rand(num_users, num_features)

# 위에서 만든 Location 임베딩 벡터와 기존 feature matrix를 연결
node_features_extended = np.concatenate((previous_node_feature_matrix, location_embeddings), axis=1)

advance

  • train_data['Location'].nunique()의 결과가 151일 때, input_dim이 151로 설정되는 이유

train_data['Location'].nunique()의 결과가 151이라는 것은 151개의 다른 위치가 있다는 것이다.
따라서 input_dim을 151로 설정하는 것은 각각의 위치 정보를 고유한 방식으로 표현하기 위한 것이다.

참고로, 완성된 embedding의 shape는 [] 이다.

profile
Recommend System & BackEnd Engineering

0개의 댓글