K교차검증

TaeHyun Lee·2023년 4월 24일
0

AI 공부

목록 보기
15/17

K교차검증이란?

k겹 교차 검증이란 데이터셋을 여러 개로 나누어 하나씩 테스트셋으로 사용하고 나머지를 모두 합해서 학습셋으로 사용하는 방법입니다. 이렇게 하면 가지고 있는 데이터의 100%를 학습셋으로 사용할 수 있고, 또 동시에 테스트셋으로도 사용할 수 있습니다. 예를 들어 5겹 교차 검증(5-fold cross validation)의 예가 있습니다.

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score

import pandas as pd

# 깃허브에 준비된 데이터를 가져옵니다.
!git clone https://github.com/taehojo/data.git

# 광물 데이터를 불러옵니다.
df = pd.read_csv('./data/sonar3.csv', header=None)

# 음파 관련 속성을 X로, 광물의 종류를 y로 저장합니다.
X = df.iloc[:,0:60]
y = df.iloc[:,60]

# 몇 겹으로 나눌 것인지 정합니다.
k = 5

# KFold 함수를 불러옵니다. 분할하기 전에 샘플이 치우치지 않도록 섞어 줍니다.
kfold = KFold(n_splits=k, shuffle=True)

# 정확도가 채워질 빈 리스트를 준비합니다.
acc_score = []

def model_fn():
    model = Sequential() # 딥러닝 모델의 구조를 시작합니다.
    model.add(Dense(24, input_dim=60, activation='relu'))
    model.add(Dense(10, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    return model

# k겹 교차 검증을 이용해 k번의 학습을 실행합니다.
# for 문에 의해 k번 반복합니다.
# split()에 의해 k개의 학습셋, 테스트셋으로 분리됩니다.
for train_index, test_index in kfold.split(X): 
    X_train, X_test = X.iloc[train_index,:], X.iloc[test_index,:]  
    y_train, y_test = y.iloc[train_index], y.iloc[test_index]

    model = model_fn()
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    history = model.fit(X_train, y_train, epochs=200, batch_size=10, verbose=0) 
    
    accuracy = model.evaluate(X_test, y_test)[1] # 정확도를 구합니다.
    acc_score.append(accuracy)                   # 정확도 리스트에 저장합니다.

# k번 실시된 정확도의 평균을 구합니다.
avg_acc_score = sum(acc_score) / k

# 결과를 출력합니다.
print('정확도: ', acc_score)
print('정확도 평균: ', avg_acc_score)

2/2 [==============================] - 0s 2ms/step - loss: 1.0080 - accuracy: 0.7381
2/2 [==============================] - 0s 2ms/step - loss: 0.7071 - accuracy: 0.8095
2/2 [==============================] - 0s 2ms/step - loss: 0.3312 - accuracy: 0.8810
2/2 [==============================] - 0s 2ms/step - loss: 0.4377 - accuracy: 0.9024
2/2 [==============================] - 0s 3ms/step - loss: 0.6416 - accuracy: 0.7317
정확도: [0.738095223903656, 0.8095238208770752, 0.8809523582458496, 0.9024389982223511, 0.7317073345184326]
정확도 평균: 0.8125435471534729

profile
서커스형 개발자

0개의 댓글