esn.h

Go to the documentation of this file.
00001 /***************************************************************************/
00020 #ifndef AURESERVOIR_ESN_H__
00021 #define AURESERVOIR_ESN_H__
00022 
00023 #include <iostream>
00024 #include <map>
00025 #include <algorithm>
00026 
00027 #include "utilities.h"
00028 #include "activations.h"
00029 #include "init.h"
00030 #include "simulate.h"
00031 #include "train.h"
00032 
00033 namespace aureservoir
00034 {
00035 
00063 template <typename T = float>
00064 class ESN
00065 {
00066  public:
00067 
00069   typedef std::map<InitParameter,T> ParameterMap;
00070 
00071   typedef typename SPMatrix<T>::Type SPMatrix;
00072   typedef typename DEMatrix<T>::Type DEMatrix;
00073   typedef typename DEVector<T>::Type DEVector;
00074 
00076   ESN();
00077 
00079   ESN(const ESN<T> &src);
00080 
00082   const ESN& operator= (const ESN<T>& src);
00083 
00085   ~ESN();
00086 
00088 
00089 
00094   void init()
00095     throw(AUExcept)
00096   { init_->init(); }
00097 
00110   double adapt(const DEMatrix &in)
00111     throw(AUExcept);
00112 
00123   inline void train(const DEMatrix &in, const DEMatrix &out, int washout)
00124     throw(AUExcept)
00125   { train_->train(in, out, washout); }
00126 
00134   inline void simulate(const DEMatrix &in, DEMatrix &out)
00135   { sim_->simulate(in, out); }
00136 
00140   void resetState()
00141   {
00142     std::fill_n( x_.data(), x_.length(), 0 );
00143     std::fill_n( sim_->last_out_.data(), outputs_, 0 );
00144   }
00145 
00147 
00148 
00149 
00163   double adapt(T *inmtx, int inrows, int incols) throw(AUExcept);
00164 
00177   inline void train(T *inmtx, int inrows, int incols,
00178                     T *outmtx, int outrows, int outcols,
00179                     int washout) throw(AUExcept);
00180 
00192   inline void simulate(T *inmtx, int inrows, int incols,
00193                        T *outmtx, int outrows, int outcols) throw(AUExcept);
00194 
00203   inline void simulateStep(T *invec, int insize, T *outvec, int outsize)
00204     throw(AUExcept);
00205 
00207 
00208 
00209 
00210 
00218   void setBPCutoff(const DEVector &f1, const DEVector &f2) throw(AUExcept);
00219 
00227   void setBPCutoff(T *f1vec, int f1size, T *f2vec, int f2size)
00228     throw(AUExcept);
00229 
00243   void setIIRCoeff(const DEMatrix &B, const DEMatrix &A, int series=1)  
00244     throw(AUExcept);
00245 
00259   void setIIRCoeff(T *bmtx, int brows, int bcols,
00260                    T *amtx, int arows, int acols,
00261                    int series=1) throw(AUExcept);
00262 
00264 
00265 
00266 
00272   void post();
00273 
00275   int getSize() const { return neurons_; };
00277   int getInputs() const { return inputs_; };
00279   int getOutputs() const { return outputs_; };
00281   double getNoise() const { return noise_; }
00282 
00288   T getInitParam(InitParameter key) { return init_params_[key]; }
00289 
00291   InitAlgorithm getInitAlgorithm() const
00292   { return static_cast<InitAlgorithm>(net_info_.at(INIT_ALG)); }
00294   TrainAlgorithm getTrainAlgorithm() const
00295   { return static_cast<TrainAlgorithm>(net_info_.at(TRAIN_ALG)); }
00297   SimAlgorithm getSimAlgorithm() const
00298   { return static_cast<SimAlgorithm>(net_info_.at(SIMULATE_ALG)); }
00299 
00301   ActivationFunction getReservoirAct() const
00302   { return static_cast<ActivationFunction>(net_info_.at(RESERVOIR_ACT)); }
00304   ActivationFunction getOutputAct() const
00305   { return static_cast<ActivationFunction>(net_info_.at(OUTPUT_ACT)); }
00306 
00308 
00309 
00310 
00312   const DEMatrix &getWin() { return Win_; }
00314   const SPMatrix &getW() { return W_; }
00316   const DEMatrix &getWback() { return Wback_; }
00318   const DEMatrix &getWout() { return Wout_; }
00320   const DEVector &getX() { return x_; }
00326   DEMatrix getDelays() throw(AUExcept) { return sim_->getDelays(); }
00327 
00329 
00330 
00331 
00335   void getWin(T **mtx, int *rows, int *cols);
00339   void getWback(T **mtx, int *rows, int *cols);
00343   void getWout(T **mtx, int *rows, int *cols);
00345   void getX(T **vec, int *length);
00352   void getW(T *wmtx, int wrows, int wcols) throw(AUExcept);
00360   void getDelays(T *wmtx, int wrows, int wcols) throw(AUExcept);
00361 
00363 
00364 
00365 
00367   void setInitAlgorithm(InitAlgorithm alg=INIT_STD)
00368     throw(AUExcept);
00370   void setTrainAlgorithm(TrainAlgorithm alg=TRAIN_PI)
00371     throw(AUExcept);
00373   void setSimAlgorithm(SimAlgorithm alg=SIM_STD)
00374     throw(AUExcept);
00375 
00377   void setSize(int neurons=10) throw(AUExcept);
00379   void setInputs(int inputs=1) throw(AUExcept);
00381   void setOutputs(int outputs=1) throw(AUExcept);
00382 
00385   void setNoise(double noise) throw(AUExcept);
00386 
00388   void setInitParam(InitParameter key, T value=0.);
00389 
00391   void setReservoirAct(ActivationFunction f=ACT_TANH) throw(AUExcept);
00393   void setOutputAct(ActivationFunction f=ACT_LINEAR) throw(AUExcept);
00394 
00401 //   void setParameter(string param, string value) throw(AUExcept);
00402 
00404 
00405 
00406 
00408   void setWin(const DEMatrix &Win) throw(AUExcept);
00410   void setW(const DEMatrix &W) throw(AUExcept);
00412   void setWback(const DEMatrix &Wback) throw(AUExcept);
00414   void setWout(const DEMatrix &Wout) throw(AUExcept);
00416   void setX(const DEVector &x) throw(AUExcept);
00417 
00423   void setLastOutput(const DEVector &last) throw(AUExcept);
00424 
00426 
00427 
00428 
00434   void setWin(T *inmtx, int inrows, int incols) throw(AUExcept);
00435 
00441   void setW(T *inmtx, int inrows, int incols) throw(AUExcept);
00442 
00448   void setWback(T *inmtx, int inrows, int incols) throw(AUExcept);
00449 
00455   void setWout(T *inmtx, int inrows, int incols) throw(AUExcept);
00456 
00462   void setX(T *invec, int insize) throw(AUExcept);
00463 
00469   void setLastOutput(T *last, int size) throw(AUExcept);
00470 
00472 
00473  protected:
00474 
00476   InitBase<T> *init_;
00477 
00479   TrainBase<T> *train_;
00480 
00482   SimBase<T> *sim_;
00483 
00484 
00487   DEMatrix Win_;
00488 
00490   SPMatrix W_;
00491 
00494   DEMatrix Wback_;
00495 
00498   DEMatrix Wout_;
00499 
00502   DEVector x_;
00503 
00504 
00509   void (*reservoirAct_)(T *data, int size);
00510 
00515   void (*outputAct_)(T *data, int size);
00516 
00521   void (*outputInvAct_)(T *data, int size);
00522 
00523 
00525   int neurons_;
00527   int inputs_;
00529   int outputs_;
00530 
00532   double noise_;
00533 
00534 
00536   ParameterMap init_params_;
00537 
00539   enum NetInfo
00540   {
00541     RESERVOIR_ACT,  
00542     OUTPUT_ACT,     
00543     INIT_ALG,       
00544     TRAIN_ALG,      
00545     SIMULATE_ALG    
00546   };
00547   typedef std::map<NetInfo, int> InfoMap;
00548 
00550   InfoMap net_info_;
00551 
00553   string getActString(int act);
00555   string getInitString(int alg);
00557   string getSimString(int alg);
00559   string getTrainString(int alg);
00560 
00561 
00563 
00564   friend class InitBase<T>;
00565   friend class InitStd<T>;
00566   friend class TrainBase<T>;
00567   friend class TrainPI<T>;
00568   friend class TrainLS<T>;
00569   friend class TrainRidgeReg<T>;
00570   friend class TrainDSPI<T>;
00571   friend class SimBase<T>;
00572   friend class SimStd<T>;
00573   friend class SimSquare<T>;
00574   friend class SimLI<T>;
00575   friend class SimBP<T>;
00576   friend class SimFilter<T>;
00577   friend class SimFilter2<T>;
00578   friend class SimFilterDS<T>;
00580 };
00581 
00582 } // end of namespace aureservoir
00583 
00584 #include <aureservoir/esn.hpp>
00585 #include <aureservoir/init.hpp>
00586 #include <aureservoir/simulate.hpp>
00587 #include <aureservoir/train.hpp>
00588 
00589 #endif // AURESERVOIR_ESN_H__

Generated on Wed Mar 12 21:16:05 2008 for aureservoir by  doxygen 1.5.3