train.hpp

Go to the documentation of this file.
00001 /***************************************************************************/
00020 namespace aureservoir
00021 {
00022 
00024 
00025 
00026 template <typename T>
00027 void TrainBase<T>::checkParams(const typename ESN<T>::DEMatrix &in,
00028                                const typename ESN<T>::DEMatrix &out,
00029                                int washout)
00030   throw(AUExcept)
00031 {
00032   if( in.numCols() != out.numCols() )
00033     throw AUExcept("TrainBase::train: input and output must be same column size!");
00034   if( in.numRows() != esn_->inputs_ )
00035     throw AUExcept("TrainBase::train: wrong input row size!");
00036   if( out.numRows() != esn_->outputs_ )
00037     throw AUExcept("TrainBase::train: wrong output row size!");
00038 
00039   if( esn_->net_info_[ESN<T>::SIMULATE_ALG] != SIM_SQUARE )
00040   {
00041     if( (in.numCols()-washout) < esn_->neurons_+esn_->inputs_ )
00042     throw AUExcept("TrainBase::train: too few training data!");
00043   }
00044   else
00045   {
00046     if( (in.numCols()-washout) < 2*(esn_->neurons_+esn_->inputs_) )
00047     throw AUExcept("TrainBase::train: too few training data!");
00048   }
00049 
00052 
00053   // reallocate data buffer for simulation algorithm
00054   esn_->sim_->reallocate();
00055 }
00056 
00057 template <typename T>
00058 void TrainBase<T>::collectStates(const typename ESN<T>::DEMatrix &in,
00059                                  const typename ESN<T>::DEMatrix &out,
00060                                  int washout)
00061 {
00062   int steps = in.numCols();
00063 
00064   // collects output of all timesteps in O
00065   O.resize(steps-washout, esn_->outputs_);
00066 
00067   // collects reservoir activations and inputs of all timesteps in M
00068   // (for squared algorithm we need a bigger matrix)
00069   if( esn_->net_info_[ESN<T>::SIMULATE_ALG] != SIM_SQUARE )
00070     M.resize(steps-washout, esn_->neurons_+esn_->inputs_);
00071   else
00072     M.resize(steps-washout, 2*(esn_->neurons_+esn_->inputs_));
00073 
00074 
00075   typename ESN<T>::DEMatrix sim_in(esn_->inputs_ ,1),
00076                             sim_out(esn_->outputs_ ,1);
00077   for(int n=1; n<=steps; ++n)
00078   {
00079     sim_in(_,1) = in(_,n);
00080     esn_->simulate(sim_in, sim_out);
00081 
00082     // for teacherforcing with feedback in single step simulation
00083     // we need to set the correct last output
00084     esn_->sim_->last_out_(_,1) = out(_,n);
00085 
00086 //     std::cout << esn_->x_ << std::endl;
00087 
00088     // store internal states, inputs and outputs after washout
00089     if( n > washout )
00090     {
00091       M(n-washout,_(1,esn_->neurons_)) = esn_->x_;
00092       M(n-washout,_(esn_->neurons_+1,esn_->neurons_+esn_->inputs_)) =
00093       sim_in(_,1);
00094     }
00095   }
00096 
00097   // collect desired outputs
00098   O = flens::transpose( out( _,_(washout+1,steps) ) );
00099 }
00100 
00101 template <typename T>
00102 void TrainBase<T>::squareStates()
00103 {
00104   // add additional squared states and inputs
00106   int Msize = esn_->neurons_+esn_->inputs_;
00107   int Mrows = M.numRows();
00108   for(int i=1; i<=Mrows; ++i) {
00109   for(int j=1; j<=Msize; ++j) {
00110     M(i,j+Msize) = pow( M(i,j), 2 );
00111   } }
00112 }
00113 
00115 
00116 
00117 
00118 template <typename T>
00119 void TrainPI<T>::train(const typename ESN<T>::DEMatrix &in,
00120                        const typename ESN<T>::DEMatrix &out,
00121                        int washout)
00122   throw(AUExcept)
00123 {
00124   this->checkParams(in,out,washout);
00125 
00126   // 1. teacher forcing, collect states
00127   this->collectStates(in,out,washout);
00128 
00129   // add additional squared states when using SIM_SQUARE
00130   if( esn_->net_info_[ESN<T>::SIMULATE_ALG] == SIM_SQUARE )
00131     this->squareStates();
00132 
00133 
00134   // 2. offline weight computation
00135 
00136   // undo output activation function
00137   esn_->outputInvAct_( O.data(), O.numRows()*O.numCols() );
00138 
00139   // calc weights with pseudo inv: Wout_ = (M^-1) * O
00140   flens::lss( M, O );
00141   esn_->Wout_ = flens::transpose( O(_( 1, M.numCols() ),_) );
00142 
00143   this->clearData();
00144 }
00145 
00147 
00148 
00149 
00150 template <typename T>
00151 void TrainLS<T>::train(const typename ESN<T>::DEMatrix &in,
00152                        const typename ESN<T>::DEMatrix &out,
00153                        int washout)
00154   throw(AUExcept)
00155 {
00156   this->checkParams(in,out,washout);
00157 
00158   // 1. teacher forcing, collect states
00159   this->collectStates(in,out,washout);
00160 
00161   // add additional squared states when using SIM_SQUARE
00162   if( esn_->net_info_[ESN<T>::SIMULATE_ALG] == SIM_SQUARE )
00163     this->squareStates();
00164 
00165 
00166   // 2. offline weight computation
00167 
00168   // undo output activation function
00169   esn_->outputInvAct_( O.data(), O.numRows()*O.numCols() );
00170 
00171   // calc weights with least square solver: Wout_ = (M^-1) * O
00172   flens::ls( flens::NoTrans, M, O );
00173   esn_->Wout_ = flens::transpose( O(_( 1, M.numCols() ),_) );
00174 
00175   this->clearData();
00176 }
00177 
00179 
00180 
00181 
00182 template <typename T>
00183 void TrainRidgeReg<T>::train(const typename ESN<T>::DEMatrix &in,
00184                        const typename ESN<T>::DEMatrix &out,
00185                        int washout)
00186   throw(AUExcept)
00187 {
00188   this->checkParams(in,out,washout);
00189 
00190   // 1. teacher forcing, collect states
00191   this->collectStates(in,out,washout);
00192 
00193   // add additional squared states when using SIM_SQUARE
00194   if( esn_->net_info_[ESN<T>::SIMULATE_ALG] == SIM_SQUARE )
00195     this->squareStates();
00196 
00197 
00198   // 2. offline weight computation
00199 
00200   // undo output activation function
00201   esn_->outputInvAct_( O.data(), O.numRows()*O.numCols() );
00202 
00203 
00204   // calc weights with ridge regression (.T = transpose):
00205   // Wout = ( (M.T*M + alpha^2*I)^-1 *M.T * O )
00206 
00207   // get regularization factor and square it
00208   T alpha = pow(esn_->init_params_[TIKHONOV_FACTOR],2);
00209 
00210   // temporal objects
00211   typename ESN<T>::DEMatrix T1(esn_->neurons_+esn_->inputs_,
00212                                esn_->neurons_+esn_->inputs_);
00213   flens::DenseVector<flens::Array<int> > t2( M.numCols() );
00214 
00215   // M.T * M
00216   T1 = flens::transpose(M)*M;
00217 
00218   // ans + alpha^2*I
00219   for(int i=1; i<=T1.numRows(); ++i)
00220     T1(i,i) += alpha;
00221 
00222   // calc inverse: (ans)^-1
00223   flens::trf(T1, t2);
00224   flens::tri(T1, t2);
00225 
00226   // ans * M.T
00227   esn_->Wout_ = T1 * flens::transpose(M);
00228 
00229   // ans * O
00230   T1 = esn_->Wout_ * O;
00231 
00232   // result = ans.T
00233   esn_->Wout_ = flens::transpose(T1);
00234 
00235 
00236   this->clearData();
00237 }
00238 
00240 
00241 
00242 
00243 template <typename T>
00244 void TrainDSPI<T>::train(const typename ESN<T>::DEMatrix &in,
00245                        const typename ESN<T>::DEMatrix &out,
00246                        int washout)
00247   throw(AUExcept)
00248 {
00249   this->checkParams(in,out,washout);
00250 
00251 
00252   // 1. teacher forcing, collect states
00253 
00254   int steps = in.numCols();
00255 
00256   // collects output of all timesteps in O
00257   O.resize(steps-washout, 1);
00258 
00259   // collects reservoir activations and inputs of all timesteps in M
00260   M.resize(steps, esn_->neurons_+esn_->inputs_);
00261 
00262   typename ESN<T>::DEMatrix sim_in(esn_->inputs_ ,1),
00263                             sim_out(esn_->outputs_ ,1);
00264   for(int n=1; n<=steps; ++n)
00265   {
00266     sim_in(_,1) = in(_,n);
00267     esn_->simulate(sim_in, sim_out);
00268 
00269     // for teacherforcing with feedback in single step simulation
00270     // we need to set the correct last output
00271     esn_->sim_->last_out_(_,1) = out(_,n);
00272 
00273     // store internal states, inputs and outputs
00274     M(n,_(1,esn_->neurons_)) = esn_->x_;
00275     M(n,_(esn_->neurons_+1,esn_->neurons_+esn_->inputs_)) =
00276     sim_in(_,1);
00277   }
00278 
00279 
00280   // 2. delay calculation for delay&sum readout
00281 
00282   // check for right simulation algorithm
00283   if( esn_->net_info_[ESN<T>::SIMULATE_ALG] != SIM_FILTER_DS &&
00284       esn_->net_info_[ESN<T>::SIMULATE_ALG] != SIM_SQUARE )
00285     throw AUExcept("TrainDSPI::train: you need to use SIM_FILTER_DS or SIM_SQUARE for this training algorithm!");
00286 
00287   // get maxdelay
00288   int maxdelay;
00289   if( esn_->init_params_.find(DS_MAXDELAY) == esn_->init_params_.end() )
00290   {
00291     // set maxdelay to 0 if we have squared state updates
00292     if( esn_->net_info_[ESN<T>::SIMULATE_ALG] == SIM_SQUARE )
00293       maxdelay = 0;
00294     else
00295       maxdelay = 1000;
00296   }
00297   else
00298     maxdelay = (int) esn_->init_params_[DS_MAXDELAY];
00299 
00300   // see if we use GCC or simple crosscorr, standard is GCC
00301   int filter;
00302   if( esn_->init_params_.find(DS_USE_CROSSCORR) == esn_->init_params_.end() )
00303     filter = 1;
00304   else
00305     filter = 0;
00306 
00307   // delay calculation
00308 
00309   int delay = 0;
00310   int fftsize = (int) pow( 2, ceil(log(steps)/log(2)) ); // next power of 2
00311   typename CDEVector<T>::Type X,Y;
00312   typename DEVector<T>::Type x,y,rest;
00313   typename DEMatrix<T>::Type T1(1,esn_->neurons_+esn_->inputs_);
00314   typename DEMatrix<T>::Type Mtmp(M.numRows(),M.numCols()); 
00315 
00316   for(int i=1; i<=esn_->outputs_; ++i)
00317   {
00318     // calc FFT of target vector
00319     y = out(i,_);
00320     rfft( y, Y, fftsize );
00321 
00322     // calc delays to reservoir neurons and inputs
00323     for(int j=1; j<=esn_->neurons_+esn_->inputs_; ++j)
00324     {
00325       // calc FFT of neuron/input vector
00326       x = M(_,j);
00327       rfft( x, X, fftsize );
00328 
00329       // calc delay with GCC
00330       delay = CalcDelay<T>::gcc(X,Y,maxdelay,filter);
00331 
00332       if( delay != 0 )
00333       {
00334         // shift signal the right amount
00335         rest = M( _(M.numRows()-delay+1,M.numRows()), j );
00336         Mtmp( _(1,delay), j ) = 0.;
00337         Mtmp( _(delay+1,M.numRows()), j ) = M( _(1,M.numRows()-delay), j );
00338 
00339         // init delay lines with the rest of the buffer
00340         esn_->sim_->initDelayLine((i-1)*(esn_->neurons_+esn_->inputs_)+j-1, rest);
00341       }
00342       else
00343         Mtmp(_,j) = M(_,j);
00344     }
00345 
00346 
00347     // 3. offline weight computation for each output extra
00348 
00349     // collect desired outputs
00350     O(_,1) = out( i ,_(washout+1,steps) );
00351 
00352     // undo output activation function
00353     esn_->outputInvAct_( O.data(), O.numRows()*O.numCols() );
00354 
00355     // square and double state if we have additional squared state updates
00356     if( esn_->net_info_[ESN<T>::SIMULATE_ALG] != SIM_SQUARE )
00357     {
00358       M = Mtmp( _(washout+1,steps), _);
00359     }
00360     else
00361     {
00362       M.resize( steps-washout, Mtmp.numCols()*2 );
00363       M( _, _(1,Mtmp.numCols()) ) = Mtmp( _(washout+1,steps), _);
00364       this->squareStates();
00365     }
00366 
00367     // calc weights with pseudo inv: Wout_ = (M^-1) * O
00368     flens::lss( M, O );
00369     T1 = flens::transpose( O(_( 1, M.numCols() ),_) );
00370     esn_->Wout_(i,_) = T1(1,_);
00371 
00372 
00373     // 4. restore simulation matrix M
00374 
00375     if( i < esn_->outputs_ )
00376     {
00377       M.resize( Mtmp.numRows(), Mtmp.numCols() );
00378 
00379       // undo the delays and store it again into M
00380       for(int j=1; j<=esn_->neurons_+esn_->inputs_; ++j)
00381       {
00382         rest = esn_->sim_->getDelayBuffer(i-1,j-1);
00383         delay = rest.length();
00384 
00385         if( delay != 0 )
00386         {
00387           M( _(1,steps-delay), j ) = Mtmp( _(delay+1,steps), j );
00388           M( _(steps-delay+1,steps), j ) = rest;
00389         }
00390         else
00391           M(_,j) = Mtmp(_,j);
00392       }
00393     }
00394   }
00395 
00396   this->clearData();
00397 }
00398 
00400 
00401 } // end of namespace aureservoir

Generated on Wed Mar 12 21:16:05 2008 for aureservoir by  doxygen 1.5.3