๐งโ๐ป ๋ชจ๊ฐ์ฝ 9์ฃผ์ฐจ
์ด์ ๊ธ ๐งโ๐ป ๋ชจ๊ฐ์ฝ 8์ฃผ์ฐจ ๋ณด๋ฌ๊ฐ๊ธฐ.
๐ ์ค๋์ ํ ์ผ
- ๐งโ๐ป MNIST Classification ์ค์ต
๐งโ๐ป MNIST Classification ์ค์ต
ํน๊ฐ ์ค์ต์์ Fashion MNIST๋ฅผ ์ด์ฉํ Classification์ ์ง์ ์ค์ตํ์ผ๋, ์ ํต์ ์ธ MNIST์ ๋ํ ์ค์ต์ ํ์ง ์์์๋ค. ์ค๋์ ๊ฐ๋ณ๊ฒ Fashion MNIST ์ค์ต์ ์ฐธ๊ณ ํ์ฌ MNIST Classification์ ๊ตฌํํด๋ณด๋ ค๊ณ ํ๋ค.
Import
ํ ์ํ๋ก์ฐ 2.0, keras, numpy, matplotlib, collections๋ฅผ ์ฌ์ฉํ๋ค.
try:
# Colab only
%tensorflow_version 2.x
except Exception:
pass
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import collections as col
colab์ ๊ธฐ๋ณธ์ ์ผ๋ก tensorflow 1๋ฒ์ ์ ์ฌ์ฉํ๋ค. ์์ ์ฝ๋๋ 2๋ฒ์ ์ ์ฌ์ฉํ๋๋ก ๊ฐ์ ํ๋ค. colab์์๋ง ์ฌ์ฉ์ด ๊ฐ๋ฅํ๋ค.
Dataset ๊ฐ์ ธ์ค๊ธฐ
keras์์๋ MNIST ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ณธ์ ์ผ๋ก ์ ๊ณตํ๋ค. ์๋ ์ฝ๋๋ฅผ ํตํด ํ๋ จ์ ๊ณผ ํ ์คํธ์ ์ ์์ฝ๊ฒ ๊ฐ์ ธ์ฌ ์ ์๋ค.
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
print(train_images.shape)
print(train_labels.shape)
print(test_images.shape)
print(test_labels.shape)
Label ๋ณ ๊ฐ์ ํ์ธ
collections์ Counterํจ์๋ฅผ ํตํด label๋ณ๋ก ๋ช๊ฐ์ฉ ์๋์ง ํ์ธํ ์ ์๋ค. Fashion MNIST์๋ ๋ฌ๋ฆฌ ๊ณ ๋ฅด๊ฒ ๋ถํฌ๋์ด์์ง ์๋ค. ๋ณด๊ธฐ์ข๊ฒ matplotlib์ histogram์ ํตํด ์๊ฐํ ํด ๋ณด์๋ค.
col.Counter(train_labels)
col.Counter(test_labels)
plt.hist(train_labels, color = 'green', alpha = .5)
plt.hist(test_labels, color = 'green', alpha = .5)
์ด๋ฏธ์ง ํ์ธ
ํ๋ จ ์ด๋ฏธ์ง ์ ์์ ์ฒซ๋ฒ์งธ ์ด๋ฏธ์ง๋ฅผ ์๊ฐํํ์ฌ ํ์ธํด๋ณด์๋ค.
plt.figure(figsize = (10, 10)) # ์ ์ฒด ํผ๊ฒจ ์ฌ์ด์ฆ๋ฅผ 10*10์ผ๋ก
for i in range(25) :
plt.subplot(5, 5, i + 1)
plt.xticks([]) # x ๋๊ธ ์ ๊ฑฐ
plt.yticks([]) # y ๋๊ธ ์ ๊ฑฐ
plt.xlabel(train_labels[i]) # x ์ถ ๋ผ๋ฒจ์ ์ ๋ต
plt.imshow(train_images[i])
์ด๋ฏธ์ง ๋ถ์, ์ ๊ทํ
์ด๋ฏธ์ง๋ 255๊น์ง์ ๊ฐ๋๋ฅผ ๊ฐ๋ ํฝ์ 24*24๊ฐ๋ก ์ด๋ฃจ์ด ์ ธ ์๋ค. ๋ฐ๋ผ์ 255๋ก ๋๋ ์ค์ผ๋ก ์จ ์ด๋ฏธ์ง์ ๋ชจ๋ ํฝ์ ์ 0~1์ฌ์ด์ ์๋ก ํํํ ์ ์๋ค. (์ ๊ทํ) 255๋ก ๋๋ ๋ค, Counter์ figure๋ฅผ ํตํด ์ ๊ทํ๊ฐ ๋์์์ ํ์ธํ ์ ์๋ค.
col.Counter(train_images[0].reshape(784)) # 0 ๋ฒ์งธ ์ด๋ฏธ์ง(24*24 2์ฐจ์ ๋ฐฐ์ด)๋ฅผ 784์ฌ์ด์ฆ์ 1์ฐจ์ ๋ฐฐ์ด๋ก reshapeํ์ฌ ์์(์์ ๊ฒฐ์ ํ๋ ์์น)๋ฅผ ํ์ธํด ๋ณด์๋ค.
0~255์ ๊ฐ์ ๊ฐ๋๊ฑธ ํ์ธํ ์ ์๋ค.
train_images = train_images / 255.0
test_images = test_images / 255.0
col.Counter(train_images[0].reshape(784))
255.0์ผ๋ก ๋๋์ด 0~1.0์ฌ์ด์ ๊ฐ์ ๊ฐ๊ฒ๋๊ฒ์ ํ์ธํ ์ ์๋ค.
plt.figure(figsize = (10, 10))
plt.imshow(train_images[0])
plt.colorbar()
plt.show()
figure()๋ฅผ ํตํด ๊ทธ๋ฆผ์ ๊ทธ๋ ค๋ณด์์๋ ํ์ธ์ด ๊ฐ๋ฅํ๋ค.
ํ์ต ๋ชจ๋ธ
keras๋ฅผ ์ด์ฉํ๋ฉด ๊ฐ๋จํ๊ฒ ๋ชจ๋ธ์ ๋ง๋ค์ด๋ผ ์ ์๋ค. costํจ์๋ก sparse categorical crossentropy๋ฅผ ์ฌ์ฉํ๋ค. crossentropy๋ ๋จ์ํ ๋ง์๋ ํ๋ฆฌ๋๋ก๋ง cost๋ฅผ ๊ณ์ฐํ๋๊ฒ ์๋๋ผ ์ผ๋ง์ ์ฐจ์ด๋ก ์ธํด ํ๋ ธ๋์ง ๊น์ง ๊ณ์ฐ์ ํฌํจํ๋ค. ๋ฐ๋ผ์ ๋ถ๋ฅ๋ฌธ์ ์์ ๋ ๋์ cost๋ฅผ ๊ณ์ฐํด๋ธ๋ค. ๊ด๋ จ ๋ด์ฉ์ ์ฌ๊ธฐ์์ ์ฐธ์กฐํ๋ค. crossentropy์ค์์๋ sparse๋ฅผ ์ฌ์ฉํ๋ ์ด์ ๋ one-hot encoding๋์ด์๋ ์๋ฃ๊ฐ ์๋๊ธฐ ๋๋ฌธ์ด๋ค. Optimizer๋ Adam optimizer๋ฅผ ์ฌ์ฉํ๋ค. ๋๋ถ๋ถ์ ๊ฒฝ์ฐ์์ ๊ฐ์ฅ ์ ์๋ํ๋ Optimizer๋ก, Optimizer์ ๋ํ์(..)์ด๋ผ๋ ๋ง๋ ์๋ค. ์ฌ์ค ์์ง ๊น๊ฒ ๊ณต๋ถํด๋ณด์ง ์์์ ๋ฌด์จ ์ฐจ์ด๊ฐ ์๋์ง๋ ๋ชจ๋ฅธ๋ค. ์ฐจํ์ ๊ณต๋ถํด์ผ๊ฒ ๋ค.
model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape = (28, 28)))
model.add(keras.layers.Dense(128, activation = 'relu'))
model.add(keras.layers.Dense(10, activation = 'softmax'))
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=10)
ํ๋ จ ์ ์ ๋ํด์ 99%๊น์ง์ ์ ํ๋๊ฐ ๋์ฌ ์ ๋๋ก ํ์ต์ด ๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
๊ฒ์ฆ
ํ์ต๊ฒฐ๊ณผ๋ฅผ ๊ฒ์ฆ. model.evaluate()๋ฅผ ํตํด์ ํ์ต๋ ๋ชจ๋ธ์ ๊ฒ์ฆํ ์ ์๋ค.
loss, accuracy = model.evaluate(test_images, test_labels, verbose = 1)
print('\ntest loss\t: ', loss)
print('test accuracy\t: ', accuracy)
ํ ์คํธ ์ ์ ๋ํด์๋ ์ ํ๋๊ฐ 97%๊ฐ๋์ด ๋์ด์ ํ์ธํ ์ ์๋ค. ์๋ฌด๋ ๋ ํ๋ จ ์ ์ ๋ํ overfitting์ด ์ฝ๊ฐ ์๋๊ฒ์ผ๋ก ์ถ์ ๋๋ค.
predictions = model.predict(test_images)
plt.figure(figsize=(20,20))
for i in range(25) :
plt.subplot(5, 5, i+1)
plt.imshow(test_images[i])
plt.xticks([])
plt.yticks([])
description = str(np.argmax(predictions[i])) + " / " + str(test_labels[i])
plt.xlabel(description)
plt.show()
ํ ์คํธ ์ ์ด๊ธฐ 25๊ฐ์ ์ด๋ฏธ์ง์ ๋ํด ์ด๋ฏธ์ง์ ์์ธก๊ฐ, ์ ๋ต์ ํจ๊ป ๋ณด์ฌ์ฃผ๋๋ก ์ถ๋ ฅํด๋ณด์๋ค.
์ด์ ์ Fashion MNIST๋ฅผ ํตํด ์ค์ตํด๋ณธ ๊ฒฝํ์ผ๋ก ์ธํด ์ด๋ฒ ์ค์ต์ ์ด๋ ต์ง์๊ฒ ์งํํ ์ ์์๋ค. ๋ ์ ๋ฒ์ ๋์ ํ๋ Bank marketing dataset ์ค์ต๊ณผ ๋ฌ๋ฆฌ ์ด๋ฒ MNIST classification์์๋ ์ค์ง Keras๋ฅผ ์ฌ์ฉํด ๋ชจ๋ธ์ ๋ง๋ค์๋๋ฐ, ๊ทธ ๊ณผ์ ์ด ๋งค์ฐ ๊ฐ๋จํ๊ณ ์ง๊ด์ ์ด์ด์ ์ด์ ๋ ์ ๋ง ๋ฅ๋ฌ๋์ด ๋ง์ด ๋์คํ๊ฐ ๋์์์ ๋๊ผ๋ค.
Comments