하나의 예측값만 나오는 문제(Tree 모델)
현상 요약
Tree모델에서 train데이터로 학습하고, train 데이터로 예측하였을 때
예측값이 하나만 나오는 경우가 발생하여 원인을 확인하고자 한다
원인 생각
곰곰히 생각해본 결과 node가 분할되지 않았기 때문이라는 결론을 내렸다
즉, root node에서 분할되지 않았기 때문에(root node만 존재) 어떤 input feature가 들어와도 root node의 평균 값만 내뱉는 것이다
- ※ 위의 현상이 발생한 ensemble 모델의 경우, 각 tree의 root node 평균값에 tree들의 가중치를 취하여 더한 값이 예측 값으로 나올 것이다
그러면 어떠한 경우에 분할이 발생하지 않을까?
각 모델 특성별로 root node에서 분할되지 않는 경우를 생각해 보았다
원인 1
- data point(sample) 수가 node에 존재해야하는 최소 sample 수보다 작은 경우
- 실제 data point가 너무 적거나(데이터 문제)
- 모델에서 node에 존재해야하는 최소 sample 수가 너무 많음(모델 hyper parameter 문제)
원인 2
- 어떠한 조건으로도 분할 시 information gain이 일어나지 않는 경우
- input feature가 노드 분할에 매우 쓸모 없거나
- output 변수가 특정 값에 매우 쏠려있는 경우 (분산이 매우 낮은 경우)
위의 두 원인에 대해 대표적인 두 가지 Tree 모델에서 살펴보아야할 hyper parameter는 다음과 같다
원인1 | 원인2 | |
RandomForest | min_samples_leaf | min_impurity_deacrease |
XGBoost | - | gamma, lambda(reg_lambda) |
참고) XGBoost 에서 Information gain 계산 방법
XGBoost 에서 원인2(information gain)에 해당하는 gamma, lambda에 대해 알아보자
(좋은 영상을 제작한 제작자에게 무한한 감사를 드립니다!!!)
우선 초기(root node)에 모든 output에 base_score(default = 0.5)라는 값을 빼준다. 이것은 global bias이다
정확히 왜 하는지는 아직 잘 모르겠다..
그 이후 input feature를 고르고, 해당 feature의 특정 값을 기준으로 이진 분할을 하며 information gain을 계산한다
information gain은 분할 전 대비 분할 후의 similarity 증가량이다
- 어떤 node의 similarity
- similarity(node) = (node의 output의 합)**2 / (node의 sample 수 + lambda)
- 분할 시 information gain 계산
- sim(분할 후 좌측 node) + sim(분할 후 우측 node) - sim(분할 전 node)
분할 시 information gain이 gamma보다 크면 분할을 한다
즉, gamma는 분할하기 위해 최소 증가해야하는 information gain량을 의미한다. gain이 gamma 이상은 되어야 분할 하겠다는 의미이다
그리고 lambda 는 similarity 계산 시 분모에 들어가게 되고 즉 sim(분할 후 좌측 node) + sim(분할 후 우측 node)의 값을 더 작게 만들어 information gain 자체를 줄여주는 역할을 한다 (regularization)
물론 분할 전 node에 대한 similarity 계산 시에도 분모에 더해지지만 분할 후 좌, 우 node의 similarity 계산 시 분모의 값이 더 작기 때문에 similarity를 감소시키는 lambda의 값의 효과가 더 크다고 볼 수 있다
결과적으로 lambada, gamma 이 두 녀석 모두, 값이 커질 수록 분할을 방지한다. 즉, overfitting 방지하는 parameter라고 할 수 있다
마치며
해당 문제를 접하고 매우 당황하였던 기억이 있네요..ㅎㅎ
혹시 위의 문제가 다른 케이스에서도 발생하는 경우가 있거나 제가 작성한 내용이 잘못된 경우
댓글로 남겨주시면 감사하겠습니다 :)
▼ 글이 도움이 되셨다면 아래 클릭 한번 부탁드립니다 :) ▼