GridSearchCV
GridSearchCV 를 통해 우리는 교차 검증과 하이퍼 파라미터 튜닝을 동시에 수행할 수 있습니다. 하이퍼 파라미터는 머신러닝 알고리즘 중 중요하다고 생각되는 구성 요소이며 이 값을 조정하는 튜닝 과정을 통해 알고리즘의 예측 성능을 개선할 수 있다.
교차 검증 기반으로 지정된 파라미터들을 순차적으로 적용해보며 최적의 파라미터 값을 찾아준다. 이 때문에 최적의 파라미터를 알 수 있으나 시간이 오래 걸린다.
주요 파라미터
- estimator(string) : classifier, regressor, pipeline
- param_grid(dict) : 사용될 파라미터명, 값 을 dict 형태로
- scoring(string) : 예측 성능을 측정할 평가 방법 지정
- cv(int) : 교차 검증을 위하 분할되는 학습/테스트 세트의 개수 지정
- refit(bool) : true인 경우 최적의 하이퍼 파라미터를 찾은 후 입력된 개체를 해당 하이퍼 파라미터로 재학습
예시
sklearn에서 제공하는 데이터인 iris 데이터를 이용하여 예시를 확인하도록 해보자
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV, train_test_split
iris_data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target, test_size=0.2, random_state=121)
dtree = DecisionTreeClassifier()
parameters = {'max_depth': [1,2,3], 'min_samples_split':[2,3]}
1. 데이터 분리 후 수행할 파라미터의 범위 정의
parameters: 딕셔너리 형태로 파라미터들을 저장 (후에 GridSearchCV 실행 시 지정된 파라미터들을 순차적으로 돌면서 최적의 파라미터 값을 찾아 줌)
* 현재 max_depth, min_samples_split 가 각각 3,2 개 이므로 총 6번의 loop문을 돌면서 성능 평가
import pandas as pd
grid_dtree = GridSearchCV(dtree, param_grid=parameters, cv=3, refit=True)
2. GridSearchCV 를 이용한 학습 수행
grid_dtree.fit(X_train, y_train)
3. 학습 수행
scores_df = pd.DataFrame(grid_dtree.cv_results_)
scores_df[['params', 'mean_test_score','rank_test_score', 'split0_test_score', 'split1_test_score','split2_test_score']]
4. 결과값 확인
4-1. 결과값 전체 표로 보기
output :
params : 적용된 개별 하이퍼 파라미터의 값
mean_test_score : 세트에 대해 수행된 성능들의 평균값
rank_test_score : mean_test_score을 바탕으로 평가 순위
split_test_score : 각 세트별 정확도
4-2. 최고 성능의 결과값만 보기
print('GridSearchCV 최적 파라미터:', grid_dtree.best_params_)
print('GridSearchCV 최고 정확도:{0:.4f}'.format(grid_dtree.best_score_))
output:
GridSearchCV 최적 파라미터: {'max_depth': 3, 'min_samples_split': 2}
GridSearchCV 최고 정확도:0.9750
5. 정확도가 가장 높았던 파라미터를 이용해 테스트 데이터 예측 후 성능 평가
- 앞에서 refit=True 로 지정했기 떄문에 best_estimator_에 저장이 되어있음
estimator = grid_dtree.best_estimator_
pred = estimator.predict(X_test)
print('테스트 데이터 세트 정확도: {0:.4f}'.format(accuracy_score(y_test,pred)))
output:
테스트 데이터 세트 정확도: 0.9667
출처 : 파이썬 머신러닝 완벽 가이드
'인공지능 > 데이터 분석' 카테고리의 다른 글
[사이킷런] 데이터 전처리 2. 피처스케일링 (0) | 2021.08.16 |
---|---|
[사이킷런] 데이터 전처리 1. 데이터 인코딩 (0) | 2021.08.16 |
[사이킷런] 과적합의 문제와 교차 검증 모델(KFold, StratifiedKFold, cross_val_score) (0) | 2021.08.16 |
[사이킷런] train_test_split, DecisionTreeClassifier 체험 (0) | 2021.08.16 |
[numpy] transpose, linalg.inv, dot (0) | 2020.10.04 |