머신러닝

[분류] 랜덤포레스트 RandomForest

GinaKim 2024. 3. 3. 19:45
728x90

RandomForest

여러개의 결정 트리 분류기가 전체 데이터에서 배깅 방식으로 각자의 데이터를 샘플링 해 개별적으로 학습을 수행한 뒤 최종적으로 모든 분류기가 보팅을 통해 예측 결정을 하게 됨

분류에서는 다수결(하드보팅)로 최종 결과를 구하지만 회귀에서는 평균 또는 중앙값(소프트 보팅)을 구하는 방법을 사용

 

 

RandomForest 하이퍼파라미터

  • n_estimators: 랜덤 포레스트에서 결정 트리의 개수를 지정(디폴트는 10개) 개수 늘릴수록 학습 수행 시간이 오래 걸림
  • max_features: 결정 트리에 사용된 max_features와 같음 (디폴트는 none이 아닌 sqrt) 
  • n_jobs: cpu 코어수 지정. n_jobs=-1 경우 컴퓨터의 모든 코어를 사용(사용하는 CPU 코어 개수에 비례해서 속도도 빨라짐)
  • max_depth, min_samples_split 등..

RandomForest 실습

1. 랜덤포레스트로 학습 및 별도의 테스트 셋으로 예측 성능 평가

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# 결정 트리에서 사용한 get_human_dataset( )을 이용해 학습/테스트용 DataFrame 반환
X_train, X_test, y_train, y_test = get_human_dataset()

# 랜덤 포레스트 학습 및 별도의 테스트 셋으로 예측 성능 평가
rf_clf = RandomForestClassifier(n_estimators=100, random_state=0, max_depth=8)
rf_clf.fit(X_train , y_train)
pred = rf_clf.predict(X_test)
accuracy = accuracy_score(y_test , pred)
print('랜덤 포레스트 정확도: {0:.4f}'.format(accuracy))

 

2. GridSearchCV로 최적의 하이퍼 파라미터 찾고 학습 및 예측

from sklearn.model_selection import GridSearchCV

params = {
    'max_depth': [8, 16, 24],
    'min_samples_leaf' : [1, 6, 12],
    'min_samples_split' : [2, 8, 16]
}
# RandomForestClassifier 객체 생성 후 GridSearchCV 수행
rf_clf = RandomForestClassifier(n_estimators=100, random_state=0, n_jobs=-1)
grid_cv = GridSearchCV(rf_clf , param_grid=params , cv=2, n_jobs=-1 )
grid_cv.fit(X_train , y_train)

print('최적 하이퍼 파라미터:\n', grid_cv.best_params_)
print('최고 예측 정확도: {0:.4f}'.format(grid_cv.best_score_))

 

2-1. 최적의 파라미터 값으로 랜덤포레스트 학습 및 예측 정확도 확인

rf_clf1 = RandomForestClassifier(n_estimators=100,  min_samples_leaf=6, max_depth=16,
                                 min_samples_split=2, random_state=0)
rf_clf1.fit(X_train , y_train)
pred = rf_clf1.predict(X_test)
print('예측 정확도: {0:.4f}'.format(accuracy_score(y_test , pred)))

📌GridSearchCV로 최적의 파라미터 찾은 후 학습 및 예측했을 때, 정확도 상승

 

2-2. 피처 중요도 확인

ftr_importances_values = rf_clf1.feature_importances_
ftr_importances = pd.Series(ftr_importances_values,index=X_train.columns)
ftr_importances.sort_values(ascending=False)[:20]
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

ftr_importances_values = rf_clf1.feature_importances_
ftr_importances = pd.Series(ftr_importances_values,index=X_train.columns  )
ftr_top20 = ftr_importances.sort_values(ascending=False)[:20]

plt.figure(figsize=(8,6))
plt.title('Feature importances Top 20')
sns.barplot(x=ftr_top20 , y = ftr_top20.index)
plt.show()

728x90