Study/밑바닥부터 시작하는 딥러닝

Ch6 학습 관련 기술들

Bonseong 2021. 1. 25. 10:56

 

 

 

 

6.1 매개변수 갱신

 
  • 신경망 학습의 목적 : 손실 함수의 값을 가능한 낮추는 매개변수를 찾는 것 (최적화)
  • 여태까지 미분을 통해 매개변수의 값을 갱신했음 (확률적 경사 하강법)
 

6.1.1 확률적 경사 하강법 (SGD)

In [1]:
class SGD:
    def __init__(self, lr=0.01):
        self.lr = lr
    def update(self, params, grads):
        for key in params.keys():
            params[key] -= self.lr*grads[key]
 
  • 대부분의 딥러닝 프레임워크에서는 다양한 최적화 기법을 구현해 제공
 

6.1.2 SGD의 단점

 
  • 비등상성 함수 (방향에 따라 성질, 기울기가 달라지는 함수) 에서느 탐색 경로가 비효율적
 

6.1.3 모멘텀

 
  • 운동량을 뜻하는 단어

 

In [3]:
class Momentum:
    def __init__(self, lr=0.01, momentum=0.9):
        self.lr = lr
        self.momentum = momentum
        self.v = None
        
    def update(self, params, grads):
        if self.v is None:
            self.v = {}
            for key, val in params.items():
                self.v[key] = np.zeros_like(val)
                
        for key in params.keys():
            self.v[key] = self.momentum*self.v[key] - self.lr*grads[key]
            params[key] += self.v[key]
 
  • SGD와 비교했을 때 지그재그의 정도가 덜함
  • 예시에서는 x축의 힘은 아주 작지만 방향은 변하지 않아 한 방향으로 일정하게 가속
  • y축의 힘은 크지만 번갈아 받아 상충하며 y축의 속도는 안정적이지 않음
 

6.1.4 AdaGrad

 
  • 학습률 감소 : 학습을 진행하면서 학습률을 점차 줄여가는 방법
  • 매개변수 전체의 학습률 값을 일괄적으로 낮추는 방법
In [6]:
class AdaGrad:
    def __init__(self, lr=0.01):
        self.lr = lr
        self.h = None
        
    def update(self, params, grads):
        if self.h is None:
            self.h = {}
            for key, val in params.items():
                self.h[key] = np.zeros_like(val)
                
            for key in params.keys():
                self.h[key] += grads[key] * grads[key]
                params[key] -= self.lr *grads[key] / (np.sqrt(self.h[key])+1e-7) # 0이 담겨져 있다고 해도 0으로 나누는 것을 막아줌
 
  • 최솟값을 향해 효율적으로 움직임
  • y축 방향은 기울기가 커서 처음에는 크게 움직이지만, 그 움직임에 비례해 갱신 정도도 큰 폭으로 작아지도록 조정
  • y축 방향으로 갱신 강도가 빠르게 약해지고, 지그재그 움직임이 줄어듬
 

6.1.5 Adam

 
  • 직관적으로 모멘텀과 AdaGrad를 융합한 방법
  • 하이퍼파라미터의 편향 보정이 진행됨
 

6.1.6 갱신 방법 선택

 
  • 가장 뛰어난 방법은 아직까진 없음
  • 각 문제의 상황에 따라 방법 선택
 

6.2 가중치의 초깃값

 

6.2.1 초깃값을 0으로 하면?

 
  • 가중치 감소 : 오버피팅을 억제해 범용 성능을 높이는 테크닉
  • 가중치를 균일한 값으로 설정하면 오차역전파법에서 모든 가중치의 값이 똑같이 갱신되기 때문에 좋지 않음
  • 가중치가 고르게 되어버리는 상황을 막으려면 초깃값을 무작위로 설정해야 함
 

6.2.2 은닉층의 활성화값 분포

In [11]:
import numpy as np
import matplotlib.pyplot as plt

def sigmoid(x):
    return 1/(1+np.exp(-x))

x = np.random.randn(1000,100)
node_num = 100 # 각층의 뉴런은 100개
hidden_layer_size = 5 # 5개의 층
activations = {}

for i in range(hidden_layer_size):
    if i != 0:
        x = activations[i-1]
        
    w = np.random.randn(node_num, node_num) * 1
    a = np.dot(x,w)
    z = sigmoid(a)
    activations[i] = z
In [12]:
for i, a in activations.items():
    plt.subplot(1, len(activations), i+1)
    plt.title(str(i+1) + '-layer')
    plt.hist(a.flatten(), 30, range=(0,1))
plt.show()
 
 
  • 각 층의 활성화 값이 0과 1에 몰려 있음
  • 기울기 소실 : 데이터가 0과 1에 치우쳐 분포하게 되면 역전파의 기욹리 앖이 점점 작아지다가 사라짐
In [13]:
for i in range(hidden_layer_size):
    if i != 0:
        x = activations[i-1]
        
    w = np.random.randn(node_num, node_num) * 0.01
    a = np.dot(x,w)
    z = sigmoid(a)
    activations[i] = z
In [14]:
for i, a in activations.items():
    plt.subplot(1, len(activations), i+1)
    plt.title(str(i+1) + '-layer')
    plt.hist(a.flatten(), 30, range=(0,1))
plt.show()
 
 
  • 이번에는 0.5 부근에 집중되어 있음
  • 다수의 뉴런이 거의 같은 값을 출력하고 있음 -> 뉴런을 여러개 둔 의미가 없어짐
  • 활성화값들이 치우치면 표현력을 제한한다는 관점에서 문제가 됨
 
  • Xavier 초깃값 : 일반적인 딥러닝 프레임워크에서 표준적으로 사용하는 초깃값
  • 활성화값들을 광범위하게 분포시킬 목적, 앞 계층의 노드가 n개라면 표준편차가 1/sqrt(n)인 분포를 사용하면 된다는 결론
In [19]:
for i in range(hidden_layer_size):
    if i != 0:
        x = activations[i-1]
    
    node_num=100
    w = np.random.randn(node_num, node_num) / np.sqrt(node_num)
    a = np.dot(x,w)
    z = sigmoid(a)
    activations[i] = z
In [20]:
for i, a in activations.items():
    plt.subplot(1, len(activations), i+1)
    plt.title(str(i+1) + '-layer')
    plt.hist(a.flatten(), 30, range=(0,1))
plt.show()
 
 
  • 층이 깊어지면서 형태가 다소 일그러지지만, 넓게 분포됨
  • 시그모이드 함수의 표현력도 제한받지 않음 (일그러짐 현상은 tanh함수를 사용하면 해결 됨)
 

6.2.3 ReLU를 사용할 때의 가중치 초깃값

 
  • sigmoid 함수와 tanh 함수는 좌우 대칭이라 중앙 부근 선형인 함수 -> Xavier 초깃값
  • ReLU는 He 초깃값 사용
  • He 초깃값 : 앞 계층의 노드가 n개 일 때, 표준편차가 sqrt(2/n) 인 정규분포 사용
 

6.3 배치 정규화

  • 배치 정규화 : 각 층이 활성화를 적당히 퍼뜨리도록 강제함
 

6.3.1 배치 정규화 알고리즘

  • 학습을 빨리 진행할 수 있음
  • 초깃값에 크게 의존하지 않음
  • 오버피팅 억제
 
  • 학습 시 미니배치를 단위로 정규화
  • 데이터 분포가 표준정규분포를 따르도록 함 (데이터 분포가 덜 치우치게 할 수 있음)
  • 각 계층마다 정규화 데이터에 고유한 확대화 이동 변환을 수행
 

6.4 바른 학습을 위해

  • 오버피팅 : 신경망이 훈련 데이터에만 지나치게 적응되어 그 외의 데이터에는 제대로 대응하지 못하는 상태
 

6.4.1 오버피팅

 
  • 매개변수가 많고 표현력이 높거나
  • 훈련 데이터가 적을 때 오버피팅 발생
In [35]:
import os
import sys
from dataset.mnist import load_mnist
from common.multi_layer_net import MultiLayerNet
from common.optimizer import SGD

sys.path.append(os.pardir)
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
x_train = x_train[:300]
t_train = t_train[:300]
In [39]:
#weight_decay_lambda = 0

network = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100], output_size=10,
                        weight_decay_lambda=weight_decay_lambda)
optimizer = SGD(lr=0.01) # 학습률이 0.01인 SGD로 매개변수 갱신

max_epochs = 201
train_size = x_train.shape[0]
batch_size = 100

train_loss_list = []
train_acc_list = []
test_acc_list = []

iter_per_epoch = max(train_size / batch_size, 1)
epoch_cnt = 0

for i in range(1000000000):
    batch_mask = np.random.choice(train_size, batch_size)
    x_batch = x_train[batch_mask]
    t_batch = t_train[batch_mask]

    grads = network.gradient(x_batch, t_batch)
    optimizer.update(network.params, grads)

    if i % iter_per_epoch == 0:
        train_acc = network.accuracy(x_train, t_train)
        test_acc = network.accuracy(x_test, t_test)
        train_acc_list.append(train_acc)
        test_acc_list.append(test_acc)

        print("epoch:" + str(epoch_cnt) + ", train acc:" + str(train_acc) + ", test acc:" + str(test_acc))

        epoch_cnt += 1
        if epoch_cnt >= max_epochs:
            break
 
epoch:0, train acc:0.06666666666666667, test acc:0.0853
epoch:1, train acc:0.08666666666666667, test acc:0.1029
epoch:2, train acc:0.12, test acc:0.1275
epoch:3, train acc:0.14333333333333334, test acc:0.1458
epoch:4, train acc:0.19333333333333333, test acc:0.1688
epoch:5, train acc:0.22, test acc:0.1908
epoch:6, train acc:0.22333333333333333, test acc:0.2027
epoch:7, train acc:0.24666666666666667, test acc:0.2146
epoch:8, train acc:0.28, test acc:0.2263
epoch:9, train acc:0.30666666666666664, test acc:0.2354
epoch:10, train acc:0.33, test acc:0.252
epoch:11, train acc:0.35333333333333333, test acc:0.2639
epoch:12, train acc:0.38666666666666666, test acc:0.2814
epoch:13, train acc:0.44, test acc:0.2995
epoch:14, train acc:0.49666666666666665, test acc:0.3188
epoch:15, train acc:0.5166666666666667, test acc:0.3333
epoch:16, train acc:0.5266666666666666, test acc:0.3434
epoch:17, train acc:0.5433333333333333, test acc:0.3593
epoch:18, train acc:0.5466666666666666, test acc:0.3744
epoch:19, train acc:0.5633333333333334, test acc:0.3838
epoch:20, train acc:0.5733333333333334, test acc:0.3949
epoch:21, train acc:0.58, test acc:0.4109
epoch:22, train acc:0.5833333333333334, test acc:0.4195
epoch:23, train acc:0.6, test acc:0.4355
epoch:24, train acc:0.6166666666666667, test acc:0.4492
epoch:25, train acc:0.6366666666666667, test acc:0.4595
epoch:26, train acc:0.66, test acc:0.4742
epoch:27, train acc:0.6866666666666666, test acc:0.486
epoch:28, train acc:0.6933333333333334, test acc:0.4987
epoch:29, train acc:0.69, test acc:0.5004
epoch:30, train acc:0.7033333333333334, test acc:0.5164
epoch:31, train acc:0.7166666666666667, test acc:0.5383
epoch:32, train acc:0.7266666666666667, test acc:0.5452
epoch:33, train acc:0.7366666666666667, test acc:0.5497
epoch:34, train acc:0.7366666666666667, test acc:0.5627
epoch:35, train acc:0.7533333333333333, test acc:0.568
epoch:36, train acc:0.7366666666666667, test acc:0.571
epoch:37, train acc:0.76, test acc:0.5781
epoch:38, train acc:0.76, test acc:0.5873
epoch:39, train acc:0.7566666666666667, test acc:0.5829
epoch:40, train acc:0.7833333333333333, test acc:0.5974
epoch:41, train acc:0.7766666666666666, test acc:0.5993
epoch:42, train acc:0.78, test acc:0.6064
epoch:43, train acc:0.7933333333333333, test acc:0.6132
epoch:44, train acc:0.8066666666666666, test acc:0.6211
epoch:45, train acc:0.81, test acc:0.6146
epoch:46, train acc:0.8133333333333334, test acc:0.6265
epoch:47, train acc:0.8266666666666667, test acc:0.633
epoch:48, train acc:0.8266666666666667, test acc:0.6423
epoch:49, train acc:0.8233333333333334, test acc:0.6375
epoch:50, train acc:0.8533333333333334, test acc:0.6467
epoch:51, train acc:0.83, test acc:0.644
epoch:52, train acc:0.8466666666666667, test acc:0.6508
epoch:53, train acc:0.85, test acc:0.6604
epoch:54, train acc:0.8633333333333333, test acc:0.6703
epoch:55, train acc:0.86, test acc:0.6638
epoch:56, train acc:0.8633333333333333, test acc:0.6714
epoch:57, train acc:0.8833333333333333, test acc:0.6723
epoch:58, train acc:0.8833333333333333, test acc:0.6848
epoch:59, train acc:0.8866666666666667, test acc:0.6772
epoch:60, train acc:0.89, test acc:0.6789
epoch:61, train acc:0.8933333333333333, test acc:0.6866
epoch:62, train acc:0.89, test acc:0.6871
epoch:63, train acc:0.89, test acc:0.6868
epoch:64, train acc:0.9033333333333333, test acc:0.6991
epoch:65, train acc:0.9033333333333333, test acc:0.6948
epoch:66, train acc:0.9066666666666666, test acc:0.6959
epoch:67, train acc:0.91, test acc:0.7046
epoch:68, train acc:0.9133333333333333, test acc:0.7041
epoch:69, train acc:0.9166666666666666, test acc:0.7034
epoch:70, train acc:0.9133333333333333, test acc:0.7063
epoch:71, train acc:0.9333333333333333, test acc:0.7125
epoch:72, train acc:0.9266666666666666, test acc:0.7115
epoch:73, train acc:0.9266666666666666, test acc:0.7115
epoch:74, train acc:0.9266666666666666, test acc:0.7184
epoch:75, train acc:0.9166666666666666, test acc:0.7119
epoch:76, train acc:0.9366666666666666, test acc:0.7163
epoch:77, train acc:0.9333333333333333, test acc:0.7194
epoch:78, train acc:0.9333333333333333, test acc:0.7203
epoch:79, train acc:0.94, test acc:0.7224
epoch:80, train acc:0.9433333333333334, test acc:0.7226
epoch:81, train acc:0.9366666666666666, test acc:0.7211
epoch:82, train acc:0.95, test acc:0.7246
epoch:83, train acc:0.9466666666666667, test acc:0.7274
epoch:84, train acc:0.9533333333333334, test acc:0.7237
epoch:85, train acc:0.9533333333333334, test acc:0.7257
epoch:86, train acc:0.9566666666666667, test acc:0.7285
epoch:87, train acc:0.9533333333333334, test acc:0.7251
epoch:88, train acc:0.9533333333333334, test acc:0.7319
epoch:89, train acc:0.96, test acc:0.7299
epoch:90, train acc:0.9766666666666667, test acc:0.728
epoch:91, train acc:0.97, test acc:0.7373
epoch:92, train acc:0.9633333333333334, test acc:0.7323
epoch:93, train acc:0.9733333333333334, test acc:0.7374
epoch:94, train acc:0.9733333333333334, test acc:0.7364
epoch:95, train acc:0.9733333333333334, test acc:0.7326
epoch:96, train acc:0.97, test acc:0.7377
epoch:97, train acc:0.9766666666666667, test acc:0.7392
epoch:98, train acc:0.9766666666666667, test acc:0.7414
epoch:99, train acc:0.98, test acc:0.7383
epoch:100, train acc:0.9766666666666667, test acc:0.7419
epoch:101, train acc:0.98, test acc:0.7396
epoch:102, train acc:0.98, test acc:0.7421
epoch:103, train acc:0.9866666666666667, test acc:0.7427
epoch:104, train acc:0.9833333333333333, test acc:0.742
epoch:105, train acc:0.98, test acc:0.741
epoch:106, train acc:0.9866666666666667, test acc:0.7424
epoch:107, train acc:0.99, test acc:0.7415
epoch:108, train acc:0.9833333333333333, test acc:0.7429
epoch:109, train acc:0.99, test acc:0.7463
epoch:110, train acc:0.9866666666666667, test acc:0.7418
epoch:111, train acc:0.99, test acc:0.7442
epoch:112, train acc:0.99, test acc:0.7453
epoch:113, train acc:0.99, test acc:0.7444
epoch:114, train acc:0.99, test acc:0.7413
epoch:115, train acc:0.9933333333333333, test acc:0.7451
epoch:116, train acc:0.9933333333333333, test acc:0.7481
epoch:117, train acc:0.9933333333333333, test acc:0.7494
epoch:118, train acc:0.9933333333333333, test acc:0.7449
epoch:119, train acc:0.9933333333333333, test acc:0.7449
epoch:120, train acc:0.9933333333333333, test acc:0.7425
epoch:121, train acc:0.9933333333333333, test acc:0.7477
epoch:122, train acc:0.9933333333333333, test acc:0.7453
epoch:123, train acc:0.9933333333333333, test acc:0.7487
epoch:124, train acc:0.9933333333333333, test acc:0.7473
epoch:125, train acc:0.9933333333333333, test acc:0.7494
epoch:126, train acc:0.9966666666666667, test acc:0.7496
epoch:127, train acc:0.9966666666666667, test acc:0.7494
epoch:128, train acc:0.9966666666666667, test acc:0.7491
epoch:129, train acc:0.9966666666666667, test acc:0.7481
epoch:130, train acc:0.9966666666666667, test acc:0.7467
epoch:131, train acc:0.9966666666666667, test acc:0.7489
epoch:132, train acc:0.9966666666666667, test acc:0.7489
epoch:133, train acc:0.9966666666666667, test acc:0.7497
epoch:134, train acc:0.9966666666666667, test acc:0.7511
epoch:135, train acc:0.9966666666666667, test acc:0.7504
epoch:136, train acc:0.9966666666666667, test acc:0.7512
epoch:137, train acc:0.9966666666666667, test acc:0.7512
epoch:138, train acc:0.9966666666666667, test acc:0.7524
epoch:139, train acc:0.9966666666666667, test acc:0.7517
epoch:140, train acc:0.9966666666666667, test acc:0.7512
epoch:141, train acc:1.0, test acc:0.7507
epoch:142, train acc:0.9966666666666667, test acc:0.753
epoch:143, train acc:1.0, test acc:0.7514
epoch:144, train acc:0.9966666666666667, test acc:0.7528
epoch:145, train acc:0.9966666666666667, test acc:0.7521
epoch:146, train acc:1.0, test acc:0.7535
epoch:147, train acc:0.9966666666666667, test acc:0.7536
epoch:148, train acc:0.9966666666666667, test acc:0.7529
epoch:149, train acc:1.0, test acc:0.7533
epoch:150, train acc:1.0, test acc:0.7506
epoch:151, train acc:1.0, test acc:0.7528
epoch:152, train acc:1.0, test acc:0.7533
epoch:153, train acc:1.0, test acc:0.7539
epoch:154, train acc:1.0, test acc:0.7551
epoch:155, train acc:1.0, test acc:0.7549
epoch:156, train acc:1.0, test acc:0.7566
epoch:157, train acc:1.0, test acc:0.7573
epoch:158, train acc:1.0, test acc:0.755
epoch:159, train acc:1.0, test acc:0.7541
epoch:160, train acc:1.0, test acc:0.7562
epoch:161, train acc:1.0, test acc:0.7542
epoch:162, train acc:1.0, test acc:0.7538
epoch:163, train acc:1.0, test acc:0.7568
epoch:164, train acc:1.0, test acc:0.7564
epoch:165, train acc:1.0, test acc:0.7545
epoch:166, train acc:1.0, test acc:0.7556
epoch:167, train acc:1.0, test acc:0.7572
epoch:168, train acc:1.0, test acc:0.757
epoch:169, train acc:1.0, test acc:0.7554
epoch:170, train acc:1.0, test acc:0.7563
epoch:171, train acc:1.0, test acc:0.7557
epoch:172, train acc:1.0, test acc:0.7566
epoch:173, train acc:1.0, test acc:0.7563
epoch:174, train acc:1.0, test acc:0.7569
epoch:175, train acc:1.0, test acc:0.7563
epoch:176, train acc:1.0, test acc:0.7571
epoch:177, train acc:1.0, test acc:0.7568
epoch:178, train acc:1.0, test acc:0.7579
epoch:179, train acc:1.0, test acc:0.7574
epoch:180, train acc:1.0, test acc:0.7566
epoch:181, train acc:1.0, test acc:0.7598
epoch:182, train acc:1.0, test acc:0.7591
epoch:183, train acc:1.0, test acc:0.7596
epoch:184, train acc:1.0, test acc:0.7577
epoch:185, train acc:1.0, test acc:0.7595
epoch:186, train acc:1.0, test acc:0.7594
epoch:187, train acc:1.0, test acc:0.7607
epoch:188, train acc:1.0, test acc:0.7594
epoch:189, train acc:1.0, test acc:0.7604
epoch:190, train acc:1.0, test acc:0.7599
epoch:191, train acc:1.0, test acc:0.7598
epoch:192, train acc:1.0, test acc:0.7615
epoch:193, train acc:1.0, test acc:0.7613
epoch:194, train acc:1.0, test acc:0.7596
epoch:195, train acc:1.0, test acc:0.7601
epoch:196, train acc:1.0, test acc:0.7598
epoch:197, train acc:1.0, test acc:0.7603
epoch:198, train acc:1.0, test acc:0.7604
epoch:199, train acc:1.0, test acc:0.7612
epoch:200, train acc:1.0, test acc:0.7601
In [40]:
# 그래프 그리기==========
markers = {'train': 'o', 'test': 's'}
x = np.arange(max_epochs)
plt.plot(x, train_acc_list, marker='o', label='train', markevery=10)
plt.plot(x, test_acc_list, marker='s', label='test', markevery=10)
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()
 
 
  • 훈련 데이터에 대해 정확도가 100%임
 

6.4.2 가중치 감소

 
  • 큰 가중치에 대해 그에 상응하는 패널티를 부과해 오버피팅 억제
In [41]:
weight_decay_lambda = 0.1

network = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100], output_size=10,
                        weight_decay_lambda=weight_decay_lambda)
optimizer = SGD(lr=0.01) # 학습률이 0.01인 SGD로 매개변수 갱신

max_epochs = 201
train_size = x_train.shape[0]
batch_size = 100

train_loss_list = []
train_acc_list = []
test_acc_list = []

iter_per_epoch = max(train_size / batch_size, 1)
epoch_cnt = 0

for i in range(1000000000):
    batch_mask = np.random.choice(train_size, batch_size)
    x_batch = x_train[batch_mask]
    t_batch = t_train[batch_mask]

    grads = network.gradient(x_batch, t_batch)
    optimizer.update(network.params, grads)

    if i % iter_per_epoch == 0:
        train_acc = network.accuracy(x_train, t_train)
        test_acc = network.accuracy(x_test, t_test)
        train_acc_list.append(train_acc)
        test_acc_list.append(test_acc)

        print("epoch:" + str(epoch_cnt) + ", train acc:" + str(train_acc) + ", test acc:" + str(test_acc))

        epoch_cnt += 1
        if epoch_cnt >= max_epochs:
            break
 
epoch:0, train acc:0.14, test acc:0.0978
epoch:1, train acc:0.14333333333333334, test acc:0.1045
epoch:2, train acc:0.15333333333333332, test acc:0.1145
epoch:3, train acc:0.16666666666666666, test acc:0.1252
epoch:4, train acc:0.18666666666666668, test acc:0.1337
epoch:5, train acc:0.18666666666666668, test acc:0.1452
epoch:6, train acc:0.19, test acc:0.1578
epoch:7, train acc:0.20333333333333334, test acc:0.1647
epoch:8, train acc:0.21, test acc:0.1754
epoch:9, train acc:0.22666666666666666, test acc:0.1801
epoch:10, train acc:0.25333333333333335, test acc:0.1876
epoch:11, train acc:0.2633333333333333, test acc:0.1939
epoch:12, train acc:0.3, test acc:0.2039
epoch:13, train acc:0.3233333333333333, test acc:0.211
epoch:14, train acc:0.37333333333333335, test acc:0.224
epoch:15, train acc:0.3566666666666667, test acc:0.2256
epoch:16, train acc:0.37333333333333335, test acc:0.231
epoch:17, train acc:0.41, test acc:0.2532
epoch:18, train acc:0.4166666666666667, test acc:0.2673
epoch:19, train acc:0.44333333333333336, test acc:0.2873
epoch:20, train acc:0.45666666666666667, test acc:0.2963
epoch:21, train acc:0.48333333333333334, test acc:0.317
epoch:22, train acc:0.48333333333333334, test acc:0.3311
epoch:23, train acc:0.5, test acc:0.3523
epoch:24, train acc:0.51, test acc:0.3653
epoch:25, train acc:0.5233333333333333, test acc:0.3804
epoch:26, train acc:0.52, test acc:0.3887
epoch:27, train acc:0.5266666666666666, test acc:0.3936
epoch:28, train acc:0.5566666666666666, test acc:0.4127
epoch:29, train acc:0.58, test acc:0.4294
epoch:30, train acc:0.59, test acc:0.4389
epoch:31, train acc:0.5933333333333334, test acc:0.4475
epoch:32, train acc:0.6133333333333333, test acc:0.4579
epoch:33, train acc:0.61, test acc:0.4651
epoch:34, train acc:0.6166666666666667, test acc:0.4654
epoch:35, train acc:0.6233333333333333, test acc:0.4728
epoch:36, train acc:0.6233333333333333, test acc:0.4931
epoch:37, train acc:0.6633333333333333, test acc:0.5165
epoch:38, train acc:0.6633333333333333, test acc:0.5266
epoch:39, train acc:0.6466666666666666, test acc:0.5168
epoch:40, train acc:0.6566666666666666, test acc:0.5224
epoch:41, train acc:0.62, test acc:0.5112
epoch:42, train acc:0.6633333333333333, test acc:0.5324
epoch:43, train acc:0.6533333333333333, test acc:0.5314
epoch:44, train acc:0.6633333333333333, test acc:0.5326
epoch:45, train acc:0.66, test acc:0.534
epoch:46, train acc:0.69, test acc:0.5616
epoch:47, train acc:0.7066666666666667, test acc:0.5845
epoch:48, train acc:0.75, test acc:0.6006
epoch:49, train acc:0.7466666666666667, test acc:0.6074
epoch:50, train acc:0.7266666666666667, test acc:0.6067
epoch:51, train acc:0.74, test acc:0.6074
epoch:52, train acc:0.7433333333333333, test acc:0.6122
epoch:53, train acc:0.7266666666666667, test acc:0.6167
epoch:54, train acc:0.7366666666666667, test acc:0.6137
epoch:55, train acc:0.7333333333333333, test acc:0.6078
epoch:56, train acc:0.7433333333333333, test acc:0.6217
epoch:57, train acc:0.7566666666666667, test acc:0.6301
epoch:58, train acc:0.7333333333333333, test acc:0.6184
epoch:59, train acc:0.78, test acc:0.6408
epoch:60, train acc:0.78, test acc:0.6507
epoch:61, train acc:0.7633333333333333, test acc:0.645
epoch:62, train acc:0.75, test acc:0.6399
epoch:63, train acc:0.7633333333333333, test acc:0.6395
epoch:64, train acc:0.77, test acc:0.6399
epoch:65, train acc:0.7866666666666666, test acc:0.6502
epoch:66, train acc:0.79, test acc:0.648
epoch:67, train acc:0.76, test acc:0.6435
epoch:68, train acc:0.7766666666666666, test acc:0.653
epoch:69, train acc:0.7866666666666666, test acc:0.6468
epoch:70, train acc:0.7866666666666666, test acc:0.6466
epoch:71, train acc:0.7833333333333333, test acc:0.662
epoch:72, train acc:0.77, test acc:0.6596
epoch:73, train acc:0.78, test acc:0.6679
epoch:74, train acc:0.79, test acc:0.6645
epoch:75, train acc:0.8033333333333333, test acc:0.6619
epoch:76, train acc:0.8, test acc:0.668
epoch:77, train acc:0.79, test acc:0.6659
epoch:78, train acc:0.7966666666666666, test acc:0.6665
epoch:79, train acc:0.8033333333333333, test acc:0.6673
epoch:80, train acc:0.8133333333333334, test acc:0.6717
epoch:81, train acc:0.8066666666666666, test acc:0.6652
epoch:82, train acc:0.81, test acc:0.6699
epoch:83, train acc:0.8233333333333334, test acc:0.6731
epoch:84, train acc:0.8133333333333334, test acc:0.6782
epoch:85, train acc:0.8066666666666666, test acc:0.678
epoch:86, train acc:0.8166666666666667, test acc:0.6782
epoch:87, train acc:0.8266666666666667, test acc:0.6751
epoch:88, train acc:0.8166666666666667, test acc:0.6743
epoch:89, train acc:0.8133333333333334, test acc:0.677
epoch:90, train acc:0.8233333333333334, test acc:0.6788
epoch:91, train acc:0.8166666666666667, test acc:0.6748
epoch:92, train acc:0.8333333333333334, test acc:0.6779
epoch:93, train acc:0.8166666666666667, test acc:0.6681
epoch:94, train acc:0.8366666666666667, test acc:0.6798
epoch:95, train acc:0.83, test acc:0.6808
epoch:96, train acc:0.8233333333333334, test acc:0.6726
epoch:97, train acc:0.8233333333333334, test acc:0.6847
epoch:98, train acc:0.82, test acc:0.6875
epoch:99, train acc:0.8266666666666667, test acc:0.6809
epoch:100, train acc:0.8266666666666667, test acc:0.6882
epoch:101, train acc:0.81, test acc:0.6804
epoch:102, train acc:0.8333333333333334, test acc:0.6954
epoch:103, train acc:0.84, test acc:0.692
epoch:104, train acc:0.85, test acc:0.6822
epoch:105, train acc:0.8266666666666667, test acc:0.6802
epoch:106, train acc:0.8233333333333334, test acc:0.676
epoch:107, train acc:0.8333333333333334, test acc:0.6742
epoch:108, train acc:0.8366666666666667, test acc:0.678
epoch:109, train acc:0.8166666666666667, test acc:0.6865
epoch:110, train acc:0.82, test acc:0.6843
epoch:111, train acc:0.8266666666666667, test acc:0.6814
epoch:112, train acc:0.84, test acc:0.6918
epoch:113, train acc:0.8466666666666667, test acc:0.7037
epoch:114, train acc:0.8433333333333334, test acc:0.692
epoch:115, train acc:0.8533333333333334, test acc:0.7005
epoch:116, train acc:0.8466666666666667, test acc:0.6994
epoch:117, train acc:0.85, test acc:0.7008
epoch:118, train acc:0.8433333333333334, test acc:0.6966
epoch:119, train acc:0.8533333333333334, test acc:0.6992
epoch:120, train acc:0.8366666666666667, test acc:0.6848
epoch:121, train acc:0.8533333333333334, test acc:0.6941
epoch:122, train acc:0.8466666666666667, test acc:0.6884
epoch:123, train acc:0.8566666666666667, test acc:0.6955
epoch:124, train acc:0.85, test acc:0.6888
epoch:125, train acc:0.86, test acc:0.6948
epoch:126, train acc:0.8566666666666667, test acc:0.6895
epoch:127, train acc:0.8466666666666667, test acc:0.7055
epoch:128, train acc:0.8633333333333333, test acc:0.6953
epoch:129, train acc:0.8566666666666667, test acc:0.6928
epoch:130, train acc:0.8533333333333334, test acc:0.7037
epoch:131, train acc:0.86, test acc:0.7023
epoch:132, train acc:0.86, test acc:0.7029
epoch:133, train acc:0.8433333333333334, test acc:0.6954
epoch:134, train acc:0.85, test acc:0.6962
epoch:135, train acc:0.8366666666666667, test acc:0.6977
epoch:136, train acc:0.84, test acc:0.6955
epoch:137, train acc:0.8333333333333334, test acc:0.6908
epoch:138, train acc:0.8466666666666667, test acc:0.6966
epoch:139, train acc:0.8433333333333334, test acc:0.6981
epoch:140, train acc:0.85, test acc:0.7024
epoch:141, train acc:0.85, test acc:0.7046
epoch:142, train acc:0.84, test acc:0.694
epoch:143, train acc:0.8633333333333333, test acc:0.6941
epoch:144, train acc:0.8366666666666667, test acc:0.695
epoch:145, train acc:0.8433333333333334, test acc:0.704
epoch:146, train acc:0.8533333333333334, test acc:0.7095
epoch:147, train acc:0.8466666666666667, test acc:0.7061
epoch:148, train acc:0.87, test acc:0.7058
epoch:149, train acc:0.8666666666666667, test acc:0.7085
epoch:150, train acc:0.8566666666666667, test acc:0.7027
epoch:151, train acc:0.8566666666666667, test acc:0.7022
epoch:152, train acc:0.8566666666666667, test acc:0.6997
epoch:153, train acc:0.8533333333333334, test acc:0.7004
epoch:154, train acc:0.85, test acc:0.7018
epoch:155, train acc:0.8566666666666667, test acc:0.7018
epoch:156, train acc:0.8466666666666667, test acc:0.7031
epoch:157, train acc:0.8566666666666667, test acc:0.7073
epoch:158, train acc:0.8566666666666667, test acc:0.6965
epoch:159, train acc:0.86, test acc:0.7038
epoch:160, train acc:0.8633333333333333, test acc:0.7076
epoch:161, train acc:0.8666666666666667, test acc:0.6949
epoch:162, train acc:0.85, test acc:0.7108
epoch:163, train acc:0.8433333333333334, test acc:0.7057
epoch:164, train acc:0.86, test acc:0.7045
epoch:165, train acc:0.8533333333333334, test acc:0.703
epoch:166, train acc:0.8533333333333334, test acc:0.7004
epoch:167, train acc:0.8633333333333333, test acc:0.7031
epoch:168, train acc:0.86, test acc:0.7055
epoch:169, train acc:0.8633333333333333, test acc:0.7071
epoch:170, train acc:0.85, test acc:0.7062
epoch:171, train acc:0.8733333333333333, test acc:0.7107
epoch:172, train acc:0.8666666666666667, test acc:0.7118
epoch:173, train acc:0.8633333333333333, test acc:0.7076
epoch:174, train acc:0.8466666666666667, test acc:0.7095
epoch:175, train acc:0.8633333333333333, test acc:0.711
epoch:176, train acc:0.85, test acc:0.7021
epoch:177, train acc:0.86, test acc:0.7041
epoch:178, train acc:0.85, test acc:0.7046
epoch:179, train acc:0.84, test acc:0.6988
epoch:180, train acc:0.8566666666666667, test acc:0.698
epoch:181, train acc:0.8666666666666667, test acc:0.7105
epoch:182, train acc:0.8733333333333333, test acc:0.7119
epoch:183, train acc:0.8533333333333334, test acc:0.7059
epoch:184, train acc:0.8766666666666667, test acc:0.7139
epoch:185, train acc:0.8533333333333334, test acc:0.7007
epoch:186, train acc:0.8666666666666667, test acc:0.6988
epoch:187, train acc:0.8566666666666667, test acc:0.7012
epoch:188, train acc:0.8433333333333334, test acc:0.6951
epoch:189, train acc:0.8566666666666667, test acc:0.6957
epoch:190, train acc:0.8666666666666667, test acc:0.7065
epoch:191, train acc:0.86, test acc:0.703
epoch:192, train acc:0.8533333333333334, test acc:0.71
epoch:193, train acc:0.86, test acc:0.7073
epoch:194, train acc:0.8533333333333334, test acc:0.7044
epoch:195, train acc:0.8466666666666667, test acc:0.6919
epoch:196, train acc:0.8533333333333334, test acc:0.6989
epoch:197, train acc:0.8466666666666667, test acc:0.6962
epoch:198, train acc:0.8466666666666667, test acc:0.7009
epoch:199, train acc:0.86, test acc:0.7074
epoch:200, train acc:0.8566666666666667, test acc:0.6976
In [42]:
# 그래프 그리기==========
markers = {'train': 'o', 'test': 's'}
x = np.arange(max_epochs)
plt.plot(x, train_acc_list, marker='o', label='train', markevery=10)
plt.plot(x, test_acc_list, marker='s', label='test', markevery=10)
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()
 
 
  • 오버피팅이 억제되었음
  • 정확도가 100%에 도달하지 못했음
 

6.4.3 드롭아웃

 
  • 드롭아웃 : 뉴런을 임의로 삭제하면서 학습
  • 은닉층의 뉴런을 무작위로 골라 삭제
  • 훈련 떄는 데이터를 흘릴 대마다 삭제할 뉴런을 무작위로 삭제, 시험 때는 모든 뉴런에 신호 전달 (각 뉴런의 출력에 훈련 때 삭제 안 한 비율을 곱하여 출력)
In [43]:
class Dropout:
    def __init__(self, dropout_ratio=0.5):
        self.dropout_ratio = dropout_ratio
        self.mask = None
        
    def forward(self, x, train_flg = True):
        if train_flg:
            self.mask = np.random.rand(*x.shape) > self.dropout_ratio
            return x * self.mask
        else:
            return x * (1.0 - self.dropout_ratio)
    
    def backward(self, dout):
        return dout * self.mask
 

6.5 적절한 하이퍼파라미터 찾기

 

6.5.1 검증 데이터

 
  • 하이퍼파라미터의 성능을 평가할 때는 시험 데이터를 사용해서는 안됨 (시험데이터에만 적합하도록 조정됨)
  • 검증데이터 : 하이퍼파라미터 조정용 데이터
In [45]:
def shuffle_dataset(x, t):
    permutation = np.random.permutation(x.shape[0])
    x = x[permutation, :] if x.ndim == 2 else x[permutation, :, :, :]
    t = t[permutation]

    return x, t
In [50]:
(x_train, t_train), (x_test, t_test) = load_mnist()
x_train, t_train = shuffle_dataset(x_train, t_train)

validation_rate = 0.2
validation_num = int(x_train.shape[0] * validation_rate)

x_val = x_train[:validation_num]
t_val = t_train[:validation_num]
x_train = x_train[validation_num:]
t_train = x_train[validation_num:]
 

6.5.2 하이퍼파라미터 최적화

 
  • 하이퍼파라미터의 최적 값이 존재하는 범위를 조금씩 줄여나감
  • 범위를 줄이려면 우선 대략적인 범위를 설정하고 그 범위에서 무작위로 하이퍼파라미터 값을 골라낸 후, 그 값으로 정확도를 평가
  • 정확도를 잘 살피면서 이 작업을 반복하여 하이퍼파라미터의 최적 값의 범위를 좁혀나감
 
  • 하이퍼파라미터 최적화는 아주 오래걸리기 때문에 학습을 위한 에폭을 작게 하여 1회 평가에 걸리는 시간을 단축하는 것이 효과적
 
  • 0단계 : 하이퍼파라미터의 값의 범위를 설정
  • 1단계 : 설정된 범위에서 하이퍼파라미터의 값을 무작위로 추출
  • 2단계 : 1단계에서 샘플링한 하이퍼퍼라미터 값을 사용하여 학습하고, 검증 데이터로 정확도를 평가 (에폭은 작게 설정)
  • 3단계 : 1단계와 2단계를 특정 횟수 반복하여 그 정확도를 알아보고, 하이퍼파라미터의 범위를 좁힘
 
  • 베이즈 최적화 : 베이즈정리를 중심으로 최적화 수행
 

6.5.3 하이퍼파라미터 최적화 구현하기

In [56]:
import sys
import os
sys.path.append(os.pardir)  # 부모 디렉터리의 파일을 가져올 수 있도록 설정
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net import MultiLayerNet
from common.util import shuffle_dataset
from common.trainer import Trainer

(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)

# 결과를 빠르게 얻기 위해 훈련 데이터를 줄임
x_train = x_train[:500]
t_train = t_train[:500]

# 20%를 검증 데이터로 분할
validation_rate = 0.20
validation_num = int(x_train.shape[0] * validation_rate)
x_train, t_train = shuffle_dataset(x_train, t_train)
x_val = x_train[:validation_num]
t_val = t_train[:validation_num]
x_train = x_train[validation_num:]
t_train = t_train[validation_num:]


def __train(lr, weight_decay, epocs=50):
    network = MultiLayerNet(input_size=784,
                            hidden_size_list=[100, 100, 100, 100, 100, 100],
                            output_size=10, weight_decay_lambda=weight_decay)
    trainer = Trainer(network, x_train, t_train, x_val, t_val,
                      epochs=epocs, mini_batch_size=100,
                      optimizer='sgd',
                      optimizer_param={'lr': lr}, verbose=False)
    trainer.train()

    return trainer.test_acc_list, trainer.train_acc_list


# 하이퍼파라미터 무작위 탐색======================================
optimization_trial = 100
results_val = {}
results_train = {}
for _ in range(optimization_trial):
    # 탐색한 하이퍼파라미터의 범위 지정===============
    weight_decay = 10 ** np.random.uniform(-8, -4)
    lr = 10 ** np.random.uniform(-6, -2)
    # ================================================

    val_acc_list, train_acc_list = __train(lr, weight_decay)
    print("val acc:" + str(val_acc_list[-1]) + " | lr:" + str(lr) + ", weight decay:" + str(weight_decay))
    key = "lr:" + str(lr) + ", weight decay:" + str(weight_decay)
    results_val[key] = val_acc_list
    results_train[key] = train_acc_list
 
val acc:0.15 | lr:1.3064061966199673e-06, weight decay:3.523413975825515e-06
val acc:0.08 | lr:1.9723289195873304e-05, weight decay:5.141522191326748e-06
val acc:0.11 | lr:1.964099851494618e-06, weight decay:7.548978076917676e-05
val acc:0.08 | lr:3.010601554614792e-06, weight decay:1.6403622609140678e-08
val acc:0.25 | lr:5.231029051409021e-06, weight decay:1.3906858938139786e-06
val acc:0.17 | lr:0.0005624879713589749, weight decay:2.8316582369721114e-06
val acc:0.07 | lr:1.1064544425774134e-05, weight decay:5.7161875456100166e-08
val acc:0.13 | lr:6.168881782031404e-05, weight decay:1.1555349861548783e-07
val acc:0.15 | lr:5.1038706943552785e-05, weight decay:1.2187052116419099e-08
val acc:0.06 | lr:1.2149469604064779e-06, weight decay:3.204451755905063e-08
val acc:0.12 | lr:2.647967120408096e-05, weight decay:1.6307474686213448e-06
val acc:0.33 | lr:0.002364053383212174, weight decay:5.744622847314584e-07
val acc:0.15 | lr:4.601440490443153e-06, weight decay:1.3786367303265164e-07
val acc:0.23 | lr:0.001321431365550546, weight decay:3.086147644804013e-05
val acc:0.13 | lr:6.415823754611082e-06, weight decay:1.5855652844963293e-06
val acc:0.67 | lr:0.003956787855427095, weight decay:6.431916325335931e-07
val acc:0.12 | lr:0.00018295567746388544, weight decay:3.274360199680452e-05
val acc:0.05 | lr:2.8246773528035726e-05, weight decay:1.0792966092564235e-08
val acc:0.05 | lr:4.216761318853959e-05, weight decay:2.3160829343726325e-05
val acc:0.15 | lr:0.0008221557951401586, weight decay:2.4600899194607493e-07
val acc:0.14 | lr:3.292811360181099e-05, weight decay:3.437065070795765e-06
val acc:0.12 | lr:0.00020582349509227642, weight decay:2.5310505217400406e-05
val acc:0.11 | lr:0.00031924362302009535, weight decay:9.947908591933742e-08
val acc:0.08 | lr:1.0266005591228873e-05, weight decay:5.609099348067004e-08
val acc:0.63 | lr:0.005748429116209908, weight decay:1.2088315642907163e-07
val acc:0.1 | lr:9.946633130994964e-06, weight decay:1.9426938750068496e-08
val acc:0.16 | lr:5.039315671574693e-05, weight decay:1.6044050862929304e-05
val acc:0.06 | lr:4.496846057227187e-06, weight decay:6.698303588972123e-08
val acc:0.13 | lr:8.441785805774514e-05, weight decay:5.380837311569492e-07
val acc:0.11 | lr:0.0007190633959395153, weight decay:7.653586403854133e-07
val acc:0.21 | lr:0.0015496443423869382, weight decay:1.3267208540107235e-07
val acc:0.15 | lr:1.341145555858737e-05, weight decay:1.830608728622233e-07
val acc:0.11 | lr:2.049194232257206e-05, weight decay:3.47811742302489e-05
val acc:0.5 | lr:0.005396006985609333, weight decay:6.019939134011659e-07
val acc:0.05 | lr:3.5074182689965176e-06, weight decay:2.7263755172440965e-05
val acc:0.23 | lr:0.000619039929250422, weight decay:2.3393962305077777e-06
val acc:0.06 | lr:7.827372759790367e-05, weight decay:2.011095372946977e-07
val acc:0.41 | lr:0.002557427241071804, weight decay:1.793846115128445e-08
val acc:0.08 | lr:9.341295387429083e-06, weight decay:1.1463397425574124e-06
val acc:0.15 | lr:2.5789456361847497e-06, weight decay:4.7692950489936866e-08
val acc:0.19 | lr:0.0008811874887106767, weight decay:1.1285684754381494e-07
val acc:0.09 | lr:4.510365322080007e-05, weight decay:4.275768368372417e-06
val acc:0.1 | lr:3.3581910698348326e-05, weight decay:5.84930287477968e-07
val acc:0.08 | lr:6.85951171325959e-05, weight decay:2.6872143154911565e-06
val acc:0.15 | lr:0.0006701035216747515, weight decay:8.67234465767478e-08
val acc:0.09 | lr:2.3937040400845545e-06, weight decay:9.147874991616181e-07
val acc:0.73 | lr:0.004789492053726677, weight decay:2.4116749997527427e-06
val acc:0.14 | lr:0.0006386290851895926, weight decay:4.593032517922764e-08
val acc:0.57 | lr:0.004632051737538101, weight decay:7.447387587577219e-07
val acc:0.12 | lr:0.00011084164815963858, weight decay:1.8257325313377146e-07
val acc:0.14 | lr:3.589460848820216e-05, weight decay:2.589753466329225e-07
val acc:0.12 | lr:0.00015140972298987668, weight decay:1.806996720943166e-05
val acc:0.22 | lr:0.0011369210195055063, weight decay:3.311644686414794e-08
val acc:0.09 | lr:3.201426162517283e-05, weight decay:2.4040749087747273e-05
val acc:0.1 | lr:0.00015658349763918267, weight decay:1.4623731430829729e-06
val acc:0.36 | lr:0.00175523458124374, weight decay:3.805404250770036e-08
val acc:0.06 | lr:5.582618227014247e-06, weight decay:2.4037984239206284e-05
val acc:0.55 | lr:0.0035177356634564006, weight decay:2.9893098494220664e-07
val acc:0.09 | lr:0.00016453712024085224, weight decay:4.6866552538560094e-05
val acc:0.24 | lr:5.633633869632294e-06, weight decay:5.4777708623415806e-05
val acc:0.6 | lr:0.004490766518383531, weight decay:4.3614065351457263e-08
val acc:0.05 | lr:2.5933777782954918e-06, weight decay:2.3099201425165093e-07
val acc:0.09 | lr:2.7371614700900138e-06, weight decay:1.344270523754482e-07
val acc:0.19 | lr:0.00033680468005940317, weight decay:4.539117553858931e-05
val acc:0.06 | lr:0.00010079829031940394, weight decay:1.8444568161656226e-06
val acc:0.09 | lr:3.3849261076014134e-05, weight decay:3.384622585800945e-06
val acc:0.47 | lr:0.0024834292353355008, weight decay:8.445885135841837e-06
val acc:0.77 | lr:0.009595895583117506, weight decay:4.822567864967486e-08
val acc:0.48 | lr:0.0037796535275257154, weight decay:3.38379710272746e-08
val acc:0.06 | lr:7.123467599592044e-05, weight decay:3.2722216114308133e-07
val acc:0.09 | lr:0.00040590240912440377, weight decay:1.1973614865273034e-06
val acc:0.13 | lr:0.000148126150873015, weight decay:6.14009550847387e-06
val acc:0.19 | lr:0.0005252838245137564, weight decay:5.4898585770196515e-08
val acc:0.22 | lr:0.00040052638478782243, weight decay:3.4905681461397354e-07
val acc:0.2 | lr:0.0008779841616784123, weight decay:8.656573660314793e-05
val acc:0.48 | lr:0.004328265520138318, weight decay:3.7717200337941226e-07
val acc:0.1 | lr:0.00015759861505203532, weight decay:9.486658076482954e-06
val acc:0.58 | lr:0.0043047964253477, weight decay:2.9222811209087445e-05
val acc:0.12 | lr:1.7460996376734702e-05, weight decay:8.190890784323276e-08
val acc:0.79 | lr:0.00963119570917399, weight decay:2.1914663615639613e-06
val acc:0.08 | lr:1.8793261863499837e-06, weight decay:4.6122070660026077e-07
val acc:0.17 | lr:0.0004709374845170836, weight decay:1.1333361601209602e-06
val acc:0.05 | lr:0.0001267704201397086, weight decay:4.2266855771649093e-08
val acc:0.1 | lr:0.00023026478587671493, weight decay:1.4497538410201577e-05
val acc:0.12 | lr:0.0004844216182755952, weight decay:4.437376377574713e-05
val acc:0.16 | lr:0.0007863740743236546, weight decay:1.2361398113507698e-05
val acc:0.1 | lr:1.3883319373679108e-05, weight decay:9.643317905633718e-05
val acc:0.17 | lr:4.471144506159626e-06, weight decay:6.456872651811511e-05
val acc:0.08 | lr:4.042839822824157e-05, weight decay:1.3035533247680704e-06
val acc:0.11 | lr:0.00032318910288822066, weight decay:5.235239511608154e-08
val acc:0.72 | lr:0.008873643748338008, weight decay:3.1828875763226756e-08
val acc:0.1 | lr:0.0011382996201518773, weight decay:3.320539162811981e-05
val acc:0.06 | lr:5.8706353422682814e-06, weight decay:6.5345295496167e-07
val acc:0.11 | lr:7.45699568362334e-06, weight decay:1.5848309598203948e-05
val acc:0.48 | lr:0.005890925782411005, weight decay:2.6970997037997328e-06
val acc:0.51 | lr:0.0046676334629489075, weight decay:2.1379347199777646e-06
val acc:0.13 | lr:7.50302258522357e-06, weight decay:3.478729542184295e-06
val acc:0.06 | lr:2.8106074481510875e-05, weight decay:3.59379326453999e-05
val acc:0.07 | lr:1.8105906973045258e-05, weight decay:9.330028016399396e-08
val acc:0.11 | lr:4.815987745319018e-05, weight decay:3.576291742129845e-05
In [57]:
# 그래프 그리기========================================================
print("=========== Hyper-Parameter Optimization Result ===========")
graph_draw_num = 20
col_num = 5
row_num = int(np.ceil(graph_draw_num / col_num))
i = 0

for key, val_acc_list in sorted(results_val.items(), key=lambda x: x[1][-1], reverse=True):
    print("Best-" + str(i+1) + "(val acc:" + str(val_acc_list[-1]) + ") | " + key)

    plt.subplot(row_num, col_num, i+1)
    plt.title("Best-" + str(i+1))
    plt.ylim(0.0, 1.0)
    if i % 5:
        plt.yticks([])
    plt.xticks([])
    x = np.arange(len(val_acc_list))
    plt.plot(x, val_acc_list)
    plt.plot(x, results_train[key], "--")
    i += 1

    if i >= graph_draw_num:
        break

plt.show()
 
=========== Hyper-Parameter Optimization Result ===========
Best-1(val acc:0.79) | lr:0.00963119570917399, weight decay:2.1914663615639613e-06
Best-2(val acc:0.77) | lr:0.009595895583117506, weight decay:4.822567864967486e-08
Best-3(val acc:0.73) | lr:0.004789492053726677, weight decay:2.4116749997527427e-06
Best-4(val acc:0.72) | lr:0.008873643748338008, weight decay:3.1828875763226756e-08
Best-5(val acc:0.67) | lr:0.003956787855427095, weight decay:6.431916325335931e-07
Best-6(val acc:0.63) | lr:0.005748429116209908, weight decay:1.2088315642907163e-07
Best-7(val acc:0.6) | lr:0.004490766518383531, weight decay:4.3614065351457263e-08
Best-8(val acc:0.58) | lr:0.0043047964253477, weight decay:2.9222811209087445e-05
Best-9(val acc:0.57) | lr:0.004632051737538101, weight decay:7.447387587577219e-07
Best-10(val acc:0.55) | lr:0.0035177356634564006, weight decay:2.9893098494220664e-07
Best-11(val acc:0.51) | lr:0.0046676334629489075, weight decay:2.1379347199777646e-06
Best-12(val acc:0.5) | lr:0.005396006985609333, weight decay:6.019939134011659e-07
Best-13(val acc:0.48) | lr:0.0037796535275257154, weight decay:3.38379710272746e-08
Best-14(val acc:0.48) | lr:0.004328265520138318, weight decay:3.7717200337941226e-07
Best-15(val acc:0.48) | lr:0.005890925782411005, weight decay:2.6970997037997328e-06
Best-16(val acc:0.47) | lr:0.0024834292353355008, weight decay:8.445885135841837e-06
Best-17(val acc:0.41) | lr:0.002557427241071804, weight decay:1.793846115128445e-08
Best-18(val acc:0.36) | lr:0.00175523458124374, weight decay:3.805404250770036e-08
Best-19(val acc:0.33) | lr:0.002364053383212174, weight decay:5.744622847314584e-07
Best-20(val acc:0.25) | lr:5.231029051409021e-06, weight decay:1.3906858938139786e-06
 
 

6.6 정리

  • 매개변수 갱신 방법에는 확률적 경사하강법 외에 모멘텀, AdaGrad, Adam 등이 있음
  • 가중치 초깃값을 정하는 방법은 올바른 학습을 하는 데 매우 중요
  • 가중치 초깃값으로는 Xavier 초깃값과 He 초깃값이 효과적
  • 배치정규화를 이용하면 정규화 기술로는 가중치 감소와 드롭아웃이 있음
  • 하이퍼파라미터 값 탐색은 최적 값이 존재할 법한 범위를 점차 좁히면서 하는 것이 효과적