This shows you the differences between two versions of the page.
Next revision | Previous revision Next revision Both sides next revision | ||
code:demo2d [2015/09/28 14:36] asa created |
code:demo2d [2015/09/28 14:46] asa |
||
---|---|---|---|
Line 1: | Line 1: | ||
- | Displaying the decision boundary of a classifier | + | ===== Displaying the decision boundary of a classifier ===== |
- | ================================================= | + | |
+ | Below is Python code for displaying the decision boundary of a classifier. | ||
+ | To use it: | ||
+ | |||
+ | (:source lang=python:) | ||
+ | |||
+ | import demo2d | ||
+ | from sklearn import svm | ||
+ | |||
+ | demo2d.get_data() | ||
+ | |||
+ | classifier = svm.LinearSVC() | ||
+ | |||
+ | demo2d.decision_surface(classifier) | ||
+ | |||
+ | (:sourcend:) | ||
+ | |||
+ | |||
+ | <file python demo2d.py> | ||
+ | |||
+ | import numpy | ||
+ | from numpy import arange | ||
+ | import matplotlib | ||
+ | from matplotlib import pylab | ||
+ | pylab.rcParams['contour.negative_linestyle'] = 'solid' | ||
+ | |||
+ | """ | ||
+ | demo2d: display decision boundaries and contours of the decision function | ||
+ | of a classifier on two dimensional data. | ||
+ | |||
+ | USAGE:: | ||
+ | |||
+ | first you need to generate some data; you need to call | ||
+ | demo2d.generate_data() | ||
+ | data is generated by clicking '1' or '2' at positions on the figure | ||
+ | where you want your data points to be. | ||
+ | click 'q' when you're done. | ||
+ | demo2d.decision_surface(classifier) then plots the decision boundary and | ||
+ | contours of the discriminant function of the given classifier on the data | ||
+ | that was generated. | ||
+ | demo2d.decision_surface can be called several times using different classifiers. | ||
+ | """ | ||
+ | |||
+ | X = [] | ||
+ | Y = [] | ||
+ | #plotStr = ['or', 'ob'] | ||
+ | plotStr = ['or', '+b'] | ||
+ | xmin = -1 | ||
+ | xmax = 1 | ||
+ | ymin = -1 | ||
+ | ymax = 1 | ||
+ | |||
+ | def pick(event) : | ||
+ | |||
+ | global X | ||
+ | global Y | ||
+ | key_to_class = {'1' : 0, '2' : 1} | ||
+ | if event.key == 'q' : | ||
+ | if len(X) == 0 : return | ||
+ | print 'done creating data. close this window and use the decisionSurface function' | ||
+ | pylab.disconnect(binding_id) | ||
+ | if event.key =='1' or event.key == '2' : | ||
+ | if event.inaxes is not None: | ||
+ | print 'data coords', event.xdata, event.ydata | ||
+ | X.append([event.xdata, event.ydata]) | ||
+ | Y.append(key_to_class[event.key]) | ||
+ | pylab.plot([event.xdata], [event.ydata], | ||
+ | plotStr[int(event.key) - 1]) | ||
+ | pylab.draw() | ||
+ | |||
+ | def get_data(**args) : | ||
+ | pylab.subplot(111) | ||
+ | pylab.plot([xmin,xmin,xmax,xmax], [ymin,ymax,ymin,ymax], '.k') | ||
+ | pylab.title("press the numbers 1 or 2 to generate data points and 'q' to quit") | ||
+ | global binding_id | ||
+ | binding_id = pylab.connect('key_press_event', pick) | ||
+ | pylab.show() | ||
+ | |||
+ | def scatter(X, Y, **args) : | ||
+ | |||
+ | markersize = 5 | ||
+ | if 'markersize' in args : | ||
+ | markersize = args['markersize'] | ||
+ | for i in range(len(X)) : | ||
+ | pylab.plot(X[i][0], X[i][1], plotStr[Y[i]], markersize=markersize) | ||
+ | |||
+ | |||
+ | def decision_surface(classifier, fileName = None, **args) : | ||
+ | |||
+ | global X | ||
+ | global Y | ||
+ | |||
+ | classifier.fit(X, Y) | ||
+ | |||
+ | numContours = 3 | ||
+ | if 'numContours' in args : | ||
+ | numContours = args['numContours'] | ||
+ | title = None | ||
+ | if 'title' in args : | ||
+ | title = args['title'] | ||
+ | markersize=5 | ||
+ | fontsize = 'medium' | ||
+ | if 'markersize' in args : | ||
+ | markersize = args['markersize'] | ||
+ | if 'fontsize' in args : | ||
+ | fontsize = args['fontsize'] | ||
+ | contourFontsize = 10 | ||
+ | if 'contourFontsize' in args : | ||
+ | contourFontsize = args['contourFontsize'] | ||
+ | showColorbar = False | ||
+ | if 'showColorbar' in args : | ||
+ | showColorbar = args['showColorbar'] | ||
+ | show = True | ||
+ | if fileName is not None : | ||
+ | show = False | ||
+ | if 'show' in args : | ||
+ | show = args['show'] | ||
+ | |||
+ | # setting up the grid | ||
+ | delta = 0.01 | ||
+ | if 'delta' in args : | ||
+ | delta = args['delta'] | ||
+ | | ||
+ | |||
+ | x = arange(xmin, xmax, delta) | ||
+ | y = arange(ymin, ymax, delta) | ||
+ | |||
+ | Z = numpy.zeros((len(x), len(y)), numpy.float) | ||
+ | gridX = numpy.zeros((len(x) *len(y), 2), numpy.float) | ||
+ | n = 0 | ||
+ | for i in range(len(x)) : | ||
+ | for j in range(len(y)) : | ||
+ | gridX[n][0] = x[i] | ||
+ | gridX[n][1] = y[j] | ||
+ | n += 1 | ||
+ | |||
+ | results = classifier.decision_function(gridX) | ||
+ | |||
+ | n = 0 | ||
+ | for i in range(len(x)) : | ||
+ | for j in range(len(y)) : | ||
+ | Z[i][j] = results[n] | ||
+ | n += 1 | ||
+ | | ||
+ | #pylab.figure() | ||
+ | im = pylab.imshow(numpy.transpose(Z), | ||
+ | interpolation='bilinear', origin='lower', | ||
+ | cmap=pylab.cm.gray, extent=(xmin,xmax,ymin,ymax) ) | ||
+ | |||
+ | if numContours == 1 : | ||
+ | C = pylab.contour(numpy.transpose(Z), | ||
+ | [0], | ||
+ | origin='lower', | ||
+ | linewidths=(3), | ||
+ | colors = 'black', | ||
+ | extent=(xmin,xmax,ymin,ymax)) | ||
+ | elif numContours == 3 : | ||
+ | C = pylab.contour(numpy.transpose(Z), | ||
+ | [-1,0,1], | ||
+ | origin='lower', | ||
+ | linewidths=(1,3,1), | ||
+ | colors = 'black', | ||
+ | extent=(xmin,xmax,ymin,ymax)) | ||
+ | else : | ||
+ | C = pylab.contour(numpy.transpose(Z), | ||
+ | numContours, | ||
+ | origin='lower', | ||
+ | linewidths=2, | ||
+ | extent=(xmin,xmax,ymin,ymax)) | ||
+ | |||
+ | pylab.clabel(C, | ||
+ | inline=1, | ||
+ | fmt='%1.1f', | ||
+ | fontsize=contourFontsize) | ||
+ | | ||
+ | # plot the data | ||
+ | scatter(X, Y, markersize=markersize) | ||
+ | xticklabels = pylab.getp(pylab.gca(), 'xticklabels') | ||
+ | yticklabels = pylab.getp(pylab.gca(), 'yticklabels') | ||
+ | pylab.setp(xticklabels, fontsize=fontsize) | ||
+ | pylab.setp(yticklabels, fontsize=fontsize) | ||
+ | |||
+ | if title is not None : | ||
+ | pylab.title(title, fontsize=fontsize) | ||
+ | if showColorbar : | ||
+ | pylab.colorbar(im) | ||
+ | |||
+ | # colormap: | ||
+ | pylab.hot() | ||
+ | if fileName is not None : | ||
+ | pylab.savefig(fileName) | ||
+ | if show : | ||
+ | pylab.show() | ||
+ | |||
+ | |||
+ | </file> |