티스토리 뷰

STATA

cross-validation in Stata

비조 2018. 10. 22. 15:59

Machine Learning in Stata

(Cameron의 노트를 정리한 것입니다)

machine learning은 간단히 보면 $x$ 가 주어졌을 때 $y$ 를 예측하는 알고리듬으로 볼 수 있다.

특별히, 주어진 자료에 기초하여 예측하는 알고리듬으로 정리할 수 있다.

ML은 예측하는 모형인 것인지 인과적 효과를 검정하는 모형은 아니다.

특별히 supervised learning은 기본적으로 regression 으로 이해할 수 있다.

예측을 잘 하는 모형을 선택하는 알고리듬이 필요하다.

여러가지 방법이 있을 수 있는데, 많이 사용되는 방법은 크게 두 가지 정도로 나눠볼 수 있다.

  • penalty measures: Mallows, AIC, BIC 등
  • cross validation

CV의 경우

  1. 데이터를 training set 과 test set 을 구분하여
  2. training set에서는 모형을 추정하고
  3. 추정된 모형을 이용하여 test set에서 예측을 한 후
  4. 예측오차가 가장 작은 모형을 선택하는 알고리듬이다

모형의 예측오차는 보통 MSE(mean squared error)를 이용하여 측정한다.

$$ err = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 $$

몇 가지 실습을 통해서 CV를 해보자.

. clear

. qui set obs 40

. set seed 10101

. gen x1 = _n - mod(_n+1,2)

. foreach i of num 2/4 {
  2.         gen x`i' = x1^`i'
  3. }

. gen y = 2 + 0.1*(x1-20)^2 + rnormal(0,10)

. tw (scatter y x1) (lfitci y x1)

그러면 아래와 같은 산포도를 얻을 수 있다.

link

다음으로는 training set과 test set 으로 구분해보자.

. gen dtrain = (mod(_n, 2) == 1)

. reg y x1-x4 if dtrain

      Source |       SS           df       MS      Number of obs   =      
>   20
-------------+----------------------------------   F(4, 15)        =      
> 9.74
       Model |  3784.49406         4  946.123515   Prob > F        =    0.
> 0004
    Residual |  1457.28188        15  97.1521252   R-squared       =    0.
> 7220
-------------+----------------------------------   Adj R-squared   =    0.
> 6479
       Total |  5241.77594        19  275.882944   Root MSE        =    9.
> 8566

--------------------------------------------------------------------------
> ----
           y |      Coef.   Std. Err.      t    P>|t|     [95% Conf. Inter
> val]
-------------+------------------------------------------------------------
> ----
          x1 |  -.3544394   3.969393    -0.09   0.930    -8.815001    8.10
> 6123
          x2 |  -.4518467   .4031622    -1.12   0.280    -1.311167    .407
> 4731
          x3 |   .0242291    .015139     1.60   0.130    -.0080388    .056
> 4971
          x4 |  -.0003269   .0001878    -1.74   0.102    -.0007272    .000
> 0734
       _cons |   41.78578   11.40795     3.66   0.002     17.47031    66.1
> 0125
--------------------------------------------------------------------------
> ----

. predict yhat_train 
(option xb assumed; fitted values)

. tw (scatter y x1 if dtrain) (line yhat_train x1, sort color(blue)), savi
> ng(train, replace) legend(pos(6) row(1)) title("training set")
(file train.gph saved)

. tw (scatter y x1 if !dtrain) (line yhat_train x1, sort  color(blue)), sa
> ving(non_train, replace) legend(pos(6)  row(1))  title("test set")
(file non_train.gph saved)

. graph combine train.gph non_train.gph

그러면 아래와 같은 그림을 얻을 수 있다. training set에서 test set보다 오차가 작음을 확인할 수 있다.

link

다항식을 이용한 선형 모형에서 몇 번째 항까지 포함시키는 것이 좋은지 CV를 통해서 확인할 수 있다.

. forvalues k = 1/4 {
  2.         qui reg y x1-x`k' if dtrain
  3.         qui predict y`k'hat
  4.         qui gen y`k'errorsq = (y - y`k'hat)^2
  5.         qui sum y`k'errorsq if dtrain == 1
  6.         qui scalar mse`k'train = r(mean)
  7.         qui sum y`k'errorsq if dtrain == 0
  8.         qui scalar mse`k'test = r(mean)
  9. }

. di _n   "MSE linear     Train = " mse1train " Test = " mse1test _n ///
>                 "MSE quadratic  Train = " mse2train " Test = " mse2test 
> _n ///
>                 "MSE cubic              Train = " mse3train " Test = " m
> se3test _n ///
>                 "MSE quartic    Train = " mse4train " Test = " mse4test 
> _n 

MSE linear     Train = 252.32258 Test = 412.98285
MSE quadratic  Train = 92.781786 Test = 184.43114
MSE cubic              Train = 87.577254 Test = 208.24569
MSE quartic    Train = 72.864095 Test = 207.78885


위 결과에 따르면 2차항까지 포함시켰을 때가 MSE가 가장 작음을 확인할 수 있다.

한편, crossfold를 이용하면 k-fold CV를 할 수 있다.

. forvalues k = 1/4 {
  2.  qui set seed 12345
  3.  qui crossfold reg y x1-x`k'
  4.  qui matrix RMSEs`k' = r(est)
  5.  qui svmat RMSEs`k', names(rmses`k')
  6.  qui sum rmses`k'
  7.  qui scalar cv`k' = r(mean)
  8. }

. di _n "CV(5) for k = 1,2,3,4 is " cv1 ", " cv2 ", " cv3 ", " cv4

CV(5) for k = 1,2,3,4 is 18.798389, 11.847123, 11.96073, 13.002836

위의 결과 역시 CV가 2차항까지 포함시켰을 때 가장 작음을 알 수 있다.

한편, AIC 혹은 BIC 역시 계산할 수 있다. AIC의 경우 다음과 같이 정의된다.

$$ AIC = -2\ln(ll) + 2k $$ $$ BIC = -2\ln(ll) + \ln(n)k $$

. forvalues k = 1/4 {
  2.         qui reg y x1-x`k'
  3.         qui scalar aic`k' = -2*e(ll) + 2*e(rank)
  4.         qui scalar bic`k' = -2*e(ll) + ln(e(N))*e(rank)
  5. }

. di      _n "AIC for k = 1,...,4 = " aic1 ", " aic2 ", " aic3 ", " aic4 ,
>  ///
>         _n "BIC for k = 1,...,4 = " bic1 ", " bic2 ", " bic3 ", " bic4 

AIC for k = 1,...,4 = 348.99841, 314.26217, 316.01317, 315.3112 
BIC for k = 1,...,4 = 352.37617, 319.32881, 322.76869, 323.7556


'STATA' 카테고리의 다른 글

회귀선과 잔차  (0) 2019.03.08
신뢰구간이란 무엇인가?  (2) 2018.10.31
명령어 자동화  (0) 2018.10.10
Cheat Sheet for Policy Evaluation  (1) 2018.01.29
정책분석을 위한 STATA 출간  (0) 2018.01.22
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/04   »
1 2 3 4 5 6
7 8 9 10 11 12 13
14 15 16 17 18 19 20
21 22 23 24 25 26 27
28 29 30
글 보관함