aboutsummaryrefslogtreecommitdiffstats
path: root/core/num_diff.hpp
blob: 27a16740fec1b712ff86a168460f57c0232d174f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/**
   \file num_diff.hpp
 */


#ifndef NUMDIFF_HPP
#define NUMDIFF_HPP
#define OPT_HEADER
#include <core/optimizer.hpp>
#include <core/opt_traits.hpp>
#include <algorithm>
#include <limits>
#include <cmath>

namespace opt_utilities
{

  /**
     \brief differentiable function
     \tparam rT the return type
     \tparam the parameter type
   */
  template <typename rT,typename pT>
  class dfunc_obj
    :public func_obj<rT,pT>
  {
  private:
    virtual pT do_diff(const pT& p)=0;
    
  public:
    /**
       calculate the differentiation
       \param p the self-var
       \return the gradient at p
     */
    pT diff(const pT& p)
    {
      return do_diff(p);
    }
  };
  

  /**
     When trying to diff an object function, when it is not differentiable,
     this exception will be thrown.
   */
  class underivable
    :public opt_exception
  {
  public:
    underivable()
      :opt_exception("underivable")
    {}
  };

  /**
     calculate the numerical differential of a func_obj
   */
  template <typename rT,typename pT>
    pT numdiff(func_obj<rT,pT>& f,const pT& p)
  {
    rT ep=std::sqrt(std::numeric_limits<rT>::epsilon());
    
    pT result;
    resize(result,get_size(p));
    pT p2;
    resize(p2,get_size(p));
    pT p1;
    resize(p1,get_size(p));
    
    for(size_t i=0;i<get_size(p);++i)
      {
	set_element(p2,i,get_element(p,i));
	set_element(p1,i,get_element(p,i));
      }
    for(size_t i=0;i<get_size(p);++i)
      {
	typename element_type_trait<pT>::element_type h=
	  std::max(get_element(p,i),rT(1))*ep;
	  
	set_element(p2,i,get_element(p,i)+h);
	set_element(p1,i,get_element(p,i)-h);
	
	rT v2=f(p2);
	rT v1=f(p1);
	set_element(result,i,
		    (v2-v1)/h/2
		    );
	set_element(p2,i,get_element(p,i));
	set_element(p1,i,get_element(p,i));
      }
    return result;
  }


  /**
     Help function to calculate the gradient of an objection function 
     func_obj, whether it is differentiable or not. If it is differentiable, 
     the gradient will be calculated by calling the diff member in the func_obj,
     or a numerical calculation will be performed.
   */
  template <typename rT,typename pT>
    pT diff(func_obj<rT,pT>& f,const pT& p)
  {
    dfunc_obj<rT,pT>* pdf=dynamic_cast<dfunc_obj<rT,pT>*>(&f);
    if(pdf)
      {
	return pdf->diff(p);
      }
    return numdiff(f,p);
    
  }
}

#endif