Training Logic

yeoni·2023년 6월 27일
0

Tensorflow

목록 보기
11/15
  • Cifar10DataLoader 클래스, build_resnet 모델 이용
import numpy as np
import pandas as pd
import tensorflow as tf

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

np.random.seed(7777)
tf.random.set_seed(7777)

lr = 0.03
batch_size = 64 #메모리를 위해서

opt = tf.keras.optimizers.Adam(lr)
loss_fn = tf.keras.losses.categorical_crossentropy #함수
train_loss = tf.keras.metrics.Mean() # class로 만든 객체
train_acc = tf.keras.metrics.CategoricalAccuracy()

def train_step(x, y):
  with tf.GradientTape() as tape:
    pred = model(x)
    loss = loss_fn(y, pred)

  gradients = tape.gradient(loss, model.trainable_variables)
  opt.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss) # 객체에 넣은 값
  train_acc(y, pred)
  
for epoch in range(1):
  for i in range(train_x.shape[0]//batch_size):
    idx = i * batch_size
    x, y = train_x[idx:idx+batch_size], train_y[idx:idx+batch_size]
    train_step(x, y)
    print('{} / {}'.format(i, train_x.shape[0]//batch_size), end='\r')
  fmt = 'epoch: {}, loss: {}, acc: {}'
  print(fmt.format(epoch+1, train_loss.result(), train_acc.result()))
  train_loss.reset_state() # 각 epoch마다 loss 구할 수 있게 누적이 되는 것을 방지
  train_acc.reset_state()
---------------------------------------------------------

@tf.function # 선언했을 때 미리 그래프를 그려서 속도가 빨라짐, tensorflow에서 연산되는 부분 -> GPU에서 눈에 띄는 변화
def train_step(x, y):
  with tf.GradientTape() as tape:
    pred = model(x)
    loss = loss_fn(y, pred)

  gradients = tape.gradient(loss, model.trainable_variables)
  opt.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss) # 객체에 넣은 값
  train_acc(y, pred)
  
  
for epoch in range(1):
  for i in range(train_x.shape[0]//batch_size):
    idx = i * batch_size
    x, y = train_x[idx:idx+batch_size], train_y[idx:idx+batch_size]
    train_step(x, y)
    print('{} / {}'.format(i, train_x.shape[0]//batch_size), end='\r')
  fmt = 'epoch: {}, loss: {}, acc: {}'
  print(fmt.format(epoch+1, train_loss.result(), train_acc.result()))
  train_loss.reset_state() 
  train_acc.reset_state()

Reference
1) 제로베이스 데이터스쿨 강의자료

profile
데이터 사이언스 / just do it

0개의 댓글