#include <train.h>
Simply derive from this class if you want to add a new algorithm.
Public Member Functions | |
TrainBase (ESN< T > *esn) | |
virtual | ~TrainBase () |
virtual void | train (const typename ESN< T >::DEMatrix &in, const typename ESN< T >::DEMatrix &out, int washout)=0 throw (AUExcept) |
Protected Member Functions | |
void | clearData () |
class TrainBase Implementation | |
void | checkParams (const typename ESN< T >::DEMatrix &in, const typename ESN< T >::DEMatrix &out, int washout) throw (AUExcept) |
void | collectStates (const typename ESN< T >::DEMatrix &in, const typename ESN< T >::DEMatrix &out, int washout) |
void | squareStates () |
Protected Attributes | |
ESN< T > * | esn_ |
ESN< T >::DEMatrix | M |
ESN< T >::DEMatrix | O |
aureservoir::TrainBase< T >::TrainBase | ( | ESN< T > * | esn | ) | [inline] |
Constructor.
virtual aureservoir::TrainBase< T >::~TrainBase | ( | ) | [inline, virtual] |
Destructor.
virtual void aureservoir::TrainBase< T >::train | ( | const typename ESN< T >::DEMatrix & | in, | |
const typename ESN< T >::DEMatrix & | out, | |||
int | washout | |||
) | throw (AUExcept) [pure virtual] |
training algorithm
in | matrix of input values (inputs x timesteps) | |
out | matrix of desired output values (outputs x timesteps) for teacher forcing | |
washout | washout time in samples, used to get rid of the transient dynamics of the network starting state |
Implemented in aureservoir::TrainPI< T >, aureservoir::TrainLS< T >, aureservoir::TrainRidgeReg< T >, and aureservoir::TrainDSPI< T >.
void aureservoir::TrainBase< T >::checkParams | ( | const typename ESN< T >::DEMatrix & | in, | |
const typename ESN< T >::DEMatrix & | out, | |||
int | washout | |||
) | throw (AUExcept) [inline, protected] |
check parameters
void aureservoir::TrainBase< T >::collectStates | ( | const typename ESN< T >::DEMatrix & | in, | |
const typename ESN< T >::DEMatrix & | out, | |||
int | washout | |||
) | [inline, protected] |
collect network states with simulation algorithm
void aureservoir::TrainBase< T >::squareStates | ( | ) | [inline, protected] |
squares states for SIM_SQUARE
void aureservoir::TrainBase< T >::clearData | ( | ) | [inline, protected] |
frees allocated data for M and O
ESN<T>* aureservoir::TrainBase< T >::esn_ [protected] |
reference to the data of the network
ESN<T>::DEMatrix aureservoir::TrainBase< T >::M [protected] |
matrix for network states and inputs over all timesteps
ESN<T>::DEMatrix aureservoir::TrainBase< T >::O [protected] |
matrix for outputs over all timesteps