15 ublas::matrix<double> X =
t(data.
collapse());
17 cout <<
"Starting QDA Training\n"
18 <<
"X = matrix[" << X.size1() <<
"," << X.size2() <<
"]\n";
23 int nClasses = using_classes;
24 int classesSize = classes.size();
26 if(using_classes<classesSize){
27 cout <<
"More than the requested number of classes was collected. Only the first " << nClasses <<
" will be trained."<<endl;
29 else if(using_classes > classesSize){
30 cerr <<
"Data was only collected for " << classes.size() <<
" classes. Cannot train " << using_classes <<
" classes." <<endl;
35 int nSamples =
nrow(X);
36 int nFeatures =
ncol(X);
37 covariances.resize(nClasses);
38 covariancesInv.resize(nClasses);
39 for (
int k = 0; k < nClasses; k++) {
41 this->inturruptionPoint();
48 priors =
rep(0, nClasses);
52 cout <<
"Classes: " << classes << endl << flush;
54 for(
int k=0; k< nClasses; k++) {
56 this->inturruptionPoint();
60 std::vector<bool> mask =
createMask(classes[k], Y);
61 ublas::matrix<double> Z(
count(k,classes), nFeatures);
65 int nSamplesThisClass =
nrow(Z);
68 priors[k] = double(nSamplesThisClass) / nSamples;
74 matrix<double> temp =
createMatrix(ublas::vector<double>(row(means,k)),
75 nSamplesThisClass, nFeatures,
true);
76 matrix<double> Zc = Z - temp;
79 this->inturruptionPoint();
81 covariances[k] = prod(
t(Zc), Zc) / nSamplesThisClass;
82 covariancesInv[k] =
solve(covariances[k]);
83 covariancesDet[k] =
det(covariances[k]);
96 trained_classes = nClasses;
101 ublas::vector<int> QDA::use(
const ublas::matrix<double> & data)
103 if(data.size1() == 0)
105 cerr <<
"QDA use error: 0 size data matrix given.\n";
106 ublas::vector<int> ret;
110 ublas::matrix<double> X = data;
114 int nClasses = using_classes;
115 int nSamples =
nrow(X);
116 int nFeatures =
ncol(X);
118 ublas::matrix<double> disc_functions(nSamples,nClasses);
121 for(
int k=0; k< nClasses; k++)
123 matrix<double> temp =
createMatrix(ublas::vector<double>(row(means,k)),
124 nSamples, nFeatures,
true);
125 ublas::matrix<double> Xc = X - temp;
126 double scalarpart = -0.5 * log(covariancesDet[k]) + log(priors[k]);
127 ublas::matrix<double> a = prod( Xc,covariancesInv[k]);
128 ublas::matrix<double> sa =
compProd(a,Xc);
129 ublas::vector<double> vectorpart = -0.5 *
rowSums(sa);
130 for(
unsigned int vi=0; vi<vectorpart.size(); vi++)
131 disc_functions(vi,k) = vectorpart[vi] + scalarpart;
135 ublas::vector<int> predicted_classes;
136 predicted_classes.resize(nSamples);
138 for(
int j=0; j<
nrow(disc_functions); j++)
140 ublas::vector<double> disc_functionsRow = row(disc_functions,j);
141 predicted_classes[j] = classes[
whichMax(disc_functionsRow)];
149 ublas::vector<double> max_disc =
150 cppR::rowApply<double>(disc_functions,&cppR::max<double>);
154 matrix<double> max_discs =
155 createMatrix(max_disc,disc_functions.size1(),disc_functions.size2());
158 matrix<double> probabilities =
159 apply(matrix<double>(disc_functions - max_discs), exp);
165 matrix<double> sum_p_rep =
166 createMatrix(sum_p,probabilities.size1(),probabilities.size2());
167 probabilities =
compDiv(probabilities, sum_p_rep);
170 this->probabilities.resize(probabilities.size1());
171 for(
unsigned i=0;i<probabilities.size1();i++)
172 this->probabilities[i] =
173 asStdVector(ublas::vector<double>(row(probabilities,i)));
179 return predicted_classes;
185 map<string, SerializedObject>
QDA::save()
const
187 map<string, SerializedObject> ret;
191 ret[
"covariances"] =
serialize(covariances);
192 ret[
"covariancesInv"] =
serialize(covariancesInv);
193 ret[
"covariancesDet"] =
serialize(covariancesDet);
195 ret[
"trained_classes"] =
serialize(trained_classes);
196 ret[
"using_classes"] =
serialize(using_classes);
197 ret[
"using_lags"] =
serialize(using_lags);
198 ret[
"trained_lags"] =
serialize(trained_lags);
203 void QDA::load(map<string, SerializedObject> objects)
209 deserialize(objects[
"covariancesInv"],covariancesInv);
210 deserialize(objects[
"covariancesDet"],covariancesDet);
212 deserialize(objects[
"trained_classes"],trained_classes);
213 deserialize(objects[
"using_classes"],using_classes);
221 std::map<std::string, CEBL::Param> QDA::getParamsList()
223 std::map<std::string, CEBL::Param> params;
225 "Should QDA compute probabilities when you use the classifier?",
227 params[
"probs"] = probs;
233 void QDA::setParamsList( std::map<std::string, CEBL::Param> &p)
235 compute_probs = p[
"probs"].getBool();