[머신러닝] K-폴드 교차 검증, stratified K폴드(층화추출)

2024. 2. 14. 21:25·머신러닝
728x90

K-폴드 교차 검증

K개의 데이터 폴드 세트를 만들어서 K번만큼 각 폴드 세트에 학습과 검증 평가를 반복적으로 수행하는 방법

출처 : [Must Have] 머신러닝 딥러닝 문제해결 전략 https://wikidocs.net/223699

 

1) 붓꽃 데이터 세트와 DecisionTreeClassifier 생성 후, 5개의 폴드 세트로 분리하는 KFold 객체 생성

from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
import numpy as np

iris = load_iris()
features = iris.data #독립변수
label = iris.target #종속변수
dt_clf = DecisionTreeClassifier(random_state=156)

# 5개의 폴드 세트로 분리하는 KFold 객체와 폴드 세트별 정확도를 담을 리스트 객체 생성.
kfold = KFold(n_splits=5)

cv_accuracy = []
print('붓꽃 데이터 세트 크기:',features.shape[0])

2) 생성된 KFold 객체의 split()을 호출해 전체 붗꽃 데이터를 5개의 폴드 데이터 세트로 분리

split() : 학습용/검증용 데이터로 분할할 수 있는 인덱스 반환

n_iter = 0

# KFold객체의 split( ) 호출하면 폴드 별 학습용, 검증용 테스트의 로우 인덱스를 array로 반환  
for train_index, test_index in kfold.split(features):
    
    # kfold.split( )으로 반환된 인덱스를 이용하여 학습용, 검증용 테스트 데이터 추출
    X_train, X_test = features[train_index], features[test_index]
    y_train, y_test = label[train_index], label[test_index]
    
    #학습 및 예측 
    dt_clf.fit(X_train , y_train)    
    pred = dt_clf.predict(X_test)
    n_iter += 1
    
    # 반복 시 마다 정확도 측정 
    accuracy = np.round(accuracy_score(y_test,pred), 4)
    train_size = X_train.shape[0]
    test_size = X_test.shape[0]
    print('\n#{0} 교차 검증 정확도 :{1}, 학습 데이터 크기: {2}, 검증 데이터 크기: {3}'
          .format(n_iter, accuracy, train_size, test_size))
    print('#{0} 검증 세트 인덱스:{1}'.format(n_iter,test_index))
    cv_accuracy.append(accuracy)
    
# 개별 iteration별 정확도를 합하여 평균 정확도 계산 
print('\n## 평균 검증 정확도:', np.mean(cv_accuracy))

 

Stratified K 폴드

종속변수(label)가 불균형한(imbalanced) 분포도를 가진 데이터 집합을 위한 K폴드 방식

불균형한 분포도를 가진 데이터 집합은 특정 레이블 값이 특이하게 많거나 매우 적어서 값의 분포가 한쪽으로 치우치는 것

 

1. K폴드의 문제점

1) K폴드는 레이블 데이터 집합이 원본 데이터 집합의 레이블 분포를 학습 및 테스트 세트에 제대로 분배하지 못함

import pandas as pd
iris = load_iris()
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df['label']=iris.target
iris_df['label'].value_counts()

 

우선 데이터 세트를 확인해 봤을 때, 레이블 값 0, 1, 2는 모두 50개로 동일함

 

2) KFold로 검증 

kfold = KFold(n_splits=3)
# kfold.split(X)는 폴드 세트를 3번 반복할 때마다 달라지는 학습/테스트 용 데이터 로우 인덱스 번호 반환. 
n_iter =0
for train_index, test_index  in kfold.split(iris_df):
    n_iter += 1
    label_train= iris_df['label'].iloc[train_index]
    label_test= iris_df['label'].iloc[test_index]
    print('## 교차 검증: {0}'.format(n_iter))
    print('학습 레이블 데이터 분포:\n', label_train.value_counts())
    print('검증 레이블 데이터 분포:\n', label_test.value_counts())

 

검증 결과 KFold로 분할된 레이블 데이터 세트가 전체 레이블 값의 분포도를 반영하지 못하고 있다는 것을 확인 함

 

2. K폴드의 문제점을 보완한 것이 StratifiedKFold

동일한 데이터 분할을 StratifiedKFold로 수행하고 학습/검증 데이터를 확인

# 층화추출 방식으로 데이터셋 분리

from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=3)
n_iter=0

for train_index, test_index in skf.split(iris_df, iris_df['label']):
    n_iter += 1
    label_train= iris_df['label'].iloc[train_index]
    label_test= iris_df['label'].iloc[test_index]
    print('## 교차 검증: {0}'.format(n_iter))
    print('학습 레이블 데이터 분포:\n', label_train.value_counts())
    print('검증 레이블 데이터 분포:\n', label_test.value_counts())

 

학습 레이블과 검증 레이블 데이터 값의 분포도가 거의 동일하게 할당되었음을 확인

 

✍️Stratified K폴드는 원본 데이터의 레이블 분포도 특성을 반영한 학습 및 검증 데이터 세트를 만들 수 있으므로 왜곡된 레이블 데이터 세트에서는 반드시 Stratified K폴드를 이용해 교차 검증을 해야 함!!

 

✍️회귀에서는 Stratified K폴드가 지원되지 않음

회귀의 결정값은 연속된 숫자값이기 때문에 분포를 정하는 의미가 없기 때문!

728x90

'머신러닝' 카테고리의 다른 글

[머신러닝] 피처스케일링 - StandardScaler, MinMaxScaler  (0) 2024.02.19
[머신러닝] 데이터 전처리 - Label Encoding, One-Hot Encoding, get_dummies()  (0) 2024.02.19
[머신러닝] 하이퍼파라미터 최적화 방법 GridSearchCV vs RandomizedSearchCV  (0) 2024.02.19
[머신러닝] 사이킷런 활용하여 붓꽃 품종 예측하기  (0) 2024.02.14
머신러닝 - pycaret 설치  (0) 2024.02.14
'머신러닝' 카테고리의 다른 글
  • [머신러닝] 데이터 전처리 - Label Encoding, One-Hot Encoding, get_dummies()
  • [머신러닝] 하이퍼파라미터 최적화 방법 GridSearchCV vs RandomizedSearchCV
  • [머신러닝] 사이킷런 활용하여 붓꽃 품종 예측하기
  • 머신러닝 - pycaret 설치
GinaKim
GinaKim
안녕하세요! 반갑습니다 :)
  • GinaKim
    디디
    GinaKim
  • 전체
    오늘
    어제
    • 분류 전체보기 (91)
      • Python (43)
        • Python 기초문법 (25)
        • 데이터 시각화 (5)
        • 통계 (8)
        • 크롤링 (5)
      • git (5)
      • streamlit (5)
      • django (5)
      • 머신러닝 (18)
      • Spark (4)
      • Google Cloud Platform (8)
      • Tableau (0)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.0
GinaKim
[머신러닝] K-폴드 교차 검증, stratified K폴드(층화추출)
상단으로

티스토리툴바