diff options
Diffstat (limited to 'methods/conjugate_gradient/conjugate_gradient.hpp')
-rw-r--r-- | methods/conjugate_gradient/conjugate_gradient.hpp | 129 |
1 files changed, 77 insertions, 52 deletions
diff --git a/methods/conjugate_gradient/conjugate_gradient.hpp b/methods/conjugate_gradient/conjugate_gradient.hpp index edec92d..c22aac8 100644 --- a/methods/conjugate_gradient/conjugate_gradient.hpp +++ b/methods/conjugate_gradient/conjugate_gradient.hpp @@ -1,6 +1,6 @@ /** \file conjugate_gradient.hpp - \brief conjugate gradient optimization method + \brief powerll optimization method \author Junhua Gu */ @@ -13,9 +13,10 @@ #include <cassert> #include <cmath> #include "../linmin/linmin.hpp" +#include <math/num_diff.hpp> #include <algorithm> #include <iostream> -#include <math/num_diff.hpp> + namespace opt_utilities { /** @@ -34,10 +35,11 @@ namespace opt_utilities func_obj<rT,pT>* p_fo; optimizer<rT,pT>* p_optimizer; volatile bool bstop; + //typedef blitz::Array<rT,2> array2d_type; const char* do_get_type_name()const { - return "conjugate gradient"; + return "conjugate gradient method"; } private: array1d_type start_point; @@ -45,6 +47,9 @@ namespace opt_utilities private: rT threshold; + array1d_type g; + array1d_type h; + array1d_type xi; private: rT func(const pT& x) { @@ -54,24 +59,86 @@ namespace opt_utilities private: + void clear_xi() + { + } + + void init_xi(int n) + { + clear_xi(); + g=array1d_type(n); + h=array1d_type(n); + xi=array1d_type(n); + } + + void cg(array1d_type& p,const T ftol, + int& iter,T& fret) + { + const int ITMAX=200; + const T EPS=std::numeric_limits<T>::epsilon(); + + int j,its; + int n=p.size(); + T gg,gam,fp,dgg; + fp=func(p); + xi=gradient(*p_fo,p); + for(j=0;j<n;++j) + { + g[j]=-xi[j]; + xi[j]=h[j]=g[j]; + } + for(its=1;its<=ITMAX;++its) + { + iter=its; + linmin(p,xi,fret,*p_fo); + //std::cerr<<"######:"<<its<<"\t"<<abs(fret-fp)/(abs(fret)+fabs(fp)+EPS)<<std::endl; + if(2.0*abs(fret-fp)<=ftol*(abs(fret)+fabs(fp)+EPS)) + { + return; + } + fp=func(p); + xi=gradient(*p_fo,p); + dgg=gg=0; + for(j=0;j<n;++j) + { + gg+=g[j]*g[j]; + //dgg+=(xi[j]+g[j])*xi[j]; + dgg+=xi[j]*xi[j]; + } + std::cerr<<its<<"\t"<<gg<<std::endl; + if(gg==0.0) + { + return; + } + gam=dgg/gg; + for(j=0;j<n;++j) + { + g[j]=-xi[j]; + xi[j]=h[j]=g[j]+gam*h[j]; + } + } + std::cerr<<"Too many iterations in cg"<<std::endl; + } + public: conjugate_gradient() - :threshold(1e-4) + :threshold(1e-4),g(0),h(0),xi(0) {} virtual ~conjugate_gradient() { + clear_xi(); }; conjugate_gradient(const conjugate_gradient<rT,pT>& rhs) :opt_method<rT,pT>(rhs),p_fo(rhs.p_fo),p_optimizer(rhs.p_optimizer), start_point(rhs.start_point), end_point(rhs.end_point), - threshold(rhs.threshold) + threshold(rhs.threshold),g(0),h(0),xi(0) { } @@ -128,54 +195,12 @@ namespace opt_utilities pT do_optimize() { bstop=false; + init_xi((int)get_size(start_point)); + + int iter=100; opt_eq(end_point,start_point); - pT xn; - opt_eq(xn,start_point); - pT Delta_Xn1(gradient(*p_fo,start_point)); - for(size_t i=0;i<get_size(start_point);++i)Delta_Xn1[i]=-Delta_Xn1[i]; - rT alpha=0; - linmin(start_point,Delta_Xn1,alpha,(*p_fo)); - for(size_t i=0;i<get_size(start_point);++i)xn[i]=start_point[i]+alpha*Delta_Xn1[i]; - pT LX; - opt_eq(LX,Delta_Xn1); - for(int n=1;;++n) - { - pT Delta_Xn(gradient(*p_fo,xn)); - for(size_t i=0;i<get_size(start_point);++i)Delta_Xn[i]=-Delta_Xn[i]; - ////calc beta n - rT betan; - rT b1(0),b2(0); - for(size_t i=0;i<get_size(start_point);++i) - { - b1+=Delta_Xn[i]*(Delta_Xn[i]-Delta_Xn1[i]); - b2+=Delta_Xn1[i]*Delta_Xn1[i]; - } - if(b2==0) - { - end_point=xn; - return end_point; - } - betan=max(rT(0),b1/b2); - //// - for(size_t i=0;i<get_size(start_point);++i) - LX[i]=Delta_Xn[i]+betan*LX[i]; - linmin(xn,LX,alpha,(*p_fo)); - for(size_t i=0;i<get_size(start_point);++i) - xn[i]+=alpha*LX[i]; - rT delta=0; - rT xn_abs=0; - for(size_t i=0;i<get_size(start_point);++i) - { - delta+=LX[i]*LX[i]; - xn_abs+=xn[i]*xn[i]; - } - if(delta*alpha*alpha<threshold) - { - opt_eq(end_point,xn); - break; - } - opt_eq(Delta_Xn1,Delta_Xn); - } + rT fret; + cg(end_point,threshold,iter,fret); return end_point; } |