00001
00020 namespace aureservoir
00021 {
00022
00024
00025
00026 template <typename T>
00027 void TrainBase<T>::checkParams(const typename ESN<T>::DEMatrix &in,
00028 const typename ESN<T>::DEMatrix &out,
00029 int washout)
00030 throw(AUExcept)
00031 {
00032 if( in.numCols() != out.numCols() )
00033 throw AUExcept("TrainBase::train: input and output must be same column size!");
00034 if( in.numRows() != esn_->inputs_ )
00035 throw AUExcept("TrainBase::train: wrong input row size!");
00036 if( out.numRows() != esn_->outputs_ )
00037 throw AUExcept("TrainBase::train: wrong output row size!");
00038
00039 if( esn_->net_info_[ESN<T>::SIMULATE_ALG] != SIM_SQUARE )
00040 {
00041 if( (in.numCols()-washout) < esn_->neurons_+esn_->inputs_ )
00042 throw AUExcept("TrainBase::train: too few training data!");
00043 }
00044 else
00045 {
00046 if( (in.numCols()-washout) < 2*(esn_->neurons_+esn_->inputs_) )
00047 throw AUExcept("TrainBase::train: too few training data!");
00048 }
00049
00052
00053
00054 esn_->sim_->reallocate();
00055 }
00056
00057 template <typename T>
00058 void TrainBase<T>::collectStates(const typename ESN<T>::DEMatrix &in,
00059 const typename ESN<T>::DEMatrix &out,
00060 int washout)
00061 {
00062 int steps = in.numCols();
00063
00064
00065 O.resize(steps-washout, esn_->outputs_);
00066
00067
00068
00069 if( esn_->net_info_[ESN<T>::SIMULATE_ALG] != SIM_SQUARE )
00070 M.resize(steps-washout, esn_->neurons_+esn_->inputs_);
00071 else
00072 M.resize(steps-washout, 2*(esn_->neurons_+esn_->inputs_));
00073
00074
00075 typename ESN<T>::DEMatrix sim_in(esn_->inputs_ ,1),
00076 sim_out(esn_->outputs_ ,1);
00077 for(int n=1; n<=steps; ++n)
00078 {
00079 sim_in(_,1) = in(_,n);
00080 esn_->simulate(sim_in, sim_out);
00081
00082
00083
00084 esn_->sim_->last_out_(_,1) = out(_,n);
00085
00086
00087
00088
00089 if( n > washout )
00090 {
00091 M(n-washout,_(1,esn_->neurons_)) = esn_->x_;
00092 M(n-washout,_(esn_->neurons_+1,esn_->neurons_+esn_->inputs_)) =
00093 sim_in(_,1);
00094 }
00095 }
00096
00097
00098 O = flens::transpose( out( _,_(washout+1,steps) ) );
00099 }
00100
00101 template <typename T>
00102 void TrainBase<T>::squareStates()
00103 {
00104
00106 int Msize = esn_->neurons_+esn_->inputs_;
00107 int Mrows = M.numRows();
00108 for(int i=1; i<=Mrows; ++i) {
00109 for(int j=1; j<=Msize; ++j) {
00110 M(i,j+Msize) = pow( M(i,j), 2 );
00111 } }
00112 }
00113
00115
00116
00117
00118 template <typename T>
00119 void TrainPI<T>::train(const typename ESN<T>::DEMatrix &in,
00120 const typename ESN<T>::DEMatrix &out,
00121 int washout)
00122 throw(AUExcept)
00123 {
00124 this->checkParams(in,out,washout);
00125
00126
00127 this->collectStates(in,out,washout);
00128
00129
00130 if( esn_->net_info_[ESN<T>::SIMULATE_ALG] == SIM_SQUARE )
00131 this->squareStates();
00132
00133
00134
00135
00136
00137 esn_->outputInvAct_( O.data(), O.numRows()*O.numCols() );
00138
00139
00140 flens::lss( M, O );
00141 esn_->Wout_ = flens::transpose( O(_( 1, M.numCols() ),_) );
00142
00143 this->clearData();
00144 }
00145
00147
00148
00149
00150 template <typename T>
00151 void TrainLS<T>::train(const typename ESN<T>::DEMatrix &in,
00152 const typename ESN<T>::DEMatrix &out,
00153 int washout)
00154 throw(AUExcept)
00155 {
00156 this->checkParams(in,out,washout);
00157
00158
00159 this->collectStates(in,out,washout);
00160
00161
00162 if( esn_->net_info_[ESN<T>::SIMULATE_ALG] == SIM_SQUARE )
00163 this->squareStates();
00164
00165
00166
00167
00168
00169 esn_->outputInvAct_( O.data(), O.numRows()*O.numCols() );
00170
00171
00172 flens::ls( flens::NoTrans, M, O );
00173 esn_->Wout_ = flens::transpose( O(_( 1, M.numCols() ),_) );
00174
00175 this->clearData();
00176 }
00177
00179
00180
00181
00182 template <typename T>
00183 void TrainRidgeReg<T>::train(const typename ESN<T>::DEMatrix &in,
00184 const typename ESN<T>::DEMatrix &out,
00185 int washout)
00186 throw(AUExcept)
00187 {
00188 this->checkParams(in,out,washout);
00189
00190
00191 this->collectStates(in,out,washout);
00192
00193
00194 if( esn_->net_info_[ESN<T>::SIMULATE_ALG] == SIM_SQUARE )
00195 this->squareStates();
00196
00197
00198
00199
00200
00201 esn_->outputInvAct_( O.data(), O.numRows()*O.numCols() );
00202
00203
00204
00205
00206
00207
00208 T alpha = pow(esn_->init_params_[TIKHONOV_FACTOR],2);
00209
00210
00211 typename ESN<T>::DEMatrix T1(esn_->neurons_+esn_->inputs_,
00212 esn_->neurons_+esn_->inputs_);
00213 flens::DenseVector<flens::Array<int> > t2( M.numCols() );
00214
00215
00216 T1 = flens::transpose(M)*M;
00217
00218
00219 for(int i=1; i<=T1.numRows(); ++i)
00220 T1(i,i) += alpha;
00221
00222
00223 flens::trf(T1, t2);
00224 flens::tri(T1, t2);
00225
00226
00227 esn_->Wout_ = T1 * flens::transpose(M);
00228
00229
00230 T1 = esn_->Wout_ * O;
00231
00232
00233 esn_->Wout_ = flens::transpose(T1);
00234
00235
00236 this->clearData();
00237 }
00238
00240
00241
00242
00243 template <typename T>
00244 void TrainDSPI<T>::train(const typename ESN<T>::DEMatrix &in,
00245 const typename ESN<T>::DEMatrix &out,
00246 int washout)
00247 throw(AUExcept)
00248 {
00249 this->checkParams(in,out,washout);
00250
00251
00252
00253
00254 int steps = in.numCols();
00255
00256
00257 O.resize(steps-washout, 1);
00258
00259
00260 M.resize(steps, esn_->neurons_+esn_->inputs_);
00261
00262 typename ESN<T>::DEMatrix sim_in(esn_->inputs_ ,1),
00263 sim_out(esn_->outputs_ ,1);
00264 for(int n=1; n<=steps; ++n)
00265 {
00266 sim_in(_,1) = in(_,n);
00267 esn_->simulate(sim_in, sim_out);
00268
00269
00270
00271 esn_->sim_->last_out_(_,1) = out(_,n);
00272
00273
00274 M(n,_(1,esn_->neurons_)) = esn_->x_;
00275 M(n,_(esn_->neurons_+1,esn_->neurons_+esn_->inputs_)) =
00276 sim_in(_,1);
00277 }
00278
00279
00280
00281
00282
00283 if( esn_->net_info_[ESN<T>::SIMULATE_ALG] != SIM_FILTER_DS &&
00284 esn_->net_info_[ESN<T>::SIMULATE_ALG] != SIM_SQUARE )
00285 throw AUExcept("TrainDSPI::train: you need to use SIM_FILTER_DS or SIM_SQUARE for this training algorithm!");
00286
00287
00288 int maxdelay;
00289 if( esn_->init_params_.find(DS_MAXDELAY) == esn_->init_params_.end() )
00290 {
00291
00292 if( esn_->net_info_[ESN<T>::SIMULATE_ALG] == SIM_SQUARE )
00293 maxdelay = 0;
00294 else
00295 maxdelay = 1000;
00296 }
00297 else
00298 maxdelay = (int) esn_->init_params_[DS_MAXDELAY];
00299
00300
00301 int filter;
00302 if( esn_->init_params_.find(DS_USE_CROSSCORR) == esn_->init_params_.end() )
00303 filter = 1;
00304 else
00305 filter = 0;
00306
00307
00308
00309 int delay = 0;
00310 int fftsize = (int) pow( 2, ceil(log(steps)/log(2)) );
00311 typename CDEVector<T>::Type X,Y;
00312 typename DEVector<T>::Type x,y,rest;
00313 typename DEMatrix<T>::Type T1(1,esn_->neurons_+esn_->inputs_);
00314 typename DEMatrix<T>::Type Mtmp(M.numRows(),M.numCols());
00315
00316 for(int i=1; i<=esn_->outputs_; ++i)
00317 {
00318
00319 y = out(i,_);
00320 rfft( y, Y, fftsize );
00321
00322
00323 for(int j=1; j<=esn_->neurons_+esn_->inputs_; ++j)
00324 {
00325
00326 x = M(_,j);
00327 rfft( x, X, fftsize );
00328
00329
00330 delay = CalcDelay<T>::gcc(X,Y,maxdelay,filter);
00331
00332 if( delay != 0 )
00333 {
00334
00335 rest = M( _(M.numRows()-delay+1,M.numRows()), j );
00336 Mtmp( _(1,delay), j ) = 0.;
00337 Mtmp( _(delay+1,M.numRows()), j ) = M( _(1,M.numRows()-delay), j );
00338
00339
00340 esn_->sim_->initDelayLine((i-1)*(esn_->neurons_+esn_->inputs_)+j-1, rest);
00341 }
00342 else
00343 Mtmp(_,j) = M(_,j);
00344 }
00345
00346
00347
00348
00349
00350 O(_,1) = out( i ,_(washout+1,steps) );
00351
00352
00353 esn_->outputInvAct_( O.data(), O.numRows()*O.numCols() );
00354
00355
00356 if( esn_->net_info_[ESN<T>::SIMULATE_ALG] != SIM_SQUARE )
00357 {
00358 M = Mtmp( _(washout+1,steps), _);
00359 }
00360 else
00361 {
00362 M.resize( steps-washout, Mtmp.numCols()*2 );
00363 M( _, _(1,Mtmp.numCols()) ) = Mtmp( _(washout+1,steps), _);
00364 this->squareStates();
00365 }
00366
00367
00368 flens::lss( M, O );
00369 T1 = flens::transpose( O(_( 1, M.numCols() ),_) );
00370 esn_->Wout_(i,_) = T1(1,_);
00371
00372
00373
00374
00375 if( i < esn_->outputs_ )
00376 {
00377 M.resize( Mtmp.numRows(), Mtmp.numCols() );
00378
00379
00380 for(int j=1; j<=esn_->neurons_+esn_->inputs_; ++j)
00381 {
00382 rest = esn_->sim_->getDelayBuffer(i-1,j-1);
00383 delay = rest.length();
00384
00385 if( delay != 0 )
00386 {
00387 M( _(1,steps-delay), j ) = Mtmp( _(delay+1,steps), j );
00388 M( _(steps-delay+1,steps), j ) = rest;
00389 }
00390 else
00391 M(_,j) = Mtmp(_,j);
00392 }
00393 }
00394 }
00395
00396 this->clearData();
00397 }
00398
00400
00401 }