CIFAR 10 Classification

-J1-·2022년 1월 21일
0

Let's get used to the sensorflow framework and understand how CNN is work!


IMPORT LIBRARY

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

LOAD DATA

cifar10 = tf.keras.datasets.cifar10
#bring CIFAR10 dataset from tensorflow
(train_images,train_labels),(test_images,test_labels) = cifar10.load_data()
#load data separately in to train_images,train_label,test_images,test_data

ANALISYS DATA

#check each dataset shape and number of data size 
print(train_images.shape)
print(len(train_images))

print(test_images.shape)
print(len(test_images))
#data visualization
plt.figure(figsize=(3,3)) # draw figure
for i in range(9):
	plt.subplot(3,3,i+1)  # subplot into 3*3 array shape
	plt.imshow(train_images[i])
	plt.colorbar()
	plt.grid(False)
	plt.show()

MODEL TRAIN

#model building
model = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.Rescaling(1./255),
# normalize dataset values to between 0-1
tf.keras.layers.Conv2D(32, 3, activation='relu'),
# convolutional layer with 32 many and 3*3 sized kernels , use ReLu for activation function
tf.keras.layers.MaxPooling2D(),
# maxpooling layers 
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(128, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
# flatten layers from n dimension array form to 1 dimension
tf.keras.layers.Dense(128, activation='relu'),
# connect 128 nodes using dense layer 
tf.keras.layers.Dense(10)
# because we want 10 different classes to be defined connect to 10 nodes 

])
#model.summary()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# compile model using 'adam' for optimizer, crossentropy for loss function

model.fit(train_images, train_labels, epochs=15)
# actually train model epochs(=15) times

test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)
# check the result value of loss and accuracy
print('\nTest accuracy:', test_acc)

profile
Jaywalking with Jaewon🏃‍♀️

0개의 댓글