CEBL  2.1
Training.hpp
Go to the documentation of this file.
1 /*
2 * CEBL : CSU EEG Brain-Computer Interface Lab
3 *
4 * Author: Jeshua Bratman - jeshuabratman@gmail.com
5 *
6 * This file is part of CEBL.
7 *
8 * CEBL is free software; you can redistribute it and/or modify it.
9 * We only ask that if you use our code that you cite the source in
10 * your project or publication.
11 *
12 * EEG Group (www.cs.colostate.edu/eeg)
13 * Department of Computer Science
14 * Colorado State University
15 *
16 */
17 
25 #ifndef TRAINING_H
26 #define TRAINING_H
27 
28 #include "EEGTrainingData.hpp"
29 #include "Timer.hpp"
30 #include "TimeoutThread.hpp"
31 #include "SessionManager.hpp"
32 
33 #include <string>
34 #include <vector>
35 using namespace std;
36 
37 //forward declarations
38 class CEBLModel;
39 
40 class Training : public TimeoutThread
41 {
42 private:
43  //make SessionManager a friend
44  friend class SessionManager;
45 
46  CEBLModel *model;
47 
48  //training options
49  int num_classes;
50  int num_sequences;
51  int sequence_length;
52  int pause_length;
53  std::vector<string> class_labels;
54 
55  //feedback options
56  bool classification_feedback;
57  bool classifier_trained;
58  std::vector<double> class_proportions;
59 
64  void classifySamples(EEGData samples);
65 
68  void trainClassifierAndContinue();
69 
70 
71  //data
72  EEGTrainingData training_data;
73  EEGTrainingData training_data_filtered;
74  bool data_is_loaded;
75 
76  //file loading
77  bool data_file_loaded;
78  string data_filename;
79 
80  //training process
81  bool training_is_active;
82  int current_training_class;
83  int current_training_sequence;
84 
85  Timer training_timer;
86  int wait_length; // in ms
87  std::vector<int> training_class_ordering;
88  std::vector<int> current_class_sequence;
89  int training_index;
90  bool waiting;
91  bool training_failed;
92  string failure_message;
93 
94  void timeoutFunction();
95  void initializeTraining();
96 
97  //control of training
98  void stopSuccess();
99  void stopFailure(string msg);
100 
101 
102  //private access functions
103  void setTrainingData(EEGTrainingData data)
104  {
105  training_data = data;
106  }
107 
108  void setTrainingDataFiltered(EEGTrainingData data)
109  {
110  training_data_filtered = data;
111  }
112 
113  void setDataIsLoaded(bool is_loaded)
114  {
115  data_is_loaded = is_loaded;
116  }
117 
118 public:
119  Training(CEBLModel *);
120  ~Training();
121 
122  //GETTING OPERATIONS
124  std::vector<string> getClassLabels();
126  string getClassLabel(int class_num);
128  int getNumClasses();
130  int getNumSequences();
132  int getSequenceLength();
134  int getPauseLength();
136  EEGTrainingData getData();
138  bool dataIsLoaded();
140  bool isDataFileLoaded();
142  string getDataFilename();
144  bool isActive();
146  bool failed();
148  string getFailureMessage();
150  bool isPaused();
152  int getTrainingClass();
154  int getTrainingSequence();
155 
156  //SETTING OPERATIONS
158  void start();
160  void stop();
162  void setNumClasses(int);
164  void setNumSequences(int);
166  void setSequenceLength(int);
168  void setPauseLength(int);
170  void setClassLabels(std::vector<string> labels);
172  void setClassLabel(int class_number, string label);
174  void loadData(string filename);
176  void clearData();
178  void saveData(string filename);
179 
180 
181 
182  //FEEDBACK OPERATIONS
183  bool feedbackEnabled() { return classification_feedback; }
184  void setFeedbackEnabled(bool flag) { classification_feedback = flag; }
185  bool isTrainingClassifier();
186  std::vector<double> getClassProportions() { return class_proportions; }
187 
188 };
189 
190 #endif
191