티스토리 뷰
Machine Learning in Stata
(Cameron의 노트를 정리한 것입니다)
machine learning은 간단히 보면 $x$ 가 주어졌을 때 $y$ 를 예측하는 알고리듬으로 볼 수 있다.
특별히, 주어진 자료에 기초하여 예측하는 알고리듬으로 정리할 수 있다.
ML은 예측하는 모형인 것인지 인과적 효과를 검정하는 모형은 아니다.
특별히 supervised learning은 기본적으로 regression 으로 이해할 수 있다.
예측을 잘 하는 모형을 선택하는 알고리듬이 필요하다.
여러가지 방법이 있을 수 있는데, 많이 사용되는 방법은 크게 두 가지 정도로 나눠볼 수 있다.
- penalty measures: Mallows, AIC, BIC 등
- cross validation
CV의 경우
- 데이터를 training set 과 test set 을 구분하여
- training set에서는 모형을 추정하고
- 추정된 모형을 이용하여 test set에서 예측을 한 후
- 예측오차가 가장 작은 모형을 선택하는 알고리듬이다
모형의 예측오차는 보통 MSE(mean squared error)를 이용하여 측정한다.
$$ err = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 $$
몇 가지 실습을 통해서 CV를 해보자.
. clear
. qui set obs 40
. set seed 10101
. gen x1 = _n - mod(_n+1,2)
. foreach i of num 2/4 {
2. gen x`i' = x1^`i'
3. }
. gen y = 2 + 0.1*(x1-20)^2 + rnormal(0,10)
. tw (scatter y x1) (lfitci y x1)
그러면 아래와 같은 산포도를 얻을 수 있다.
다음으로는 training set과 test set 으로 구분해보자.
. gen dtrain = (mod(_n, 2) == 1)
. reg y x1-x4 if dtrain
Source | SS df MS Number of obs =
> 20
-------------+---------------------------------- F(4, 15) =
> 9.74
Model | 3784.49406 4 946.123515 Prob > F = 0.
> 0004
Residual | 1457.28188 15 97.1521252 R-squared = 0.
> 7220
-------------+---------------------------------- Adj R-squared = 0.
> 6479
Total | 5241.77594 19 275.882944 Root MSE = 9.
> 8566
--------------------------------------------------------------------------
> ----
y | Coef. Std. Err. t P>|t| [95% Conf. Inter
> val]
-------------+------------------------------------------------------------
> ----
x1 | -.3544394 3.969393 -0.09 0.930 -8.815001 8.10
> 6123
x2 | -.4518467 .4031622 -1.12 0.280 -1.311167 .407
> 4731
x3 | .0242291 .015139 1.60 0.130 -.0080388 .056
> 4971
x4 | -.0003269 .0001878 -1.74 0.102 -.0007272 .000
> 0734
_cons | 41.78578 11.40795 3.66 0.002 17.47031 66.1
> 0125
--------------------------------------------------------------------------
> ----
. predict yhat_train
(option xb assumed; fitted values)
. tw (scatter y x1 if dtrain) (line yhat_train x1, sort color(blue)), savi
> ng(train, replace) legend(pos(6) row(1)) title("training set")
(file train.gph saved)
. tw (scatter y x1 if !dtrain) (line yhat_train x1, sort color(blue)), sa
> ving(non_train, replace) legend(pos(6) row(1)) title("test set")
(file non_train.gph saved)
. graph combine train.gph non_train.gph
그러면 아래와 같은 그림을 얻을 수 있다. training set에서 test set보다 오차가 작음을 확인할 수 있다.
다항식을 이용한 선형 모형에서 몇 번째 항까지 포함시키는 것이 좋은지 CV를 통해서 확인할 수 있다.
. forvalues k = 1/4 {
2. qui reg y x1-x`k' if dtrain
3. qui predict y`k'hat
4. qui gen y`k'errorsq = (y - y`k'hat)^2
5. qui sum y`k'errorsq if dtrain == 1
6. qui scalar mse`k'train = r(mean)
7. qui sum y`k'errorsq if dtrain == 0
8. qui scalar mse`k'test = r(mean)
9. }
. di _n "MSE linear Train = " mse1train " Test = " mse1test _n ///
> "MSE quadratic Train = " mse2train " Test = " mse2test
> _n ///
> "MSE cubic Train = " mse3train " Test = " m
> se3test _n ///
> "MSE quartic Train = " mse4train " Test = " mse4test
> _n
MSE linear Train = 252.32258 Test = 412.98285
MSE quadratic Train = 92.781786 Test = 184.43114
MSE cubic Train = 87.577254 Test = 208.24569
MSE quartic Train = 72.864095 Test = 207.78885
위 결과에 따르면 2차항까지 포함시켰을 때가 MSE가 가장 작음을 확인할 수 있다.
한편, crossfold를 이용하면 k-fold CV를 할 수 있다.
. forvalues k = 1/4 {
2. qui set seed 12345
3. qui crossfold reg y x1-x`k'
4. qui matrix RMSEs`k' = r(est)
5. qui svmat RMSEs`k', names(rmses`k')
6. qui sum rmses`k'
7. qui scalar cv`k' = r(mean)
8. }
. di _n "CV(5) for k = 1,2,3,4 is " cv1 ", " cv2 ", " cv3 ", " cv4
CV(5) for k = 1,2,3,4 is 18.798389, 11.847123, 11.96073, 13.002836
위의 결과 역시 CV가 2차항까지 포함시켰을 때 가장 작음을 알 수 있다.
한편, AIC 혹은 BIC 역시 계산할 수 있다. AIC의 경우 다음과 같이 정의된다.
$$ AIC = -2\ln(ll) + 2k $$ $$ BIC = -2\ln(ll) + \ln(n)k $$
. forvalues k = 1/4 {
2. qui reg y x1-x`k'
3. qui scalar aic`k' = -2*e(ll) + 2*e(rank)
4. qui scalar bic`k' = -2*e(ll) + ln(e(N))*e(rank)
5. }
. di _n "AIC for k = 1,...,4 = " aic1 ", " aic2 ", " aic3 ", " aic4 ,
> ///
> _n "BIC for k = 1,...,4 = " bic1 ", " bic2 ", " bic3 ", " bic4
AIC for k = 1,...,4 = 348.99841, 314.26217, 316.01317, 315.3112
BIC for k = 1,...,4 = 352.37617, 319.32881, 322.76869, 323.7556
'STATA' 카테고리의 다른 글
회귀선과 잔차 (0) | 2019.03.08 |
---|---|
신뢰구간이란 무엇인가? (2) | 2018.10.31 |
명령어 자동화 (0) | 2018.10.10 |
Cheat Sheet for Policy Evaluation (1) | 2018.01.29 |
정책분석을 위한 STATA 출간 (0) | 2018.01.22 |