Warning: Declaration of action_plugin_tablewidth::register(&$controller) should be compatible with DokuWiki_Action_Plugin::register(Doku_Event_Handler $controller) in /s/bach/b/class/cs545/public_html/fall16/lib/plugins/tablewidth/action.php on line 93
code:model_selection [CS545 fall 2016]

User Tools

Site Tools


code:model_selection

model selection in scikit-learn

model_selection.py
 
"""classifier evaluation using scikit-learn
 
more details at:
http://scikit-learn.org/stable/modules/cross_validation.html
http://scikit-learn.org/stable/tutorial/statistical_inference/model_selection.html
"""
 
import numpy as np
from sklearn import cross_validation
from sklearn import svm
from sklearn import metrics
 
# read in the heart dataset
 
data=np.genfromtxt("../data/heart_scale.data", delimiter=",")
X=data[:,1:]
y=data[:,0]
 
# first let's do regular cross-validation:
 
cv = cross_validation.StratifiedKFold(y, 5, shuffle=True, random_state=0)
print (cv.test_folds)
 
classifier = svm.SVC(kernel='linear', C=1)
 
y_predict = cross_validation.cross_val_predict(classifier, X, y, cv=cv)
print(metrics.accuracy_score(y, y_predict))
 
 
# grid search
 
# let's perform model selection using grid search 
 
from sklearn.grid_search import GridSearchCV
Cs = np.logspace(-2, 3, 6)
classifier = GridSearchCV(estimator=svm.LinearSVC(), param_grid=dict(C=Cs) )
classifier.fit(X, y)
 
# print the best accuracy, classifier and parameters:
print (classifier.best_score_)
print (classifier.best_estimator_)
print (classifier.best_params_)
 
# performing nested cross validation:
 
y_predict = cross_validation.cross_val_predict(classifier, X, y, cv=cv)
print(metrics.accuracy_score(y, y_predict))
 
 
# if we want to do grid search over multiple parameters:
param_grid = [
  {'C': [1, 10, 100, 1000], 'kernel': ['linear']},
  {'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']},
 ]
classifier = GridSearchCV(estimator=svm.SVC(), param_grid=param_grid)
 
y_predict = cross_validation.cross_val_predict(classifier, X, y, cv=cv)
print(metrics.accuracy_score(y, y_predict))
code/model_selection.txt ยท Last modified: 2016/10/06 14:58 by asa