Warning: Declaration of action_plugin_tablewidth::register(&$controller) should be compatible with DokuWiki_Action_Plugin::register(Doku_Event_Handler $controller) in /s/bach/b/class/cs545/public_html/fall16/lib/plugins/tablewidth/action.php on line 93
import numpy
import theano
import theano.tensor as T
class LogisticRegression(object):
def __init__(self, learning_rate = 0.1, regularizer = 0.1, epochs = 500):
self.learning_rate = learning_rate
self.regularizer = regularizer
self.epochs = epochs
def fit(self, X, labels) :
num_features = len(X[0])
# initialize the weights
w = theano.shared(
value=numpy.random.randn(num_features),
name='w'
)
# initialize the bias
b = theano.shared(0.0, name="b")
x = T.dmatrix("x")
y = T.dvector("y")
# Probability that target = 1
p_1 = 1 / (1 + T.exp(-T.dot(x, w) - b))
# The predictions thresholded
prediction = p_1 > 0.5
# The cross-entropy loss function
cross_ent = -y * T.log(p_1) - (1-y) * T.log(1-p_1)
# The cost function to minimize
cost = cross_ent.mean() + self.regularizer * (w ** 2).sum()
# Compute the gradient of the cost
gw = T.grad(cost, wrt=w)
gb = T.grad(cost, wrt=b)
#gw, gb = T.grad(cost, [w, b])
train = theano.function(
inputs=[x,y],
outputs=[prediction, cost],
updates=((w, w - self.learning_rate * gw), (b, b - self.learning_rate * gb)))
self._predict = theano.function(inputs=[x], outputs=prediction)
# train the model:
for i in range(self.epochs) :
pred, c = train(X, labels)
print ("In sample error: ", c)
self.w = w
self.b = b
def predict(self, X) :
return self._predict(X)
if __name__=='__main__' :
import numpy as np
from sklearn import cross_validation
from sklearn import metrics
# read in the heart dataset
data=np.genfromtxt("../../data/heart_scale.data", delimiter=",")
X=data[:,1:]
y=data[:,0]
y = (y + 1) / 2
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y, test_size=0.4, random_state=0)
classifier=LogisticRegression()
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
print("test set accuracy: ", metrics.accuracy_score(y_test, y_pred) )