train.h

Go to the documentation of this file.
00001 /***************************************************************************/
00020 #ifndef AURESERVOIR_TRAIN_H__
00021 #define AURESERVOIR_TRAIN_H__
00022 
00023 #include "utilities.h"
00024 #include "delaysum.h"
00025 
00026 namespace aureservoir
00027 {
00028 
00034 enum TrainAlgorithm
00035 {
00036   TRAIN_PI,        
00037   TRAIN_LS,        
00038   TRAIN_RIDGEREG,  
00039   TRAIN_DS_PI      
00040 };
00041 
00042 template <typename T> class ESN;
00043 
00056 template <typename T>
00057 class TrainBase
00058 {
00059  public:
00060 
00062   TrainBase(ESN<T> *esn) { esn_=esn; }
00063 
00065   virtual ~TrainBase() {}
00066 
00076   virtual void train(const typename ESN<T>::DEMatrix &in,
00077                      const typename ESN<T>::DEMatrix &out,
00078                      int washout) throw(AUExcept) = 0;
00079 
00080  protected:
00081 
00083   void checkParams(const typename ESN<T>::DEMatrix &in,
00084                    const typename ESN<T>::DEMatrix &out,
00085                    int washout) throw(AUExcept);
00086 
00087 
00089   void collectStates(const typename ESN<T>::DEMatrix &in,
00090                      const typename ESN<T>::DEMatrix &out,
00091                      int washout);
00092 
00094   void squareStates();
00095 
00097   void clearData()
00098   { M.resize(1,1); O.resize(1,1); }
00099 
00101   ESN<T> *esn_;
00102 
00104   typename ESN<T>::DEMatrix M;
00106   typename ESN<T>::DEMatrix O;
00107 };
00108 
00135 template <typename T>
00136 class TrainPI : public TrainBase<T>
00137 {
00138   using TrainBase<T>::esn_;
00139   using TrainBase<T>::M;
00140   using TrainBase<T>::O;
00141 
00142  public:
00143   TrainPI(ESN<T> *esn) : TrainBase<T>(esn) {}
00144   virtual ~TrainPI() {}
00145 
00147   virtual void train(const typename ESN<T>::DEMatrix &in,
00148                      const typename ESN<T>::DEMatrix &out,
00149                      int washout) throw(AUExcept);
00150 };
00151 
00169 template <typename T>
00170 class TrainLS : public TrainBase<T>
00171 {
00172   using TrainBase<T>::esn_;
00173   using TrainBase<T>::M;
00174   using TrainBase<T>::O;
00175 
00176  public:
00177   TrainLS(ESN<T> *esn) : TrainBase<T>(esn) {}
00178   virtual ~TrainLS() {}
00179 
00181   virtual void train(const typename ESN<T>::DEMatrix &in,
00182                      const typename ESN<T>::DEMatrix &out,
00183                      int washout) throw(AUExcept);
00184 };
00185 
00209 template <typename T>
00210 class TrainRidgeReg : public TrainBase<T>
00211 {
00212   using TrainBase<T>::esn_;
00213   using TrainBase<T>::M;
00214   using TrainBase<T>::O;
00215 
00216  public:
00217   TrainRidgeReg(ESN<T> *esn) : TrainBase<T>(esn) {}
00218   virtual ~TrainRidgeReg() {}
00219 
00221   virtual void train(const typename ESN<T>::DEMatrix &in,
00222                      const typename ESN<T>::DEMatrix &out,
00223                      int washout) throw(AUExcept);
00224 };
00225 
00242 template <typename T>
00243 class TrainDSPI : public TrainBase<T>
00244 {
00245   using TrainBase<T>::esn_;
00246   using TrainBase<T>::M;
00247   using TrainBase<T>::O;
00248 
00249  public:
00250   TrainDSPI(ESN<T> *esn) : TrainBase<T>(esn) {}
00251   virtual ~TrainDSPI() {}
00252 
00254   virtual void train(const typename ESN<T>::DEMatrix &in,
00255                      const typename ESN<T>::DEMatrix &out,
00256                      int washout) throw(AUExcept);
00257 };
00258 
00259 } // end of namespace aureservoir
00260 
00261 #endif // AURESERVOIR_TRAIN_H__

Generated on Wed Mar 12 21:16:05 2008 for aureservoir by  doxygen 1.5.3