CEBL  2.1
Training.cpp
Go to the documentation of this file.
1 
9 #include "Training.hpp"
10 #include "DataProcess.hpp"
11 #include "FilterConfig.hpp"
12 #include "../CEBLModel.hpp"
13 #include "../cppR/cppR.hpp"
14 #include "DataIO.hpp"
15 #include "../TextUtils.hpp"
16 using namespace cppR;
17 
18 //----------------------------------------------------------------------
19 // CONSTRUCTORS / DESTRUCTORS
20 
21 
23 {
24  this->model = model;
25 
26  //training options
27  this->num_sequences = 3;
28  this->sequence_length = 5;
29  this->pause_length = 1;
30 
31  //feedback options
32  this->classification_feedback = false;
33  this->classifier_trained = false;
34 
35  //data
36  this->data_is_loaded = false;
37  this->training_data.clear();
38  this->training_data_filtered.clear();
39  this->training_data_filtered.setFiltered(true);
40 
41  //file loading
42  this->data_file_loaded = false;
43 
44  //training process
45  this->training_is_active = false;
46  this->current_training_class = 0;
47  this->current_training_sequence = 0;
48  this->timeout_length = 100; //ms
49  this->waiting = false;
50  this->training_failed = false;
51  this->training_index = -1;
52 
53  //set to 3 classes by default
54  this->setNumClasses(3);
55 }
56 
58 {
59 
60 }
61 
62 
63 //----------------------------------------------------------------------
64 
65 //GETTING OPERATIONS
66 
67 
68 std::vector<string> Training::getClassLabels()
69 {
70  return class_labels;
71 }
72 
73 string Training::getClassLabel(int class_num)
74 {
75  return class_labels[class_num];
76 }
77 
79 {
80  return num_classes;
81 }
82 
84 {
85  return num_sequences;
86 }
87 
89 {
90  return sequence_length;
91 }
92 
94 {
95  return pause_length;
96 }
97 
99 {
100  return training_data;
101 }
102 
104 {
105  return data_is_loaded;
106 }
107 
109 {
110  return data_file_loaded;
111 }
112 
114 {
115  return data_filename;
116 }
117 
119 {
120  return training_is_active;
121 }
122 
124 {
125  return training_failed;
126 }
127 
129 {
130  return failure_message;
131 }
132 
134 {
135  return waiting;
136 }
137 
139 {
140  return current_training_class;
141 }
142 
144 {
145  return current_training_sequence;
146 }
147 
148 //--------------------------------------------------------------------------------
149 
150 //SETTING OPERATIONS
151 
152 
153 
154 
156 {
157  this->num_classes = n;
158  //class labels
159  if(class_labels.size() < unsigned(num_classes))
160  {
161  int start = class_labels.size();
162  for(int i=start;i<num_classes;i++)
163  {
164  class_labels.push_back("Class " + TextUtils::IntToString(i));
165  }
166  }
167 }
168 
170 {
171  this->num_sequences = n;
172 }
173 
175 {
176  this->sequence_length = n;
177 }
178 
180 {
181  this->pause_length = n;
182 }
183 
184 void Training::setClassLabels(std::vector<string> labels)
185 {
186  this->class_labels = labels;
187 }
188 
189 void Training::setClassLabel(int class_number, string label)
190 {
191  this->class_labels[class_number] = label;
192 }
193 
194 void Training::loadData(string filename)
195 {
196  this->training_data = DataIO::loadTrainingDataFromFile(filename);
197  this->setDataIsLoaded(true);
198  this->data_file_loaded = true;
199  this->data_filename = filename;
200 }
201 
203 {
204  this->training_data.clear();
205  this->training_data_filtered.clear();
206 }
207 
208 void Training::saveData(string filename)
209 {
210  if(this->training_data_filtered.numClasses() == this->training_data.numClasses())
211  {
212  int filter_lags = this->model->filterGetNumLags();
213  std::vector<int> removed_components = this->model->filterGetSelectedComponents();
214  ublas::matrix<double> filter_matrix = this->model->filterGetFilterMatrix();
215  DataIO::saveTrainingSessionToFile(this->training_data,
216  filename,
217  this->training_data_filtered,
218  filter_lags,
219  removed_components,
220  filter_matrix);
221  }
222  else
223  {
224  DataIO::saveTrainingDataToFile(this->training_data,
225  filename);
226  }
227 }
228 
229 //--------------------------------------------------------------------------------
230 // TRAINING PROCESS
231 
232 
234 {
235  if(training_is_active && !classification_feedback)
236  {
237  return;
238  }
239  //check source to see if it is ready
240  try
241  {
242  this->model->dataStart();
243  if(!model->dataIsStarted())
244  {
245  throw TrainingException("Data source is not ready.");
246  }
247  this->model->dataStop();
248  }
249  catch(...)
250  {
251  throw TrainingException("Data source is not ready.");
252  }
253 
254  this->initializeTraining();
255 
256  //start the thread
257  timeoutStart();
258  cout << "starting timeout\n";
259 }
260 
261 void Training::initializeTraining()
262 {
263  this->training_is_active = true;
264  this->waiting = false;
265  this->training_failed = false;
266  this->training_index = -1;
267  this->training_data.clear();
268  this->training_data.reserve(num_classes,num_sequences);
269  this->training_data.setChannelNames(this->model->channelsGetEnabledNames());
270  this->training_data.setClassLabels(this->class_labels);
271  this->training_data_filtered.clear();
272 
273  //create randomized list of classes to train on
274  ublas::vector<int> ord = sample(rep(vectorRange(0,num_classes-1),
275  num_sequences));
276  training_class_ordering = asStdVector(ord);
277  current_class_sequence = asStdVector(rep(0,num_classes));
278 
279  // printVector(training_class_ordering);
280  this->training_data.setSequenceOrder(training_class_ordering);
281 
282 }
283 
285 {
286  if(!isStarted())
287  {
288  training_is_active = false;
289  return;
290  }
291  else
292  {
293  // let the thread finish
294  haltAndJoin();
295 
296  //stop the training process
297  stopFailure("Stopped manually.");
298  }
299 }
300 
301 //stopping functions
302 void Training::stopSuccess()
303 {
304  training_failed = false;
305  training_is_active = false;
306  data_is_loaded = true;
307  classifier_trained = false;
308  halt = true;
309 }
310 
311 void Training::stopFailure(string msg)
312 {
313  if(training_is_active)
314  {
315  training_failed = true;
316  training_is_active = false;
317  data_is_loaded = false;
318  failure_message = msg;
319  halt = true;
320  classifier_trained = false;
321  cout << "Training stopped: " << msg << "\n";
322  }
323 }
324 
325 
326 void Training::timeoutFunction()
327 {
328  int wait_time = 1000 * pause_length; // 1 second
329  int train_time = 1000 * sequence_length;
330 
331  if(!this->halt)
332  {
333 
334  //initialize
335  if(training_index < 0)
336  {
337  training_index = 0;
338  training_timer.restart();
339  waiting = true;
340  }
341  //set the current training class
342  if(unsigned(training_index) >= training_class_ordering.size())
343  {
344  stopFailure("Training index out of range.");
345  return;
346  }
347  current_training_class = training_class_ordering.at(training_index);
348  current_training_sequence =
349  current_class_sequence[current_training_class];
350  //if we are waiting, don't collect data, but wait
351  if(waiting)
352  {
353  if(training_timer.elapsed() > wait_time)
354  {
355  //start the data source and begin collecting
356  waiting = false;
357  //set current class etc and start recording
358  cout << "* collecting for class " << current_training_class << " sequence " << current_class_sequence[current_training_class] <<"\n";
359 
360  //try starting the data source
361  try
362  {
363  this->model->dataStart();
364  }
365  catch(exception & e)
366  {
367  stopFailure(e.what());
368  }
369 
370  //reset the timer
371  training_timer.restart();
372  }
373  }
374  //we are not waiting, now we are collecting
375  else
376  {
377  //save the recorded data
378  try
379  {
380  EEGData new_data = this->model->dataReadAllRaw();
381  EEGData unfiltered = this->model->getDataProcess()->process(new_data,true,true,false);
382  bool filter_enabled = this->model->getDataProcess()->getFilterEnabled();
383  bool filter_trained = this->model->getFilterConfig()->isTrained();
384  int cls = current_training_class;
385  int seq = current_class_sequence[cls];
386  //save unfiltered data
387  training_data.append(cls, seq, unfiltered);
388  // cout << "class " << cls << " seq " << seq << " has " << training_data.get(cls,seq).numSamples() << "\n";
389 
390  //also save the filtered version if filter is enabled and trained
391  EEGData filtered;
392  if(filter_enabled && filter_trained)
393  {
394  filtered = this->model->getDataProcess()->process(new_data,true,true,true);
395  training_data_filtered.append(cls, seq, filtered);
396  }
397 
398  //if we are providing classification feedback, classify these
399  //new samples using the previously trained classifier
400  if(classification_feedback && classifier_trained)
401  {
402  if(filter_enabled && filter_trained)
403  this->classifySamples(filtered);
404  else
405  this->classifySamples(unfiltered);
406  }
407  }
408  catch(exception & e)
409  {
410  stopFailure(e.what());
411  }
412 
413  //see if it is time to stop recording
414  if(training_timer.elapsed() > train_time)
415  {
416  training_timer.restart();
417  cout << "* done collecting for class " << current_training_class << " sequence " << current_class_sequence[current_training_class] << "\n";
418 
419  //try stopping the data source
420  try
421  {
422  this->model->dataStop();
423  }
424  catch(exception & e)
425  {
426  stopFailure(e.what());
427  }
428  //we are done
429  //cout << "index = " << training_index << ", size = " << training_class_ordering.size() << "\n";
430  if(unsigned(training_index) >= training_class_ordering.size()-1)
431  {
432  //if we are just training. finish.
433  if(!classification_feedback)
434  {
435  stopSuccess();
436  cout << "* training complete\n";
437  }
438  //if we want feedback, we need to now train the classifier
439  //and do it again
440  else
441  {
442  trainClassifierAndContinue();
443  }
444  }
445  else
446  {
447  //increment class index
448  training_index++;
449  current_class_sequence[current_training_class]++;
450 
451  //go back to waiting
452  waiting = true;
453  current_training_class = training_class_ordering.at(training_index);
454  training_timer.restart();
455  }
456  }
457  }
458  }
459 
460  if(this->halt)
461  {
462  this->model->dataStop();
463  return;
464  }
465 }
466 
467 
468 //------------------------------------------------------------------------------
469 // TRAINING FEEDBACK
470 
472 {
473  return model->realtimeIsTrainingClassifier();
474 }
475 
476 void Training::classifySamples(EEGData samples)
477 {
478  EEGData features;
479  try
480  {
481  features = model->featuresExtract(samples);
482  }
483  catch(...)
484  {
485  cerr << "Caught exception when extracting features.\n";
486  stopFailure("Failed to extract features.");
487  return;
488  }
489  //classify
490  ublas::vector<int> classes;
491  try
492  {
493  classes = model->classifierUse(features);
494  }
495  catch(exception &e)
496  {
497  cerr << "Caught exception when classifying: " << e.what() << "\n";
498  stopFailure("Failed to classify.");
499  return;
500  }
501  catch(...)
502  {
503  cerr << "Caught exception when classifying: No Message\n";
504  stopFailure("Failed to classify.");
505  return;
506  }
507 
508  //----------------------------------------
509  //update proportions
510  {
511  using namespace cppR;
512  // update decision
513  if(model->classifierGetUseProbs())
514  {
515  model->decisionUpdateWithProbabilities(model->classifierGetLastProbs());
516  }
517  else
518  {
519  model->decisionUpdateWithClassification(classes);
520  }
521  // get decision proportions
522  this->class_proportions = model->decisionDecideClasses();
523  }
524 }
525 
526 
527 void Training::trainClassifierAndContinue()
528 {
529  cout << "training classifier\n";
530  training_failed = false;
531  data_is_loaded = true;
532  classifier_trained = false;
533  if(!data_is_loaded)
534  {
535  stopFailure("Tried to train classifier without training data.");
536  return;
537  }
538  else
539  {
540  //train classifier on most recent training data
541  cout << training_data << "\n";
542  model->realtimeTrainClassifierThreaded();
543  while(model->realtimeIsTrainingClassifier())
544  {
545  this->sleep(1000);
546  }
547  if(!model->realtimeIsReady() || model->realtimeLastTrainFailed())
548  {
549  stopFailure("Failed to train classifier.");
550  }
551  //if it succeeded, start training again
552  else
553  {
554  model->decisionInit(num_classes);
555  classifier_trained = true;
556  initializeTraining();
557  }
558  }
559 }