00001
00020 #ifndef AURESERVOIR_ESN_H__
00021 #define AURESERVOIR_ESN_H__
00022
00023 #include <iostream>
00024 #include <map>
00025 #include <algorithm>
00026
00027 #include "utilities.h"
00028 #include "activations.h"
00029 #include "init.h"
00030 #include "simulate.h"
00031 #include "train.h"
00032
00033 namespace aureservoir
00034 {
00035
00063 template <typename T = float>
00064 class ESN
00065 {
00066 public:
00067
00069 typedef std::map<InitParameter,T> ParameterMap;
00070
00071 typedef typename SPMatrix<T>::Type SPMatrix;
00072 typedef typename DEMatrix<T>::Type DEMatrix;
00073 typedef typename DEVector<T>::Type DEVector;
00074
00076 ESN();
00077
00079 ESN(const ESN<T> &src);
00080
00082 const ESN& operator= (const ESN<T>& src);
00083
00085 ~ESN();
00086
00088
00089
00094 void init()
00095 throw(AUExcept)
00096 { init_->init(); }
00097
00110 double adapt(const DEMatrix &in)
00111 throw(AUExcept);
00112
00123 inline void train(const DEMatrix &in, const DEMatrix &out, int washout)
00124 throw(AUExcept)
00125 { train_->train(in, out, washout); }
00126
00134 inline void simulate(const DEMatrix &in, DEMatrix &out)
00135 { sim_->simulate(in, out); }
00136
00140 void resetState()
00141 {
00142 std::fill_n( x_.data(), x_.length(), 0 );
00143 std::fill_n( sim_->last_out_.data(), outputs_, 0 );
00144 }
00145
00147
00148
00149
00163 double adapt(T *inmtx, int inrows, int incols) throw(AUExcept);
00164
00177 inline void train(T *inmtx, int inrows, int incols,
00178 T *outmtx, int outrows, int outcols,
00179 int washout) throw(AUExcept);
00180
00192 inline void simulate(T *inmtx, int inrows, int incols,
00193 T *outmtx, int outrows, int outcols) throw(AUExcept);
00194
00203 inline void simulateStep(T *invec, int insize, T *outvec, int outsize)
00204 throw(AUExcept);
00205
00207
00208
00209
00210
00218 void setBPCutoff(const DEVector &f1, const DEVector &f2) throw(AUExcept);
00219
00227 void setBPCutoff(T *f1vec, int f1size, T *f2vec, int f2size)
00228 throw(AUExcept);
00229
00243 void setIIRCoeff(const DEMatrix &B, const DEMatrix &A, int series=1)
00244 throw(AUExcept);
00245
00259 void setIIRCoeff(T *bmtx, int brows, int bcols,
00260 T *amtx, int arows, int acols,
00261 int series=1) throw(AUExcept);
00262
00264
00265
00266
00272 void post();
00273
00275 int getSize() const { return neurons_; };
00277 int getInputs() const { return inputs_; };
00279 int getOutputs() const { return outputs_; };
00281 double getNoise() const { return noise_; }
00282
00288 T getInitParam(InitParameter key) { return init_params_[key]; }
00289
00291 InitAlgorithm getInitAlgorithm() const
00292 { return static_cast<InitAlgorithm>(net_info_.at(INIT_ALG)); }
00294 TrainAlgorithm getTrainAlgorithm() const
00295 { return static_cast<TrainAlgorithm>(net_info_.at(TRAIN_ALG)); }
00297 SimAlgorithm getSimAlgorithm() const
00298 { return static_cast<SimAlgorithm>(net_info_.at(SIMULATE_ALG)); }
00299
00301 ActivationFunction getReservoirAct() const
00302 { return static_cast<ActivationFunction>(net_info_.at(RESERVOIR_ACT)); }
00304 ActivationFunction getOutputAct() const
00305 { return static_cast<ActivationFunction>(net_info_.at(OUTPUT_ACT)); }
00306
00308
00309
00310
00312 const DEMatrix &getWin() { return Win_; }
00314 const SPMatrix &getW() { return W_; }
00316 const DEMatrix &getWback() { return Wback_; }
00318 const DEMatrix &getWout() { return Wout_; }
00320 const DEVector &getX() { return x_; }
00326 DEMatrix getDelays() throw(AUExcept) { return sim_->getDelays(); }
00327
00329
00330
00331
00335 void getWin(T **mtx, int *rows, int *cols);
00339 void getWback(T **mtx, int *rows, int *cols);
00343 void getWout(T **mtx, int *rows, int *cols);
00345 void getX(T **vec, int *length);
00352 void getW(T *wmtx, int wrows, int wcols) throw(AUExcept);
00360 void getDelays(T *wmtx, int wrows, int wcols) throw(AUExcept);
00361
00363
00364
00365
00367 void setInitAlgorithm(InitAlgorithm alg=INIT_STD)
00368 throw(AUExcept);
00370 void setTrainAlgorithm(TrainAlgorithm alg=TRAIN_PI)
00371 throw(AUExcept);
00373 void setSimAlgorithm(SimAlgorithm alg=SIM_STD)
00374 throw(AUExcept);
00375
00377 void setSize(int neurons=10) throw(AUExcept);
00379 void setInputs(int inputs=1) throw(AUExcept);
00381 void setOutputs(int outputs=1) throw(AUExcept);
00382
00385 void setNoise(double noise) throw(AUExcept);
00386
00388 void setInitParam(InitParameter key, T value=0.);
00389
00391 void setReservoirAct(ActivationFunction f=ACT_TANH) throw(AUExcept);
00393 void setOutputAct(ActivationFunction f=ACT_LINEAR) throw(AUExcept);
00394
00401
00402
00404
00405
00406
00408 void setWin(const DEMatrix &Win) throw(AUExcept);
00410 void setW(const DEMatrix &W) throw(AUExcept);
00412 void setWback(const DEMatrix &Wback) throw(AUExcept);
00414 void setWout(const DEMatrix &Wout) throw(AUExcept);
00416 void setX(const DEVector &x) throw(AUExcept);
00417
00423 void setLastOutput(const DEVector &last) throw(AUExcept);
00424
00426
00427
00428
00434 void setWin(T *inmtx, int inrows, int incols) throw(AUExcept);
00435
00441 void setW(T *inmtx, int inrows, int incols) throw(AUExcept);
00442
00448 void setWback(T *inmtx, int inrows, int incols) throw(AUExcept);
00449
00455 void setWout(T *inmtx, int inrows, int incols) throw(AUExcept);
00456
00462 void setX(T *invec, int insize) throw(AUExcept);
00463
00469 void setLastOutput(T *last, int size) throw(AUExcept);
00470
00472
00473 protected:
00474
00476 InitBase<T> *init_;
00477
00479 TrainBase<T> *train_;
00480
00482 SimBase<T> *sim_;
00483
00484
00487 DEMatrix Win_;
00488
00490 SPMatrix W_;
00491
00494 DEMatrix Wback_;
00495
00498 DEMatrix Wout_;
00499
00502 DEVector x_;
00503
00504
00509 void (*reservoirAct_)(T *data, int size);
00510
00515 void (*outputAct_)(T *data, int size);
00516
00521 void (*outputInvAct_)(T *data, int size);
00522
00523
00525 int neurons_;
00527 int inputs_;
00529 int outputs_;
00530
00532 double noise_;
00533
00534
00536 ParameterMap init_params_;
00537
00539 enum NetInfo
00540 {
00541 RESERVOIR_ACT,
00542 OUTPUT_ACT,
00543 INIT_ALG,
00544 TRAIN_ALG,
00545 SIMULATE_ALG
00546 };
00547 typedef std::map<NetInfo, int> InfoMap;
00548
00550 InfoMap net_info_;
00551
00553 string getActString(int act);
00555 string getInitString(int alg);
00557 string getSimString(int alg);
00559 string getTrainString(int alg);
00560
00561
00563
00564 friend class InitBase<T>;
00565 friend class InitStd<T>;
00566 friend class TrainBase<T>;
00567 friend class TrainPI<T>;
00568 friend class TrainLS<T>;
00569 friend class TrainRidgeReg<T>;
00570 friend class TrainDSPI<T>;
00571 friend class SimBase<T>;
00572 friend class SimStd<T>;
00573 friend class SimSquare<T>;
00574 friend class SimLI<T>;
00575 friend class SimBP<T>;
00576 friend class SimFilter<T>;
00577 friend class SimFilter2<T>;
00578 friend class SimFilterDS<T>;
00580 };
00581
00582 }
00583
00584 #include <aureservoir/esn.hpp>
00585 #include <aureservoir/init.hpp>
00586 #include <aureservoir/simulate.hpp>
00587 #include <aureservoir/train.hpp>
00588
00589 #endif // AURESERVOIR_ESN_H__