CEBL  2.1
RealTimeClassification.cpp
Go to the documentation of this file.
1 
8 #include <boost/thread/thread.hpp>
9 #include <boost/thread/mutex.hpp>
10 #include <boost/bind.hpp>
11 #include "../CEBLModel.hpp"
12 #include "cppR/cppR.hpp"
13 //----------------------------------------------------------------------
14 // CONSTRUCTORS / DESTRUCTORS
15 
16 
18 {
19  this->model = model;
20  this->is_classifying = false;
21  this->timeout_length = 100;
22  this->currently_training_classifier = false;
23  this->train_classifier_thread = NULL;
24  this->selected_class = -1;
25  this->training_failed = false;
26 }
27 
29 {
30 }
31 
32 //----------------------------------------------------------------------
33 //GETTING OPERATIONS
35 {
36  return model->classifierIsTrained();
37 }
40 {
41  return is_classifying;
42 }
45 {
46  std::vector<int> ret = classification_queue;
48  return ret;
49 }
50 
53 {
54  return classification_queue;
55 }
56 
57 
58 //----------------------------------------------------------------------
59 //SETTING OPERATIONS
61 {
62  {
63  boost::mutex::scoped_lock lock(thread_lock);
64  classification_queue.resize(0);
65  }
66 }
67 
68 
69 //----------------------------------------------------------------------
70 //CONTROL CLASSIFICATION
72 {
73  this->training_failed = false;
74  this->currently_training_classifier = true;
75  this->halt_training = false;
76  //--------------------------------------------------
77  // process and featurize training data
78  EEGTrainingData training_data = this->model->trainingGetData();
79  EEGTrainingData processed_data;
80  EEGData temp;
81  processed_data.reserve(training_data.numClasses(),training_data.numSequences());
82  for(int cls = 0; cls < training_data.numClasses(); cls++)
83  {
84  for(int seq = 0; seq < training_data.numSequences(cls); seq++)
85  {
86  //halt the training of classifier
87  if(halt_training)
88  {
89  this->currently_training_classifier = false;
90  this->training_failed = true;
91  return;
92  }
93  //process the data
94  temp = training_data.get(cls,seq);
95  temp = model->processData(temp);
96  //extract features
97  try
98  {
99  model->featureReset();
100  cout << "Featurizing class " << cls << ", seq " << seq << "\n";
101  temp = model->featuresExtract(temp);
102  }
103  catch(exception &e)
104  {
105  this->training_failed = true;
106  cerr << e.what() << ". Make sure you have installed the most recent feature plugins.\n";
107  this->currently_training_classifier = false;
108  return;
109  }
110  //add to new training data
111  processed_data.set(cls,seq,temp);
112  //halt the training of classifier
113  if(halt_training)
114  {
115  this->currently_training_classifier = false;
116  this->training_failed = true;
117  return;
118  }
119  }
120  }
121 
122  //--------------------------------------------------
123  // train classifier on features
124  try
125  {
126  this->model->classifierTrain(processed_data);
127  }
128  catch(exception &e)
129  {
130  cerr << e.what() << "\n";
131  this->currently_training_classifier = false;
132  this->training_failed = true;
133  return;
134  }
135  this->currently_training_classifier = false;
136 
137 }
138 
140 {
141  if(isReady())
142  {
143  this->model->dataStart();
144  //init the decision maker
145  model->decisionInit(model->trainingGetNumClasses());
146  try
147  {
148  this->timeoutStart();
149  is_classifying = true;
150  }
151  catch(...)
152  {
153  throw ClassificationException("Failed to start timeout for real-time classification.");
154  }
155  }
156  else
157  {
158  throw ClassificationException("Classifier is not ready.");
159  }
160 
161 }
163 {
164  this->haltAndJoin();
165  this->model->dataStop();
166  this->is_classifying = false;
167  this->clearClassificationQueue();
168 }
169 
171 {
172  //read processed data
173  EEGData new_data = model->dataReadAll();
174  if(new_data.size1() == 0)
175  return;
176 
177  //extract features from the data
178  EEGData features;
179  try
180  {
181  features = model->featuresExtract(new_data);
182  }
183  catch(...)
184  {
185  cerr << "Caught exception when extracting features.\n";
186  halt = true;
187  this->is_classifying = false;
188  return;
189  }
190  //classify
191  ublas::vector<int> classes;
192  try
193  {
194  classes = model->classifierUse(features);
195  }
196  catch(exception &e)
197  {
198  cerr << "Caught exception when classifying: " << e.what() << "\n";
199  halt = true;
200  this->is_classifying = false;
201  return;
202  }
203  catch(...)
204  {
205  cerr << "Caught exception when classifying: No Message\n";
206  halt = true;
207  this->is_classifying = false;
208  return;
209  }
210  //add these classes to the queue
211  {
212  boost::mutex::scoped_lock lock(thread_lock);
213  for(unsigned i=0;i<classes.size();i++)
214  classification_queue.push_back(classes[i]);
215  }
216 
217  //----------------------------------------
218  //update proportions
219  {
220  using namespace cppR;
221 
222  // update decision
223  if(model->classifierGetUseProbs())
224  {
225  std::vector<std::vector<double> > probs
226  = model->classifierGetLastProbs();
227  if(probs.size() > 0)
228  {
229  model->decisionUpdateWithProbabilities(probs);
230  }
231  else
232  {
233  cerr << "Realtime Classificaton: Probabilities are empty. Using predicted classes instead.\n";
234  model->decisionUpdateWithClassification(classes);
235  }
236  }
237  else
238  {
239  model->decisionUpdateWithClassification(classes);
240  }
241 
242  // get decision proportions
243  this->class_proportions = model->decisionDecideClasses();
244  ublas::vector<double> props =
245  asUblasVector(this->class_proportions);
246 
247  // if a proportion has reached 100%, set selected class
248  if(max(props) >= 1.0)
249  {
250  this->selected_class = whichMax(props);
251  }
252  }
253 }
254 
255 
256 
257 //----------------------------------------------------------------------
258 // Threaded Train Classifier
259 
261 {
262  return currently_training_classifier;
263 }
264 
266 {
267  obj->currently_training_classifier = true;
268  try
269  {
270  obj->trainClassifier();
271  }
272  catch(exception &e)
273  {
274  cerr << "Exception caught when training classifier: " << e.what() << "\n";
275  }
276  obj->currently_training_classifier = false;
277  delete obj->train_classifier_thread;
278  obj->train_classifier_thread = NULL;
279  cout << "Classifier training stopped.\n";
280 }
281 
283 {
284  train_classifier_thread = new boost::thread(boost::bind(&runTrainClassifier, this));
285 }
286 
288 {
289  if(currently_training_classifier && train_classifier_thread != NULL)
290  {
291  halt_training = true;
292  //try halting the classifier
293  model->classifierHaltTrain();
294  model->featuresHalt();
295  }
296 }
297 
298