Sequential
seq_model = keras.Sequential([
keras.Input(shape=(28,28)),
layers.Flatten(),
layers.Dense(300, activation='relu'),
layers.Dense(10, activation='softmax')
])
seq_model.summary()
functional API
inp = keras.Input(shape=(28,28))
x = layers.Flatten()(inp)
x = layers.Dense(300, activation='relu')(x)
out = layers.Dense(10, activation='softmax')(x)
func_model = keras.Model(inputs=inp , outputs=out)
func_model.summary()
Subclassing API
class model(keras.Model):
def __init__(self):
super().__init__()
self.flatten = layers.Flatten()
self.Dense1 = layers.Dense(300, activation='relu')
self.Dense2 = layers.Dense(10, activation='softmax')
def call(self, inputs):
x = self.flatten(inputs)
x = self.Dense1(x)
x = self.Dense2(x)
return x
cls_model = model()
cls_model.build((None,28,28))
cls_model.summary()
모델의 구조를 시각화해보기
