User Tools

Site Tools


code:multi_class

Multi-class classification in scikit-learn

Let's use a One-vs-the-rest classifier on the iris dataset. The data has four features that describe features of three types of iris flowers.

multi_class.py
 
import numpy as np
from sklearn import datasets
from sklearn.multiclass import OneVsRestClassifier,OneVsOneClassifier
from sklearn.svm import LinearSVC,SVC
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_validate
 
# load the iris dataset:
 
iris = datasets.load_iris()
X, y = iris.data, iris.target
 
cv_generator = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
 
# one-vs-the-rest
classifier = OneVsRestClassifier(LinearSVC())
#classifier = OneVsRestClassifier(SVC(C=1, kernel='linear'))
results_ovr = cross_validate(classifier, X, y, cv=cv_generator, scoring='accuracy', return_train_score=False)
np.mean(results_ovr['test_score'])
 
# one-vs-one
classifier = OneVsOneClassifier(LinearSVC())
results_ovo = cross_validate(classifier, X, y, cv=cv_generator, scoring='accuracy', return_train_score=False)
np.mean(results_ovo['test_score'])
 
# does this mean that one-vs-one is better?  not necessarily...
classifier = OneVsRestClassifier(SVC(C=1, kernel='rbf', gamma=0.5))
results_ovr_nonlinear = cross_validate(classifier, X, y, cv=cv_generator, scoring='accuracy', return_train_score=False)
np.mean(results_ovr_nonlinear['test_score'])
code/multi_class.txt ยท Last modified: 2018/11/05 11:49 by asa