diff options
author | astrojhgu <astrojhgu@ed2142bd-67ad-457f-ba7c-d818d4011675> | 2010-06-04 08:34:16 +0000 |
---|---|---|
committer | astrojhgu <astrojhgu@ed2142bd-67ad-457f-ba7c-d818d4011675> | 2010-06-04 08:34:16 +0000 |
commit | abcaca09e68e61062b2bf860c6d3328923869234 (patch) | |
tree | b9e7a8e3f5ff4074908da3551c09cc668b503e9c /math | |
parent | 109473cacf85c7e2a2831d56531399ab4621654e (diff) | |
download | opt-utilities-abcaca09e68e61062b2bf860c6d3328923869234.tar.bz2 |
c
git-svn-id: file:///home/svn/opt_utilities@119 ed2142bd-67ad-457f-ba7c-d818d4011675
Diffstat (limited to 'math')
-rw-r--r-- | math/num_diff.hpp | 113 |
1 files changed, 43 insertions, 70 deletions
diff --git a/math/num_diff.hpp b/math/num_diff.hpp index 27a1674..0bd2937 100644 --- a/math/num_diff.hpp +++ b/math/num_diff.hpp @@ -14,55 +14,15 @@ namespace opt_utilities { - - /** - \brief differentiable function - \tparam rT the return type - \tparam the parameter type - */ - template <typename rT,typename pT> - class dfunc_obj - :public func_obj<rT,pT> - { - private: - virtual pT do_diff(const pT& p)=0; - - public: - /** - calculate the differentiation - \param p the self-var - \return the gradient at p - */ - pT diff(const pT& p) - { - return do_diff(p); - } - }; - - - /** - When trying to diff an object function, when it is not differentiable, - this exception will be thrown. - */ - class underivable - :public opt_exception - { - public: - underivable() - :opt_exception("underivable") - {} - }; - /** calculate the numerical differential of a func_obj - */ + */ template <typename rT,typename pT> - pT numdiff(func_obj<rT,pT>& f,const pT& p) + rT gradient(func_obj<rT,pT>& f,const pT& p,size_t n) { rT ep=std::sqrt(std::numeric_limits<rT>::epsilon()); - pT result; - resize(result,get_size(p)); + rT result; pT p2; resize(p2,get_size(p)); pT p1; @@ -73,42 +33,55 @@ namespace opt_utilities set_element(p2,i,get_element(p,i)); set_element(p1,i,get_element(p,i)); } - for(size_t i=0;i<get_size(p);++i) - { - typename element_type_trait<pT>::element_type h= - std::max(get_element(p,i),rT(1))*ep; - - set_element(p2,i,get_element(p,i)+h); - set_element(p1,i,get_element(p,i)-h); - - rT v2=f(p2); - rT v1=f(p1); - set_element(result,i, - (v2-v1)/h/2 - ); - set_element(p2,i,get_element(p,i)); - set_element(p1,i,get_element(p,i)); - } + typename element_type_trait<pT>::element_type h= + std::max(get_element(p,n),rT(1))*ep; + set_element(p2,n,get_element(p,n)+h); + set_element(p1,n,get_element(p,n)-h); + + rT v2=f(p2); + rT v1=f(p1); + + result=(v2-v1)/h/2; return result; } - /** - Help function to calculate the gradient of an objection function - func_obj, whether it is differentiable or not. If it is differentiable, - the gradient will be calculated by calling the diff member in the func_obj, - or a numerical calculation will be performed. - */ template <typename rT,typename pT> - pT diff(func_obj<rT,pT>& f,const pT& p) + rT hessian(func_obj<rT,pT>& f,const pT& p,size_t m,size_t n) { - dfunc_obj<rT,pT>* pdf=dynamic_cast<dfunc_obj<rT,pT>*>(&f); - if(pdf) + rT ep=std::sqrt(std::numeric_limits<rT>::epsilon()); + typename element_type_trait<pT>::element_type hn= + std::max(get_element(p,n),rT(1))*ep; + typename element_type_trait<pT>::element_type hm= + std::max(get_element(p,m),rT(1))*ep; + pT p11; + resize(p11,get_size(p)); + pT p00; + resize(p00,get_size(p)); + pT p10; + resize(p10,get_size(p)); + pT p01; + resize(p01,get_size(p)); + + for(size_t i=0;i<get_size(p);++i) { - return pdf->diff(p); + set_element(p11,i,get_element(p,i)); + set_element(p00,i,get_element(p,i)); + set_element(p01,i,get_element(p,i)); + set_element(p10,i,get_element(p,i)); } - return numdiff(f,p); + set_element(p11,m,get_element(p11,m)+hm); + set_element(p11,n,get_element(p11,n)+hn); + set_element(p00,m,get_element(p00,m)-hm); + set_element(p00,n,get_element(p00,n)-hn); + set_element(p10,m,get_element(p10,m)+hm); + set_element(p10,n,get_element(p10,n)-hn); + set_element(p01,m,get_element(p01,m)-hm); + set_element(p01,n,get_element(p01,n)+hn); + + rT result=(f(p11)+f(p00)-f(p01)-f(p10))/(4*hm*hn); + return result; } } |