esn_example.cpp

00001 /***************************************************************************/
00020 #include "aureservoir/aureservoir.h"
00021 
00022 #define TYPE double
00023 
00024 #include <iostream>
00025 #include <complex>
00026 
00027 using namespace aureservoir;
00028 using namespace std;
00029 
00030 int main(int argc, char *argv[])
00031 {
00032   ESN< TYPE > net;
00033 
00034   try
00035   {
00036     cout << "## INITIALIZATION ##\n";
00037 
00038     int train_size = 50;
00039     int ins = 3;
00040     int outs = 2;
00041 
00042     net.setSize(15);
00043     net.setInputs(ins);
00044     net.setOutputs(outs);
00045     net.setInitParam(CONNECTIVITY, 0.8);
00046     net.setInitParam(IN_CONNECTIVITY, 0.6);
00047     net.setInitParam(FB_CONNECTIVITY, 0.4);
00048 
00049     net.init();
00050 
00051     // print current net parameters
00052     net.post();
00053     cout << endl << "input weights W_in: " << net.getWin();
00054     cout << endl << "feedback weights W_back: " << net.getWback();
00055 //     cout << endl << "reservoir weight matrix W: " << net.getW() << endl;
00056 
00057 
00058     cout << "\n## TRAINING ##\n";
00059 
00060     ESN<TYPE>::DEMatrix in(ins,train_size), out(outs,train_size);
00061 
00062     for(int i=1; i<=train_size; ++i)
00063     {
00064       for(int j=1; j<=ins; ++j)
00065         in(j,i) = Rand<TYPE>::uniform();
00066 
00067       for(int j=1; j<=outs; ++j)
00068         out(j,i) = Rand<TYPE>::uniform();
00069     }
00070 
00071     net.train(in, out, 20);
00072 
00073     cout << "\ntrained output weights W_out: " << net.getWout() << endl;
00074 
00075 
00076     cout << "## SIMULATION ##\n";
00077 
00078     int run_size = 10;
00079     ESN<TYPE>::DEMatrix indata(ins,run_size), result(outs,run_size);
00080 
00081     for(int i=1; i<=run_size; ++i)
00082     {
00083       for(int j=1; j<=ins; ++j)
00084         indata(j,i) = Rand<>::uniform();
00085     }
00086 
00087     net.simulate( indata, result );
00088 
00089     cout << endl << "simulation results: " << result << endl;
00090   }
00091   catch(AUExcept e)
00092   { cout << "Exception: " << e.what() << endl; }
00093 
00094   return 0;
00095 }

Generated on Wed Mar 12 21:16:05 2008 for aureservoir by  doxygen 1.5.3