Practice makes perfect!

[모두의 딥러닝] 10장 모델 설계하기 본문

Study/딥러닝

[모두의 딥러닝] 10장 모델 설계하기

na0dev 2021. 4. 17. 17:59

* 모델 정의

#딥러닝을 구동하는 데 필요한 케라스 함수 호출
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

#필요한 라이브러리 불러옴
import numpy as np
import tensorflow as tf

#실행할 때 마다 같은 결과를 출력하기 위한 설정
np.random.seed(3)
tf.random.set_seed(3)

#준비된 수술 환자 데이터 불러옴
Data_set = np.loadtxt("../dataset/ThoraricSurgery.csv",delimiter=",")

#환자의 기록과 수술 결과를 X와 Y로 구분하여 저장
X = Data_set[:,0:17]
Y = Data_set[:,17]

모두의 딥러닝

X는 정보1 ~ 정보 17, Y는 생존 여부 값을 갖는다.

 

#딥러닝 구조를 짜고 층을 설정하는 부분
model = Sequential() # 퍼셉트론 위에 숨겨진 퍼셉트론 층이 추가된 형태를 구현하기 위한 함수

#은닉층. 30개의 노드, 입력데이터에서 값을 17개 가져옴, relu를 활성화 함수로 사용 
model.add(Dense(30,input_dim=17,activation='relu'))
#출력층. 출력값을 하나로 정해 보여줘야하기 때문에 노드는 1개, sigmoid를 활성화 함수로 사용
model.add(Dense(1, activation='sigmoid')) 

# 딥러닝의 여러 개의 층 중에서 맨 마지막 층은 출력층이 되고 나머지는 은닉층의 역할을 한다.

# 위에서 두 개의 층을 만들었기 때문에 각각 은닉층과 출력층이다

 

# 위에서 정해진 모델을 컴퓨터가 알아들을 수 있게 컴파일하는 부분
model.compile(loss='mean_squared_error',optimizer='adam',metrics=['accuracy'])

# 모델 실행
# 각 샘플이 100번 재사용될 때까지 실행 반복, 샘플을 한번에 10개씩 처리
model.fit(X,Y,epochs=100,batch_size=10) 

# model.compile은 모델이 효과적으로 구현되도록 여러가지 환경을 설정해주면서 컴파일하는 부분이다.

 

  1. 어떤 오차 함수를 사용할 지 정한다.
    • 평균 제곱 오차 계열 함수 : 수렴하기까지 오래걸림.
    • 교차 엔트로피 계열 함수 : 출력 값에 로그를 취해 오차가 커지면 수렴 속도가 빨라지고, 오차가 작아지면 수렴 속도가 감소하게 만든 것
  2. 최적화 방법을 정한다.

# 오차함수로 '평균제곱오차함수(mean_squared_error)', 최적화 방법으로 'adam'을 사용한다.
# metrics() 함수는 모델이 컴파일 될 때 모델 수행 결과를 나타내도록 설정하는 부분이다. (과적합 문제 방지)

 

# 학습 프로세스가 모든 샘플에 대해 한번 실행되는 것을 1 epoch라 한다.
# 샘플을 한번에 몇 개씩 처리할지 정하기 위해 batch_size 이용한다. batch_size가 너무 크면 학습 속도가 느리고, 너무 작으면 결괏값이 불안정해질 수 있다.

 

 


* 오차 함수

 

- 교차 엔트로피 계열 : 분류 문제에서 많이 사용됨

  • categorical_crossentropy (범주형 교차 엔트로피) : 일반적인 분류
  • binary_crossentropy (이항 교차 엔트로피) : 두 개의 클래스 중에서 예측할 때 (ex-예측 값이 참, 거짓 둘 중 하나일 때 ) 
    • 위에서 예측 값은 생존(1) 또는 사망(0) 둘 중 하나이므로 binary_crossentropy를 사용하면 좋다.

 

[mean_squared_error 이용]

 

[binary_crossentropy 이용]

▷ binary_crossentropy 사용하면 accuracy가 약간 향상될거라고 했는데 오히려 떨어졌다,,, 왜지?!

 

 

 

- 평균 제곱 계열

  • mean_squared_error (평균 제곱 오차)
  • mean_absolute_error (평균 절대 오차)
  • mean_absolute_percentage_error (평균 절대 백분율)
  • mean_squared_logarithmic_error (평균 제곱 로그 오차)
반응형
Comments