CEBL  2.1
MSPRT.cpp
Go to the documentation of this file.
1 #include "MSPRT.hpp"
2 #include "cppR/cppR.hpp"
3 
4 using namespace cppR;
5 
6 namespace CEBL {
7 
8  MSPRT::MSPRT()
9  {
10  this->threshold = 0.8;
11  this->g = 0.01;
12  this->num_classes = 3;
13  this->plugin_name = "MSPRT";
14  }
15 
16  MSPRT::~MSPRT()
17  {
18  }
19 
20 
21 
23  std::map<std::string, CEBL::Param> MSPRT::getParamsList()
24  {
25  std::map<std::string, CEBL::Param> params;
26  CEBL::Param thresh("Threshold", "", this->threshold);
27  thresh.setStep(0.001);
28  thresh.setMax(10);
29  thresh.setMin(0);
30  params["thresh"] = thresh;
31 
32  CEBL::Param g("g", "Gain parameter", this->g);
33  g.setStep(0.001);
34  g.setMax(1.0);
35  g.setMin(0.0);
36  params["g"] = g;
37 
38  return params;
39  }
40 
41 
43  void MSPRT::setParamsList( std::map<std::string, CEBL::Param> &p)
44  {
45 
46  double old_threshold = this->threshold;
47  double old_g = this->g;
48 
49  this->threshold = p["thresh"].getDouble();
50  this->g = p["g"].getDouble();
51 
52  if(old_g != g || old_threshold != threshold)
53  {
54  this->init(num_classes);
55  }
56  }
57 
58 
59  void MSPRT::updateWithProbabilities(std::vector<double> probs)
60  {
61  if(sums.size() != probs.size())
62  {
63  cerr << "MSPRT: size of probability vectory doesn't seem right. Did you initialize the decision\n";
64  return;
65  }
66  sums = sums + cppR::asUblasVector(probs);
67  ublas::vector<double> y = g * sums;
68  double logsumexp = log(sum(apply(y,exp)));
69  this->log_probs = y - rep(logsumexp,y.size());
70  }
71 
72  std::vector<double> MSPRT::decideClasses()
73  {
74  ublas::vector<double> percents;
75  double expthresh = 1.0 / exp(log_threshold);
76  percents = apply(log_probs, exp) * expthresh;
77 
78  //reset if we have selected a class
79  if(max(percents) >= 1.0)
80  {
81  init(num_classes);
82  }
83  return cppR::asStdVector(percents);
84  }
85 
86  void MSPRT::init(int num_classes)
87  {
88  this->log_threshold = log(this->threshold);
89  this->num_classes = num_classes;
90  this->sums = cppR::rep(0,num_classes);
91  }
92 
93  //----------------------------------------------------------------------
94  //SAVING and LOADING
95 
97  map<string, SerializedObject> MSPRT::save() const
98  {
99  map<string, SerializedObject> ret;
100  ret["sums"] = serialize(sums);
101  ret["log_probs"] = serialize(log_probs);
102  ret["log_threshold"] = serialize(log_threshold);
103  ret["threshold"] = serialize(threshold);
104  ret["g"] = serialize(g);
105  return ret;
106  }
107 
109  void MSPRT::load(map<string, SerializedObject> objects)
110  {
111  deserialize(objects["sums"],sums);
112  deserialize(objects["log_probs"],log_probs);
113  deserialize(objects["log_threshold"],log_threshold);
114  deserialize(objects["threshold"],threshold);
115  deserialize(objects["g"],g);
116  }
117 
118 }
119 
120 
121 
122 /*************************************************************/
123 //DYNAMIC LOADING
124 
126 {
127  return new CEBL::MSPRT;
128 }
129 
130 extern "C" void ObjectDestroy(CEBL::Decision* p)
131 {
132  delete p;
133 }
134