Machine learning packages, functions (10)

호진·2021년 11월 21일
0
post-thumbnail

핵심 패키지와 함수

scikit-learn

  • cross_validate()는 교차 검증을 수행하는 함수다.
    첫 번재 매개변수에 교차 검증을 수행할 모델 객체를 전달한다. 두 번째와 세 번째 매개변수에 특성과 타깃 데이터를 전달한다.
    scoring 매개변수에 검증에 사용할 평가 지표를 지정할 수 있다. 기본적으로 분류 모델은 정확도를 의미하는 'accuracy', 회귀 모델은 결정 계수를 의미하는 'r2'가 된다.
    cv 매개변수에 교차 검증 폴드 수나 스플리터 객체를 지정할 수 있다. 기본값은 5이다. 회귀일 때는 KFold 클래스를 사용하고 분류일 때는 StratifiedKFold 클래스를 사용하여 5-폴드 교차 검증을 수행한다.
    n_jobs 매개변수는 교차 검증을 수행할 때 사용할 CPU 코어 수를 지정한다. 기본값은 1로 하나의 코어를 사용하고 -1로 지정하면 시스템에 있는 모든 코어를 사용한다.

  • GridSearchCV는 교차 검증으로 하이퍼파라미터 탐색을 수행한다. 최상의 모델을 찾은 후 훈련 세트 전체를 사용해 최종 모델을 훈련한다.
    첫 번째 매개변수로 그리드 서치를 수행할 모델 객체를 전달한다. 두 번째 매개변수에는 탐색할 모델의 매개변수와 값을 전달한다.
    scoring, cv, n_jobs, return_train_score 매개변수는 cross_validate() 함수와 동일하다.

  • RandomizedSearchCV는 교차 검증으로 랜덤한 하이퍼파라미터 탐색을 수행한다. 최상의 모델을 찾은 후 훈련 세트 전체를 사용해 최종 모델을 훈련한다. 첫 번째 매개변수로 그리드서치를 수행할 모델 객체를 전달한다. 두 번째 매개변수에는 탐색할 모델의 매개변수와 확률 분포 객체를 전달한다.
    scoring, cv, n_jobs, return_train_score 매개변수는 cross_validate() 함수와 동일하다.

profile
💭(。•̀ᴗ-)✧

0개의 댓글