00001 #ifndef _HOPFIELDNETWORK_H 00002 #define _HOPFIELDNETWORK_H 00003 00004 #include "Network.h" 00005 #include "RecurrentNeuron.h" 00006 #include "Matrix.h" 00007 #include <vector> 00008 00009 namespace annie 00010 { 00011 00016 real isPositive(real x); 00017 00030 class HopfieldNetwork : public Network 00031 { 00032 protected: 00034 bool _bipolar; 00035 00037 int _nPatterns; 00038 00040 Matrix *_weightMatrix; 00041 00043 RecurrentNeuron** _neurons; 00044 00045 virtual bool _equal(std::vector<int> &p1, std::vector<int> &p2); 00046 int _time; 00047 00048 public: 00055 HopfieldNetwork(int size); 00056 00064 HopfieldNetwork(int size, bool bias, bool bipolar); 00065 00071 HopfieldNetwork(const char *filename); 00072 00074 virtual ~HopfieldNetwork(); 00075 00082 virtual void addPattern(int pattern[]); 00083 00085 virtual real getEnergy(); 00086 00090 virtual real getEnergy(int pattern[]); 00091 00093 virtual int getSize(); 00094 00096 virtual void step(); 00097 00102 virtual int getTime(); 00103 00108 virtual int getPatternCount(); 00109 00111 virtual Matrix getWeightMatrix(); 00112 00114 virtual const char* getClassName(); 00115 00121 virtual void save(const char *filename); 00122 00132 virtual void setWeight(int i, int j, real weight); 00133 00134 /* Given an input pattern, keeps iterating through time till the network 00135 * output converges. Ofcourse, it is possible that this never happens 00136 * and hence a timeout has to be specified. 00137 * \todo Implement this! 00138 * @param pattern The initial input pattern given to the network 00139 * @param updateAll Determines type of updating (synchronous, asynchronous) 00140 * @param timeout The maximum number of iteration to try convergence for 00141 * @return false if the network output didn't converge till the timeout, true otherwise 00142 */ 00143 //virtual bool converge(int pattern[], bool updateAll, int timeout); 00144 00149 virtual real getBias(int i); 00150 00155 virtual void setBias(int i, real bias); 00156 00162 virtual void setInput(int pattern[]); 00163 00170 virtual void setInput(std::vector<int> &pattern); 00171 00175 virtual std::vector<int> getOutput(); 00176 00181 virtual std::vector<int> getNextOutput(); 00182 00190 virtual VECTOR getOutput(VECTOR &input); 00191 00199 virtual bool propagate(int pattern[], int timeout); 00200 00208 virtual bool propagate(std::vector<int> &pattern, int timeout); 00209 }; 00210 }; //namespace annie 00211 #endif // define _HOPFIELDNETWORK_H 00212
1.2.14 written by Dimitri van Heesch,
© 1997-2002