10 this->threshold = 0.8;
12 this->num_classes = 3;
13 this->plugin_name =
"MSPRT";
23 std::map<std::string, CEBL::Param> MSPRT::getParamsList()
25 std::map<std::string, CEBL::Param> params;
26 CEBL::Param thresh(
"Threshold",
"", this->threshold);
30 params[
"thresh"] = thresh;
43 void MSPRT::setParamsList( std::map<std::string, CEBL::Param> &p)
46 double old_threshold = this->threshold;
47 double old_g = this->g;
49 this->threshold = p[
"thresh"].getDouble();
50 this->g = p[
"g"].getDouble();
52 if(old_g != g || old_threshold != threshold)
54 this->init(num_classes);
59 void MSPRT::updateWithProbabilities(std::vector<double> probs)
61 if(sums.size() != probs.size())
63 cerr <<
"MSPRT: size of probability vectory doesn't seem right. Did you initialize the decision\n";
67 ublas::vector<double> y = g * sums;
68 double logsumexp = log(
sum(
apply(y,exp)));
69 this->log_probs = y -
rep(logsumexp,y.size());
72 std::vector<double> MSPRT::decideClasses()
74 ublas::vector<double> percents;
75 double expthresh = 1.0 / exp(log_threshold);
76 percents =
apply(log_probs, exp) * expthresh;
79 if(
max(percents) >= 1.0)
86 void MSPRT::init(
int num_classes)
88 this->log_threshold = log(this->threshold);
89 this->num_classes = num_classes;
99 map<string, SerializedObject> ret;
102 ret[
"log_threshold"] =
serialize(log_threshold);
113 deserialize(objects[
"log_threshold"],log_threshold);