aboutsummaryrefslogtreecommitdiffstats
path: root/methods/conjugate_gradient/conjugate_gradient.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'methods/conjugate_gradient/conjugate_gradient.hpp')
-rw-r--r--methods/conjugate_gradient/conjugate_gradient.hpp129
1 files changed, 77 insertions, 52 deletions
diff --git a/methods/conjugate_gradient/conjugate_gradient.hpp b/methods/conjugate_gradient/conjugate_gradient.hpp
index edec92d..c22aac8 100644
--- a/methods/conjugate_gradient/conjugate_gradient.hpp
+++ b/methods/conjugate_gradient/conjugate_gradient.hpp
@@ -1,6 +1,6 @@
/**
\file conjugate_gradient.hpp
- \brief conjugate gradient optimization method
+ \brief powerll optimization method
\author Junhua Gu
*/
@@ -13,9 +13,10 @@
#include <cassert>
#include <cmath>
#include "../linmin/linmin.hpp"
+#include <math/num_diff.hpp>
#include <algorithm>
#include <iostream>
-#include <math/num_diff.hpp>
+
namespace opt_utilities
{
/**
@@ -34,10 +35,11 @@ namespace opt_utilities
func_obj<rT,pT>* p_fo;
optimizer<rT,pT>* p_optimizer;
volatile bool bstop;
+ //typedef blitz::Array<rT,2> array2d_type;
const char* do_get_type_name()const
{
- return "conjugate gradient";
+ return "conjugate gradient method";
}
private:
array1d_type start_point;
@@ -45,6 +47,9 @@ namespace opt_utilities
private:
rT threshold;
+ array1d_type g;
+ array1d_type h;
+ array1d_type xi;
private:
rT func(const pT& x)
{
@@ -54,24 +59,86 @@ namespace opt_utilities
private:
+ void clear_xi()
+ {
+ }
+
+ void init_xi(int n)
+ {
+ clear_xi();
+ g=array1d_type(n);
+ h=array1d_type(n);
+ xi=array1d_type(n);
+ }
+
+ void cg(array1d_type& p,const T ftol,
+ int& iter,T& fret)
+ {
+ const int ITMAX=200;
+ const T EPS=std::numeric_limits<T>::epsilon();
+
+ int j,its;
+ int n=p.size();
+ T gg,gam,fp,dgg;
+ fp=func(p);
+ xi=gradient(*p_fo,p);
+ for(j=0;j<n;++j)
+ {
+ g[j]=-xi[j];
+ xi[j]=h[j]=g[j];
+ }
+ for(its=1;its<=ITMAX;++its)
+ {
+ iter=its;
+ linmin(p,xi,fret,*p_fo);
+ //std::cerr<<"######:"<<its<<"\t"<<abs(fret-fp)/(abs(fret)+fabs(fp)+EPS)<<std::endl;
+ if(2.0*abs(fret-fp)<=ftol*(abs(fret)+fabs(fp)+EPS))
+ {
+ return;
+ }
+ fp=func(p);
+ xi=gradient(*p_fo,p);
+ dgg=gg=0;
+ for(j=0;j<n;++j)
+ {
+ gg+=g[j]*g[j];
+ //dgg+=(xi[j]+g[j])*xi[j];
+ dgg+=xi[j]*xi[j];
+ }
+ std::cerr<<its<<"\t"<<gg<<std::endl;
+ if(gg==0.0)
+ {
+ return;
+ }
+ gam=dgg/gg;
+ for(j=0;j<n;++j)
+ {
+ g[j]=-xi[j];
+ xi[j]=h[j]=g[j]+gam*h[j];
+ }
+ }
+ std::cerr<<"Too many iterations in cg"<<std::endl;
+ }
+
public:
conjugate_gradient()
- :threshold(1e-4)
+ :threshold(1e-4),g(0),h(0),xi(0)
{}
virtual ~conjugate_gradient()
{
+ clear_xi();
};
conjugate_gradient(const conjugate_gradient<rT,pT>& rhs)
:opt_method<rT,pT>(rhs),p_fo(rhs.p_fo),p_optimizer(rhs.p_optimizer),
start_point(rhs.start_point),
end_point(rhs.end_point),
- threshold(rhs.threshold)
+ threshold(rhs.threshold),g(0),h(0),xi(0)
{
}
@@ -128,54 +195,12 @@ namespace opt_utilities
pT do_optimize()
{
bstop=false;
+ init_xi((int)get_size(start_point));
+
+ int iter=100;
opt_eq(end_point,start_point);
- pT xn;
- opt_eq(xn,start_point);
- pT Delta_Xn1(gradient(*p_fo,start_point));
- for(size_t i=0;i<get_size(start_point);++i)Delta_Xn1[i]=-Delta_Xn1[i];
- rT alpha=0;
- linmin(start_point,Delta_Xn1,alpha,(*p_fo));
- for(size_t i=0;i<get_size(start_point);++i)xn[i]=start_point[i]+alpha*Delta_Xn1[i];
- pT LX;
- opt_eq(LX,Delta_Xn1);
- for(int n=1;;++n)
- {
- pT Delta_Xn(gradient(*p_fo,xn));
- for(size_t i=0;i<get_size(start_point);++i)Delta_Xn[i]=-Delta_Xn[i];
- ////calc beta n
- rT betan;
- rT b1(0),b2(0);
- for(size_t i=0;i<get_size(start_point);++i)
- {
- b1+=Delta_Xn[i]*(Delta_Xn[i]-Delta_Xn1[i]);
- b2+=Delta_Xn1[i]*Delta_Xn1[i];
- }
- if(b2==0)
- {
- end_point=xn;
- return end_point;
- }
- betan=max(rT(0),b1/b2);
- ////
- for(size_t i=0;i<get_size(start_point);++i)
- LX[i]=Delta_Xn[i]+betan*LX[i];
- linmin(xn,LX,alpha,(*p_fo));
- for(size_t i=0;i<get_size(start_point);++i)
- xn[i]+=alpha*LX[i];
- rT delta=0;
- rT xn_abs=0;
- for(size_t i=0;i<get_size(start_point);++i)
- {
- delta+=LX[i]*LX[i];
- xn_abs+=xn[i]*xn[i];
- }
- if(delta*alpha*alpha<threshold)
- {
- opt_eq(end_point,xn);
- break;
- }
- opt_eq(Delta_Xn1,Delta_Xn);
- }
+ rT fret;
+ cg(end_point,threshold,iter,fret);
return end_point;
}