첨부 실행 코드는 나눔고딕코딩 폰트를 사용합니다.
유용한 소스 코드가 있으면 icodebroker@naver.com으로 보내주시면 감사합니다.
블로그 자료는 자유롭게 사용하세요.

■ 학습 조기 종료시키기

----------------------------------------------------------------------------------------------------

import keras

import keras.callbacks as callbacks

import keras.datasets.mnist as mnist

import keras.models as models

import keras.utils as utils

import keras.layers as layers

import matplotlib.pyplot as pp

import numpy as np

 

np.random.seed(3)

 

print("데이터 로드를 시작합니다.")

 

(trainInputNDArray, trainCottectOutputNDArray), (testInputNDArray, testCorrectOutputNDArray) = mnist.load_data()

 

# trainInputNDArray         : (60000, 28, 28)

# trainCottectOutputNDArray : (60000,)

# testInputNDArray          : (10000, 28, 28)

# testCorrectOutputNDArray  : (10000,)

 

# 훈련/검증 데이터를 분리한다.

validationInputNDArray         = trainInputNDArray[50000:]

validationCorrectOutputNDArray = trainCottectOutputNDArray[50000:]

trainInputNDArray              = trainInputNDArray[:50000]

trainCottectOutputNDArray      = trainCottectOutputNDArray[:50000]

 

# 훈련/검증/테스트 데이터

trainInputNDArray      = trainInputNDArray.reshape(50000, 784).astype("float32") / 255.0

validationInputNDArray = validationInputNDArray.reshape(10000, 784).astype("float32") / 255.0

testInputNDArray       = testInputNDArray.reshape(10000, 784).astype("float32") / 255.0

 

# trainInputNDArray      : (50000, 784)

# validationInputNDArray : (10000, 784)

# testInputNDArray       : (10000, 784)

 

# 훈련/검증 데이터를 섞는다.

trainRandomIndexNDArray      = np.random.choice(50000, 700)

validationRandomIndexNDArray = np.random.choice(10000, 300)

 

trainInputNDArray              = trainInputNDArray[trainRandomIndexNDArray]

trainCottectOutputNDArray      = trainCottectOutputNDArray[trainRandomIndexNDArray]

validationInputNDArray         = validationInputNDArray[validationRandomIndexNDArray]

validationCorrectOutputNDArray = validationCorrectOutputNDArray[validationRandomIndexNDArray]

 

# trainInputNDArray              : (700, 784)

# trainCottectOutputNDArray      : (700,)

# validationInputNDArray         : (300, 784)

# validationCorrectOutputNDArray : (300,)

 

# 정답 데이터에 대해 원핫 인코딩 처리한다.

trainCottectOutputNDArray      = utils.np_utils.to_categorical(trainCottectOutputNDArray)

validationCorrectOutputNDArray = utils.np_utils.to_categorical(validationCorrectOutputNDArray)

testCorrectOutputNDArray       = utils.np_utils.to_categorical(testCorrectOutputNDArray)

 

# trainCottectOutputNDArray      : (50000, 10)

# validationCorrectOutputNDArray : (10000, 10)

# testCorrectOutputNDArray       : (10000, 10)

 

print("데이터 로드를 종료합니다.")

 

print("모델 정의를 시작합니다.")

 

model = models.Sequential()

 

model.add(layers.Dense(units = 64, input_dim = 784, activation = "relu"))

model.add(layers.Dense(units = 10, activation = "softmax"))

 

model.compile(loss = "categorical_crossentropy", optimizer = "sgd", metrics = ["accuracy"])

 

print("모델 정의를 종료합니다.")

 

print("모델 학습을 시작합니다.")

 

earlyStopping = callbacks.EarlyStopping(patience = 20)

 

history = model.fit(trainInputNDArray, trainCottectOutputNDArray, epochs = 1000, batch_size = 100,\

    validation_data = (validationInputNDArray, validationCorrectOutputNDArray), callbacks = [earlyStopping])

 

print("모델 학습을 종료합니다.")

 

print("학습 결과를 조회합니다.")

 

evaluateList = model.evaluate(testInputNDArray, testCorrectOutputNDArray, batch_size = 32)

 

print("")

print("loss     : " + str(evaluateList[0]))

print("accuracy : " + str(evaluateList[1]))

 

figure, lossAxeSubplot = pp.subplots()

 

accuracyAxeSubplot = lossAxeSubplot.twinx()

 

lossAxeSubplot.plot(history.history["loss"    ], "y", label = "train loss")

lossAxeSubplot.plot(history.history["val_loss"], "r", label = "val loss"  )

 

accuracyAxeSubplot.plot(history.history["acc"    ], "b", label = "train acc")

accuracyAxeSubplot.plot(history.history["val_acc"], "g", label = "val acc"  )

 

lossAxeSubplot.set_xlabel("epoch")

lossAxeSubplot.set_ylabel("loss")

 

lossAxeSubplot.legend(loc = "upper left")

 

accuracyAxeSubplot.set_ylabel("accuracy")

 

accuracyAxeSubplot.legend(loc = "lower left")

 

pp.show()

----------------------------------------------------------------------------------------------------

Posted by 사용자 icodebroker

댓글을 달아 주세요