00001
00020 #ifndef AURESERVOIR_TRAIN_H__
00021 #define AURESERVOIR_TRAIN_H__
00022
00023 #include "utilities.h"
00024 #include "delaysum.h"
00025
00026 namespace aureservoir
00027 {
00028
00034 enum TrainAlgorithm
00035 {
00036 TRAIN_PI,
00037 TRAIN_LS,
00038 TRAIN_RIDGEREG,
00039 TRAIN_DS_PI
00040 };
00041
00042 template <typename T> class ESN;
00043
00056 template <typename T>
00057 class TrainBase
00058 {
00059 public:
00060
00062 TrainBase(ESN<T> *esn) { esn_=esn; }
00063
00065 virtual ~TrainBase() {}
00066
00076 virtual void train(const typename ESN<T>::DEMatrix &in,
00077 const typename ESN<T>::DEMatrix &out,
00078 int washout) throw(AUExcept) = 0;
00079
00080 protected:
00081
00083 void checkParams(const typename ESN<T>::DEMatrix &in,
00084 const typename ESN<T>::DEMatrix &out,
00085 int washout) throw(AUExcept);
00086
00087
00089 void collectStates(const typename ESN<T>::DEMatrix &in,
00090 const typename ESN<T>::DEMatrix &out,
00091 int washout);
00092
00094 void squareStates();
00095
00097 void clearData()
00098 { M.resize(1,1); O.resize(1,1); }
00099
00101 ESN<T> *esn_;
00102
00104 typename ESN<T>::DEMatrix M;
00106 typename ESN<T>::DEMatrix O;
00107 };
00108
00135 template <typename T>
00136 class TrainPI : public TrainBase<T>
00137 {
00138 using TrainBase<T>::esn_;
00139 using TrainBase<T>::M;
00140 using TrainBase<T>::O;
00141
00142 public:
00143 TrainPI(ESN<T> *esn) : TrainBase<T>(esn) {}
00144 virtual ~TrainPI() {}
00145
00147 virtual void train(const typename ESN<T>::DEMatrix &in,
00148 const typename ESN<T>::DEMatrix &out,
00149 int washout) throw(AUExcept);
00150 };
00151
00169 template <typename T>
00170 class TrainLS : public TrainBase<T>
00171 {
00172 using TrainBase<T>::esn_;
00173 using TrainBase<T>::M;
00174 using TrainBase<T>::O;
00175
00176 public:
00177 TrainLS(ESN<T> *esn) : TrainBase<T>(esn) {}
00178 virtual ~TrainLS() {}
00179
00181 virtual void train(const typename ESN<T>::DEMatrix &in,
00182 const typename ESN<T>::DEMatrix &out,
00183 int washout) throw(AUExcept);
00184 };
00185
00209 template <typename T>
00210 class TrainRidgeReg : public TrainBase<T>
00211 {
00212 using TrainBase<T>::esn_;
00213 using TrainBase<T>::M;
00214 using TrainBase<T>::O;
00215
00216 public:
00217 TrainRidgeReg(ESN<T> *esn) : TrainBase<T>(esn) {}
00218 virtual ~TrainRidgeReg() {}
00219
00221 virtual void train(const typename ESN<T>::DEMatrix &in,
00222 const typename ESN<T>::DEMatrix &out,
00223 int washout) throw(AUExcept);
00224 };
00225
00242 template <typename T>
00243 class TrainDSPI : public TrainBase<T>
00244 {
00245 using TrainBase<T>::esn_;
00246 using TrainBase<T>::M;
00247 using TrainBase<T>::O;
00248
00249 public:
00250 TrainDSPI(ESN<T> *esn) : TrainBase<T>(esn) {}
00251 virtual ~TrainDSPI() {}
00252
00254 virtual void train(const typename ESN<T>::DEMatrix &in,
00255 const typename ESN<T>::DEMatrix &out,
00256 int washout) throw(AUExcept);
00257 };
00258
00259 }
00260
00261 #endif // AURESERVOIR_TRAIN_H__