어제 보았던 로지스틱 회귀 모델은 여러 개의 이진 분류기를 훈련시켜 연결하지 않고, 직접 다중 클래스를 지원하도록 일반화할 수 있습니다. 이를, 소프트맥스 회귀(Softmax Regression) 또는 다항 로지스틱 회귀(Multinomial Logistic Regression)라고 합니다.
개념은 샘플 x가 주어지면, 소프트맥스 회귀 모델이 각 클래스 k에 대한 점수 를 계산하고, 그 점수에 소프트맥스 함수(softmax function, 또는 정규화된 지수 함수(normalized exponential)라 부름) 를 적용하여 각 클래스의 확률을 추정합니다. 를 계산하는 식은 선형 회귀 예측을 위한 식과 매우 비슷합니다.
[클래스 k에 대한 소프트맥스 점수]
각 클래스는 자신만의 파라미터 벡터 가 있습니다. 이 벡터들은 파라미터 행렬(parameter matrix) 에 행으로 저장됩니다.
샘플 x에 대해 각 클래스의 점수가 계산되면, 소프트맥스 함수를 통과시켜 클래스 k에 속할 확률 을 추정할 수 있습니다. 이 함수는 각 점수에 지수 함수를 적용한 후, 정규화합니다. (모든 지수 함수 결과의 합으로 나눕니다.)
[소프트맥스 함수]
* k = 클래스 수
* s(x) 는 샘플 x에 대한 각 클래스의 점수를 담고 있는 벡터
* 는 샘플 x에 대한 각 클래스의 점수가 주어졌을 때 이 샘플이 클래스 k에 속할 추정 확률
로지스틱 회귀 분류기와 마찬가지로 소프트맥스 회귀 분류기는 추정 확률이 가장 높은 클래스(가장 높은 점수를 가진 클래스)를 선택합니다.
[소프트맥스 회귀 분류기의 예측]
*argmax 연산은 함수를 최대화하는 변수의 값을 반환합니다. 이 식에서는 추정확률 가 최대인 k 값을 반환합니다.
[TIP]
소프트맥스 회귀 분류기는 한 번에 하나의 클래스만 예측합니다. 즉, 다중 클래스이지 다중 출력이 아닙니다. 그래서 종류가 다른 붓꽃같이 상호 배타적인 클래스에서만 사용해야 합니다. (하나의 사진서 여러 사람의 얼굴을 인식하는 데는 사용할 수 없습니다.)
소프트맥스 회귀의 훈련 방법은 모델이 타깃 클래스에 대해서 높은 확률을 추정하도록 만드는 것이 목적입니다.
아래의 크로스 엔트로피(cross entropy) 비용 함수를 최소화하는 것은 타깃 클래스에 대해 낮은 확률을 예측하는 모델을 억제하므로 이이 목적에 부합합니다. 크로스 엔트로피는 추정된 클래스의 확률이 타깃 클래스에 얼마나 잘 맞는지 측정하는 용도로 종종 사용됩니다.
[크로스 엔트로피 비용 함수]
* i번째 샘플에 대한 타깃 클래스가 k일 때, 가 1이고 그 외에는 0입니다.
* 딱 2개의 클래스가 있을 때 (K=2), 이 비용 함수는 로지스틱 회귀 비용 함수와 같습니다.
[클래스 K에 대한 크로스 엔트로피의 그래디언트 벡터]
이제 각 클래스에 대한 그레디언트 벡터를 계산할 수 있으므로 비용 함수를 최소화하기 위한 파라미터 행렬 Theta를 찾기 위해 경사 하강법 (또는 다른 최적화 알고리즘)을 사용할 수 있습니다.
이제 아래에서 소프트맥스 회귀를 사용해 붓꽃을 세 개의 클래스로 분류해보겠습니다.
사이킷런의 LogisticRegression은 클래스가 둘 이상일 때 기본적으로 일대다(OvA) 전략을 사용합니다.
하지만 multi_class 매개변수를 multinomial로 바꾸면 소프트맥스 회귀를 사용할 수 있습니다.
소프트맥스 회귀를 사용하려면 solver 매개변수에 lbfgs와 같이 소프트맥스 회귀를 지원하는 알고리즘을 지정해야 합니다.
또한, 기본적으로 하이퍼파라미터 C를 사용하여 조절할 수 있는 규제가 적용됩니다.
######크로스 엔트로피?######
크로스 엔트로피는 원래 정보 이론에서 유래했습니다. 매일 날씨 정보를 효율적으로 전달하려 한다고 가정합시다.
8가지 정보(맑음, 비 등)이 있다면, 총2의 3제곱이 8이므로 3비트를 사용하여 이 정보들을 인코딩 할 수 있습니다.
그러나 거의 대부분 날이 맑음이라면 맑음을 하나의 비트 (0)으로 인코딩하고, 다른 7개의 선택사항을 1로 시작하는 4비트로 표현하는 것이 효율적입니다. 크로스 엔트로피는 선택사항마다 전송한 평균 비트 수를 측정합니다. 날씨에 대한 가정이 완벽하다면 크로스 엔트로피는 날씨 자체의 엔트로피와 동일할 것입니다. 하지만, 이 예측이 틀려 비가 자주 온다면, 크로스 엔트로피는 쿨백-라이블러 발산(Kullback-Leibler divergence)이라 불리는 양만큼 커질 것입니다.
두 확률 분포 p와 q사이의 크로스 엔트로피는 로 정의합니다.
(ex. 맑은은 1비트, 다른 날씨는 4비트로 전송된다고 하고 맑은 날의 비율이 80%라면, 평균 전송 비트 수는 1.6이 됩니다.
하지만 맑은 날의 비율이 50%라면 평균 전송 비트 수는 2.5로 늘어납니다. 이 두 엔트로피의 차이가 이상적인 확률 분포와 이에 근사하는 확률 분포 사이의 차이를 나타내는 쿨백-라이블러 발산입니다. 이와 비슷한 예가 구글 브레인 팀의 머신러닝 연구원인 크리스토퍼 올라(Christopher Olah)의 블로그에 나와있습니다. http://colah.github.io/posts/2015-09-Visual-Information/
####################
#소프트맥스 회귀
X = iris['data'][: , (2, 3)] #Petal Length , #Petal Width
y = iris['target']
softmax_reg = LogisticRegression(multi_class ='multinomial', solver = 'lbfgs', C=10, random_state=42)
softmax_reg.fit(X,y)
x0, x1 = np.meshgrid(
np.linspace(0, 8, 500).reshape(-1, 1),
np.linspace(0, 3.5, 200).reshape(-1, 1),
)
X_new = np.c_[x0.ravel(), x1.ravel()]
y_proba = softmax_reg.predict_proba(X_new)
y_predict = softmax_reg.predict(X_new)
zz1 = y_proba[:, 1].reshape(x0.shape)
zz = y_predict.reshape(x0.shape)
plt.figure(figsize=(10,4))
plt.plot(X[y==2, 0], X[y==2, 1], 'g^', label='Iris-Verginica')
plt.plot(X[y==1, 0], X[y==1, 1], 'bs', label='Iris-Versicolor')
plt.plot(X[y==0, 0], X[y==0, 1], 'yo', label='Iris-Setosa')
from matplotlib.colors import ListedColormap
custom_cmap = ListedColormap(['#fafab0', '#9898ff', '#a0faa0'])
plt.contourf(x0, x1, zz, cmap=custom_cmap)
contour = plt.contour(x0, x1, zz1, cmap=plt.cm.brg)
plt.clabel(contour, inline=1, fontsize=12)
plt.xlabel('Petal Length', fontsize=14)
plt.ylabel('Petal Width ',fontsize=14, rotation=0)
plt.legend(loc='center left', fontsize=14)
plt.axis([0,7, 0, 3.5])
plt.show()
softmax_reg.predict([[5,2]])
#꽃잎의 길이가 5, 너비가 2인 붓꽃을 예측하면 94.2%확률로 class 2 (Irirs-Virginica) 또는, 5.8% 확률로 Iris-Versicolor로 예측
array([2])
softmax_reg.predict_proba([[5,2]])
array([[6.38014896e-07, 5.74929995e-02, 9.42506362e-01]])
이 모델은 추정확률 50% 이하인 클래스도 예측할 수 있다는 점에 주목하면서, 소프트맥스 회귀에 대해 포스팅을 마치겠습니다.
블로그
출처
이 글의 상당 부분은 [핸즈온 머신러닝, 한빛미디어/오렐리앙 제롱/박해선] 서적을 참고하였습니다.
나머지는 부수적인 함수나 메서드에 대해 부족한 설명을 적어두었습니다.
학습용으로 포스팅 하는 것이기 때문에 복제보다는 머신러닝에 관심이 있다면 구매해보시길 추천합니다.
도움이 되셨다면 로그인 없이 가능한
아래 하트♥공감 버튼을 꾹 눌러주세요!
'## 오래된 게시글 (미관리) ## > Python (Linux)' 카테고리의 다른 글
35. Python - Matplotlib 한글 설정 (환경 설정으로 고정), 마이너스 깨짐 (0) | 2019.01.31 |
---|---|
34. Python - 4장 연습문제 (0) | 2019.01.30 |
32. Python - 로지스틱 회귀 (0) | 2019.01.28 |
31. Python - 규제(릿지 회귀, 라쏘 회귀, 엘라스틱넷, 조기종료) (0) | 2019.01.27 |
30. Python - 다항 회귀 (0) | 2019.01.25 |