12 #include "../CEBLModel.hpp"
13 #include "../cppR/cppR.hpp"
15 #include "../TextUtils.hpp"
27 this->num_sequences = 3;
28 this->sequence_length = 5;
29 this->pause_length = 1;
32 this->classification_feedback =
false;
33 this->classifier_trained =
false;
36 this->data_is_loaded =
false;
37 this->training_data.clear();
38 this->training_data_filtered.clear();
39 this->training_data_filtered.setFiltered(
true);
42 this->data_file_loaded =
false;
45 this->training_is_active =
false;
46 this->current_training_class = 0;
47 this->current_training_sequence = 0;
48 this->timeout_length = 100;
49 this->waiting =
false;
50 this->training_failed =
false;
51 this->training_index = -1;
54 this->setNumClasses(3);
75 return class_labels[class_num];
90 return sequence_length;
100 return training_data;
105 return data_is_loaded;
110 return data_file_loaded;
115 return data_filename;
120 return training_is_active;
125 return training_failed;
130 return failure_message;
140 return current_training_class;
145 return current_training_sequence;
157 this->num_classes = n;
159 if(class_labels.size() < unsigned(num_classes))
161 int start = class_labels.size();
162 for(
int i=start;i<num_classes;i++)
171 this->num_sequences = n;
176 this->sequence_length = n;
181 this->pause_length = n;
186 this->class_labels = labels;
191 this->class_labels[class_number] = label;
197 this->setDataIsLoaded(
true);
198 this->data_file_loaded =
true;
199 this->data_filename = filename;
204 this->training_data.clear();
205 this->training_data_filtered.clear();
210 if(this->training_data_filtered.numClasses() == this->training_data.numClasses())
212 int filter_lags = this->model->filterGetNumLags();
213 std::vector<int> removed_components = this->model->filterGetSelectedComponents();
214 ublas::matrix<double> filter_matrix = this->model->filterGetFilterMatrix();
217 this->training_data_filtered,
235 if(training_is_active && !classification_feedback)
242 this->model->dataStart();
243 if(!model->dataIsStarted())
247 this->model->dataStop();
254 this->initializeTraining();
258 cout <<
"starting timeout\n";
261 void Training::initializeTraining()
263 this->training_is_active =
true;
264 this->waiting =
false;
265 this->training_failed =
false;
266 this->training_index = -1;
267 this->training_data.clear();
268 this->training_data.reserve(num_classes,num_sequences);
269 this->training_data.setChannelNames(this->model->channelsGetEnabledNames());
270 this->training_data.setClassLabels(this->class_labels);
271 this->training_data_filtered.clear();
280 this->training_data.setSequenceOrder(training_class_ordering);
288 training_is_active =
false;
297 stopFailure(
"Stopped manually.");
302 void Training::stopSuccess()
304 training_failed =
false;
305 training_is_active =
false;
306 data_is_loaded =
true;
307 classifier_trained =
false;
311 void Training::stopFailure(
string msg)
313 if(training_is_active)
315 training_failed =
true;
316 training_is_active =
false;
317 data_is_loaded =
false;
318 failure_message = msg;
320 classifier_trained =
false;
321 cout <<
"Training stopped: " << msg <<
"\n";
326 void Training::timeoutFunction()
328 int wait_time = 1000 * pause_length;
329 int train_time = 1000 * sequence_length;
335 if(training_index < 0)
338 training_timer.restart();
342 if(
unsigned(training_index) >= training_class_ordering.size())
344 stopFailure(
"Training index out of range.");
347 current_training_class = training_class_ordering.at(training_index);
348 current_training_sequence =
349 current_class_sequence[current_training_class];
353 if(training_timer.elapsed() > wait_time)
358 cout <<
"* collecting for class " << current_training_class <<
" sequence " << current_class_sequence[current_training_class] <<
"\n";
363 this->model->dataStart();
367 stopFailure(e.what());
371 training_timer.restart();
380 EEGData new_data = this->model->dataReadAllRaw();
381 EEGData unfiltered = this->model->getDataProcess()->process(new_data,
true,
true,
false);
382 bool filter_enabled = this->model->getDataProcess()->getFilterEnabled();
383 bool filter_trained = this->model->getFilterConfig()->isTrained();
384 int cls = current_training_class;
385 int seq = current_class_sequence[cls];
387 training_data.
append(cls, seq, unfiltered);
392 if(filter_enabled && filter_trained)
394 filtered = this->model->getDataProcess()->process(new_data,
true,
true,
true);
395 training_data_filtered.
append(cls, seq, filtered);
400 if(classification_feedback && classifier_trained)
402 if(filter_enabled && filter_trained)
403 this->classifySamples(filtered);
405 this->classifySamples(unfiltered);
410 stopFailure(e.what());
414 if(training_timer.elapsed() > train_time)
416 training_timer.restart();
417 cout <<
"* done collecting for class " << current_training_class <<
" sequence " << current_class_sequence[current_training_class] <<
"\n";
422 this->model->dataStop();
426 stopFailure(e.what());
430 if(
unsigned(training_index) >= training_class_ordering.size()-1)
433 if(!classification_feedback)
436 cout <<
"* training complete\n";
442 trainClassifierAndContinue();
449 current_class_sequence[current_training_class]++;
453 current_training_class = training_class_ordering.at(training_index);
454 training_timer.restart();
462 this->model->dataStop();
473 return model->realtimeIsTrainingClassifier();
476 void Training::classifySamples(
EEGData samples)
481 features = model->featuresExtract(samples);
485 cerr <<
"Caught exception when extracting features.\n";
486 stopFailure(
"Failed to extract features.");
490 ublas::vector<int> classes;
493 classes = model->classifierUse(features);
497 cerr <<
"Caught exception when classifying: " << e.what() <<
"\n";
498 stopFailure(
"Failed to classify.");
503 cerr <<
"Caught exception when classifying: No Message\n";
504 stopFailure(
"Failed to classify.");
511 using namespace cppR;
513 if(model->classifierGetUseProbs())
515 model->decisionUpdateWithProbabilities(model->classifierGetLastProbs());
519 model->decisionUpdateWithClassification(classes);
522 this->class_proportions = model->decisionDecideClasses();
527 void Training::trainClassifierAndContinue()
529 cout <<
"training classifier\n";
530 training_failed =
false;
531 data_is_loaded =
true;
532 classifier_trained =
false;
535 stopFailure(
"Tried to train classifier without training data.");
541 cout << training_data <<
"\n";
542 model->realtimeTrainClassifierThreaded();
543 while(model->realtimeIsTrainingClassifier())
547 if(!model->realtimeIsReady() || model->realtimeLastTrainFailed())
549 stopFailure(
"Failed to train classifier.");
554 model->decisionInit(num_classes);
555 classifier_trained =
true;
556 initializeTraining();