101일차 시작.... (EarlyStopping)

조동현·2022년 12월 1일
0

[교육] Python DL

목록 보기
7/16
post-thumbnail

📊 조기 종료 (EarlyStopping)


📌 조기 종료 (EarlyStopping)이란?

  • 정의
    - 학습 도중 monitor할 대상 지표의 값이 일정 주기 안에 반복된다면, 해당 시점에서 학습을 종료하는 시스템

  • monitor 종류
    - loss, val_loss, mse, mae : 낮아져야하는 지표로, 일정 주기 안에 반복되거나 높아진다면 조기 종료한다.
    - accuracy : 높아져야하는 지표로, 일정 주기 안에 반복되거나 낮아진다면 조기 종료한다.

  • EarlyStopping Tip
    [ 조기종료 대상 지표로 loss 값 보다는 val_loss 값을 사용하는 것이 과대적합을 예방하는데 유리하다. ]


📌 조기 종료 (EarlyStopping) 실습

1. 라이브러리 Import

# 기본 라이브러리
import numpy as np
import matplotlib.pyplot as plt

# 데이터세트
from keras.datasets import boston_housing

# feature scaling
from sklearn.preprocessing import MinMaxScaler

# 딥러닝 모델
from keras.models import Model
from keras.layers import Input, Dense
from keras.optimizers import Adam

# 딥러닝 EarlyStopping
from keras.callbacks import EarlyStopping

# 회귀모델 성능 지표
from sklearn.metrics import r2_score

2. 데이터 준비

(x_train, y_train), (x_test, y_test) = boston_housing.load_data()
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)
# (404, 13) (102, 13) (404,) (102,)

3. feature scaling

scaler = MinMaxScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.fit_transform(x_test)

4. Functional API 모델 구축

inputs = Input(shape=(13,))
net1 = Dense(units=30, activation='relu')(inputs)
net2 = Dense(units=20, activation='relu')(net1)
net3 = Dense(units=10, activation='relu')(net2)
outputs = Dense(units=1, activation='linear')(net3)

model = Model(inputs, outputs)

5. 모델 학습 최적화 설정

model.compile(optimizer=Adam(learning_rate=0.001), loss='mse', metrics=['mae'])
print(model.summary())

6. EarlyStopping 설정

# monitor 속성 : 조기종료에 사용할 대상을 지정
# patience 속성 : 반복되었을 때 조기종료할 특정 주기를 지정
es = EarlyStopping(
    monitor='val_loss',
    patience=3
)

7. 모델 학습

history = model.fit(
    x=x_train,
    y=y_train,
    batch_size=32,
    epochs=100,
    validation_split=0.2,
    callbacks=[es],
    verbose=0
)
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='valid loss')
plt.legend()
plt.show()

8. 모델 검증

loss, mae = model.evaluate(
    x=x_test,
    y=y_test,
    batch_size=32,
    verbose=0
)
print('loss : ', loss)
print('mae : ', mae)

9. 모델 예측값, 실제값 비교

y_pred = model.predict(x=x_test).flatten()
print('예측값 : ', y_pred[:4])
print('실제값 : ', np.array(y_test)[:4])

10. 회귀모델 성능 평가 지표 - 설명력

print('설명력 : ', r2_score(y_test, y_pred))
# 설명력 :  0.46321772178245846


profile
데이터 사이언티스트를 목표로 하는 개발자

0개의 댓글