CEBL  2.1
SessionManager.cpp
Go to the documentation of this file.
1 
7 #include "../CEBLModel.hpp"
8 #include "Serialization.hpp"
9 #include "Session.hpp"
10 #include "SessionManager.hpp"
11 #include "../TextUtils.hpp"
12 
13 // Model internal classes.
14 // SessionManager has to be a friend to each of these classes.
15 #include "DataSource.hpp"
16 #include "Training.hpp"
17 #include "ClassifiersConfig.hpp"
18 #include "FeaturesConfig.hpp"
19 #include "FilterConfig.hpp"
20 #include "DecisionConfig.hpp"
21 
22 
23 //----------------------------------------------------------------------
24 // CONSTRUCTORS / DESTRUCTORS
25 
26 
28 {
29  this->model = model;
30  current_session = new Session();
31 }
32 
34 {
35  delete current_session;
36 }
37 
38 
39 //----------------------------------------------------------------------
40 // PUBLIC METHODS
41 
43 {
44  if(shouldSaveAs())
45  {
46  throw FileException("Session has not yet been saved.");
47  }
48  else
49  {
50  this->updateSession();
51  current_session->save();
52  }
53 }
54 
55 void SessionManager::saveAs(string filename)
56 {
57  this->updateSession();
58  current_session->save(filename.c_str());
59 }
60 
61 void SessionManager::load(string filename)
62 {
63  current_session->load(filename.c_str());
64  this->updateModel();
65 }
66 
68 {
69  return current_session->shouldSaveAs();
70 }
71 
72 string SessionManager::encodeKey(string str)
73 {
74  std::replace(str.begin(), str.end(), ' ', '_');
75  return str;
76 }
77 
78 string SessionManager::decodeKey(string str)
79 {
80  std::replace(str.begin(), str.end(), '_', ' ');
81  return str;
82 }
83 
84 
85 //----------------------------------------------------------------------
86 // UPDATE MODEL AND SESSION
87 
88 
89 
90 void SessionManager::updateModel()
91 {
92  using namespace TextUtils;
93 
94  Session &s = *current_session;
95  CEBLModel &m = *model;
96  string temp;
97 
98  //--------------------------------------------------
99  //data source
100  s.setCurrentSection("data_source");
101  if(s.exists("source"))
102  {
103  int temp = s.get<int>("source");
104  m.dataSetSource(temp);
105  }
106  if(s.exists("stored_buffer"))
107  {
108  EEGData temp = s.get<ublas::matrix<double> >("stored_buffer");
109  m.getDataSource()->setDataBuffer(temp);
110  }
111 
112  //--------------------------------------------------
113  //process
114  s.setCurrentSection("data_process");
115  if(s.exists("reference_enabled"))
116  {
117  bool temp = s.get<bool>("reference_enabled");
118  m.processSetReferenceEnabled(temp);
119  }
120  if(s.exists("remove_enabled"))
121  {
122  bool temp = s.get<bool>("remove_enabled");
123  m.processSetRemoveEnabled(temp);
124  }
125  if(s.exists("filter_enabled"))
126  {
127  bool temp = s.get<bool>("filter_enabled");
128  m.processSetFilterEnabled(temp);
129  }
130 
131  //--------------------------------------------------
132  //channels
133  s.setCurrentSection("channels");
134  if(s.exists("configuration_string"))
135  {
136  string temp = s.get<string>("configuration_string");
137  m.channelsSetConfigurationFromString(temp);
138  }
139 
140  //--------------------------------------------------
141  //training
142  s.setCurrentSection("training");
143  if(s.exists("training_data_is_loaded"))
144  {
145  bool temp = s.get<bool>("training_data_is_loaded");
146  m.getTraining()->setDataIsLoaded(temp);
147  if(temp)
148  {
149  if(s.exists("training_data"))
150  {
151  EEGTrainingData temp = s.get<EEGTrainingData>("training_data");
152  m.getTraining()->setTrainingData(temp);
153  }
154  }
155  }
156  if(s.exists("num_classes"))
157  {
158  int temp = s.get<int>("num_classes");
159  m.trainingSetNumClasses(temp);
160  }
161  if(s.exists("num_sequences"))
162  {
163  int temp = s.get<int>("num_sequences");
164  m.trainingSetNumSequences(temp);
165  }
166  if(s.exists("sequence_length"))
167  {
168  int temp = s.get<int>("sequence_length");
169  m.trainingSetSequenceLength(temp);
170  }
171  if(s.exists("pause_length"))
172  {
173  int temp = s.get<int>("pause_length");
174  m.trainingSetPauseLength(temp);
175  }
176 
177  if(s.exists("class_labels"))
178  {
179  std::vector<string> temp = s.get<std::vector<string> >("class_labels");
180  m.trainingSetClassLabels(temp);
181  }
182 
183  //--------------------------------------------------
184  //features
185  {
186  s.setCurrentSection("features");
187  if(s.exists("selected"))
188  {
189  string temp = s.get<string>("selected");
190  m.featuresSetSelected(temp);
191  }
192 
193  FeaturesConfig * cc = m.getFeaturesConfig();
194  PluginLoader<Feature> *plugins = cc->getPluginLoader();
195  //now load each plugin
196  if(s.exists("names"))
197  {
198  vector<string> names;
199  s.get("names",&names);
200  for(unsigned i=0;i<names.size();i++)
201  {
202  Feature * p = plugins->getPlugin(names[i]);
203  if(p != NULL)
204  p->load(s.get<map<string, SerializedObject> >(encodeKey(names[i])
205  +"_internals"));
206 
207  }
208  }
209  }
210  //--------------------------------------------------
211  //classifiers
212  {
213  s.setCurrentSection("classifiers");
214  if(s.exists("selected"))
215  {
216  string temp = s.get<string>("selected");
217  m.classifiersSetSelected(temp);
218  }
219 
220  ClassifiersConfig * cc = m.getClassifiersConfig();
221  PluginLoader<Classifier> *plugins = cc->getPluginLoader();
222  //now load each plugin
223  if(s.exists("names"))
224  {
225  vector<string> names;
226  s.get("names",&names);
227  for(unsigned i=0;i<names.size();i++)
228  {
229  Classifier * p = plugins->getPlugin(names[i]);
230  if(p != NULL)
231  p->load(s.get<map<string, SerializedObject> >(encodeKey(names[i])
232  +"_internals"));
233  }
234  }
235  }
236  //--------------------------------------------------
237  //decisions
238  {
239  s.setCurrentSection("decision");
240  if(s.exists("selected"))
241  {
242  string temp = s.get<string>("selected");
243  m.decisionSetSelected(temp);
244  }
245 
246  DecisionConfig * cc = m.getDecisionConfig();
247  PluginLoader<Decision> *plugins = cc->getPluginLoader();
248  //now load each plugin
249  if(s.exists("names"))
250  {
251  vector<string> names;
252  s.get("names",&names);
253  for(unsigned i=0;i<names.size();i++)
254  {
255  Decision * p = plugins->getPlugin(names[i]);
256  if(p != NULL)
257  p->load(s.get<map<string, SerializedObject> >(encodeKey(names[i])
258  +"_internals"));
259  }
260  }
261  }
262  //--------------------------------------------------
263  //filters
264  {
265  s.setCurrentSection("filter");
266  if(s.exists("selected"))
267  {
268  string temp = s.get<string>("selected");
269  m.filterSetSelected(temp);
270  }
271  if(s.exists("lags"))
272  {
273  int temp = s.get<int>("lags");
274  m.filterSetNumLags(temp);
275  }
276  if(s.exists("components"))
277  {
278  string temp = s.get<string>("components");
279  m.filterSetSelectedComponentsString(temp);
280  }
281 
282 
283  FilterConfig * cc = m.getFilterConfig();
284  PluginLoader<Filter> *plugins = cc->getPluginLoader();
285  //now load each plugin
286  if(s.exists("names"))
287  {
288  vector<string> names;
289  s.get("names",&names);
290  for(unsigned i=0;i<names.size();i++)
291  {
292  Filter * p = plugins->getPlugin(names[i]);
293  if(p != NULL)
294  p->load(s.get<map<string, SerializedObject> >(encodeKey(names[i])
295  +"_internals"));
296  }
297  }
298  }
299 }
300 
301 //------------------------------------------------------------------------------
302 
303 void SessionManager::updateSession()
304 {
305  Session &s = *current_session;
306  CEBLModel &m = *model;
307 
308  //--------------------------------------------------
309  //data source
310  s.setCurrentSection("data_source");
311  s ("source", m.dataGetSource())
312  ("stored_buffer",m.dataGetStoredData().getMatrix());
313 
314  //--------------------------------------------------
315  //process
316  s.setCurrentSection("data_process");
317  s ("reference_enabled", m.processGetReferenceEnabled())
318  ("remove_enabled", m.processGetRemoveEnabled())
319  ("filter_enabled", m.processGetFilterEnabled());
320 
321  //--------------------------------------------------
322  //channels
323  s.setCurrentSection("channels");
324  s ("configuration_string", m.channelsGetConfigurationString());
325 
326  //--------------------------------------------------
327  //device stream
328  s.setCurrentSection("device");
329  s ("location", m.deviceGetLocation())
330  ("sample_rate",m.deviceGetSampleRate())
331  ("block_size",m.deviceGetBlockSize());
332 
333  //--------------------------------------------------
334  //file stream
335  s.setCurrentSection("file_data_stream");
336  s ("filename", m.fileStreamGetFilename());
337 
338  //--------------------------------------------------
339  //training
340  s.setCurrentSection("training");
341  s ("training_data_is_loaded", m.trainingDataIsLoaded())
342  ("training_data", m.trainingGetData())
343  ("num_classes", m.trainingGetNumClasses())
344  ("num_sequences", m.trainingGetNumSequences())
345  ("sequence_length", m.trainingGetSequenceLength())
346  ("pause_length", m.trainingGetPauseLength())
347  ("class_labels", m.trainingGetClassLabels());
348 
349  //--------------------------------------------------
350  //features
351  {
353  PluginLoader<Feature> *plugins = cc->getPluginLoader();
354  vector<string> names = plugins->getNames();
355  s.setCurrentSection("features");
356  s ("selected", m.featuresGetSelected())
357  ("names",names);
358  //now save each plugin
359  for(unsigned i=0;i<names.size();i++)
360  {
361  s(encodeKey(names[i]) + "_internals",
362  plugins->getPlugin(names[i])->save());
363  }
364  }
365  //--------------------------------------------------
366  //classifiers
367  {
369  PluginLoader<Classifier> *plugins = cc->getPluginLoader();
370  vector<string> names = plugins->getNames();
371  s.setCurrentSection("classifiers");
372  s ("selected", m.classifiersGetSelected())
373  ("names",names);
374  //now save each plugin
375  for(unsigned i=0;i<names.size();i++)
376  {
377  s(encodeKey(names[i]) + "_internals",
378  plugins->getPlugin(names[i])->save());
379  }
380  }
381 
382  //--------------------------------------------------
383  //filter
384  {
385  FilterConfig * cc = m.getFilterConfig();
386  PluginLoader<Filter> *plugins = cc->getPluginLoader();
387  vector<string> names = plugins->getNames();
388  s.setCurrentSection("filter");
389  s ("selected", m.filterGetSelected())
390  ("names",names)
391  ("components",m.filterGetSelectedComponentsString())
392  ("lags",m.filterGetNumLags());
393  //now save each plugin
394  for(unsigned i=0;i<names.size();i++)
395  {
396  s(encodeKey(names[i]) + "_internals",
397  plugins->getPlugin(names[i])->save());
398  }
399  }
400 
401  //--------------------------------------------------
402  //decision
403  {
405  PluginLoader<Decision> *plugins = cc->getPluginLoader();
406  vector<string> names = plugins->getNames();
407  s.setCurrentSection("decision");
408  s ("selected", m.decisionGetSelected())
409  ("names",names);
410  //now save each plugin
411  for(unsigned i=0;i<names.size();i++)
412  {
413  s(encodeKey(names[i]) + "_internals",
414  plugins->getPlugin(names[i])->save());
415  }
416  }
417 }