aboutsummaryrefslogtreecommitdiffstats
path: root/math
diff options
context:
space:
mode:
authorastrojhgu <astrojhgu@ed2142bd-67ad-457f-ba7c-d818d4011675>2010-06-04 08:34:16 +0000
committerastrojhgu <astrojhgu@ed2142bd-67ad-457f-ba7c-d818d4011675>2010-06-04 08:34:16 +0000
commitabcaca09e68e61062b2bf860c6d3328923869234 (patch)
treeb9e7a8e3f5ff4074908da3551c09cc668b503e9c /math
parent109473cacf85c7e2a2831d56531399ab4621654e (diff)
downloadopt-utilities-abcaca09e68e61062b2bf860c6d3328923869234.tar.bz2
c
git-svn-id: file:///home/svn/opt_utilities@119 ed2142bd-67ad-457f-ba7c-d818d4011675
Diffstat (limited to 'math')
-rw-r--r--math/num_diff.hpp113
1 files changed, 43 insertions, 70 deletions
diff --git a/math/num_diff.hpp b/math/num_diff.hpp
index 27a1674..0bd2937 100644
--- a/math/num_diff.hpp
+++ b/math/num_diff.hpp
@@ -14,55 +14,15 @@
namespace opt_utilities
{
-
- /**
- \brief differentiable function
- \tparam rT the return type
- \tparam the parameter type
- */
- template <typename rT,typename pT>
- class dfunc_obj
- :public func_obj<rT,pT>
- {
- private:
- virtual pT do_diff(const pT& p)=0;
-
- public:
- /**
- calculate the differentiation
- \param p the self-var
- \return the gradient at p
- */
- pT diff(const pT& p)
- {
- return do_diff(p);
- }
- };
-
-
- /**
- When trying to diff an object function, when it is not differentiable,
- this exception will be thrown.
- */
- class underivable
- :public opt_exception
- {
- public:
- underivable()
- :opt_exception("underivable")
- {}
- };
-
/**
calculate the numerical differential of a func_obj
- */
+ */
template <typename rT,typename pT>
- pT numdiff(func_obj<rT,pT>& f,const pT& p)
+ rT gradient(func_obj<rT,pT>& f,const pT& p,size_t n)
{
rT ep=std::sqrt(std::numeric_limits<rT>::epsilon());
- pT result;
- resize(result,get_size(p));
+ rT result;
pT p2;
resize(p2,get_size(p));
pT p1;
@@ -73,42 +33,55 @@ namespace opt_utilities
set_element(p2,i,get_element(p,i));
set_element(p1,i,get_element(p,i));
}
- for(size_t i=0;i<get_size(p);++i)
- {
- typename element_type_trait<pT>::element_type h=
- std::max(get_element(p,i),rT(1))*ep;
-
- set_element(p2,i,get_element(p,i)+h);
- set_element(p1,i,get_element(p,i)-h);
-
- rT v2=f(p2);
- rT v1=f(p1);
- set_element(result,i,
- (v2-v1)/h/2
- );
- set_element(p2,i,get_element(p,i));
- set_element(p1,i,get_element(p,i));
- }
+ typename element_type_trait<pT>::element_type h=
+ std::max(get_element(p,n),rT(1))*ep;
+ set_element(p2,n,get_element(p,n)+h);
+ set_element(p1,n,get_element(p,n)-h);
+
+ rT v2=f(p2);
+ rT v1=f(p1);
+
+ result=(v2-v1)/h/2;
return result;
}
- /**
- Help function to calculate the gradient of an objection function
- func_obj, whether it is differentiable or not. If it is differentiable,
- the gradient will be calculated by calling the diff member in the func_obj,
- or a numerical calculation will be performed.
- */
template <typename rT,typename pT>
- pT diff(func_obj<rT,pT>& f,const pT& p)
+ rT hessian(func_obj<rT,pT>& f,const pT& p,size_t m,size_t n)
{
- dfunc_obj<rT,pT>* pdf=dynamic_cast<dfunc_obj<rT,pT>*>(&f);
- if(pdf)
+ rT ep=std::sqrt(std::numeric_limits<rT>::epsilon());
+ typename element_type_trait<pT>::element_type hn=
+ std::max(get_element(p,n),rT(1))*ep;
+ typename element_type_trait<pT>::element_type hm=
+ std::max(get_element(p,m),rT(1))*ep;
+ pT p11;
+ resize(p11,get_size(p));
+ pT p00;
+ resize(p00,get_size(p));
+ pT p10;
+ resize(p10,get_size(p));
+ pT p01;
+ resize(p01,get_size(p));
+
+ for(size_t i=0;i<get_size(p);++i)
{
- return pdf->diff(p);
+ set_element(p11,i,get_element(p,i));
+ set_element(p00,i,get_element(p,i));
+ set_element(p01,i,get_element(p,i));
+ set_element(p10,i,get_element(p,i));
}
- return numdiff(f,p);
+ set_element(p11,m,get_element(p11,m)+hm);
+ set_element(p11,n,get_element(p11,n)+hn);
+ set_element(p00,m,get_element(p00,m)-hm);
+ set_element(p00,n,get_element(p00,n)-hn);
+ set_element(p10,m,get_element(p10,m)+hm);
+ set_element(p10,n,get_element(p10,n)-hn);
+ set_element(p01,m,get_element(p01,m)-hm);
+ set_element(p01,n,get_element(p01,n)+hn);
+
+ rT result=(f(p11)+f(p00)-f(p01)-f(p10))/(4*hm*hn);
+ return result;
}
}