aboutsummaryrefslogtreecommitdiffstats
path: root/interface/pymodel.hpp
blob: 8165d6d8b7bfc50e4e0545b9e05e53554b77a275 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
/**
   \file pymodel.hpp
   \brief model wrapper of python functions
   \author Junhua Gu
 */



#ifndef PYMODEL
#define PYMODEL
#include <boost/python.hpp>
#include <core/fitter.hpp>
#include <core/opt_traits.hpp>


namespace opt_utilities
{

  template<typename T>
  boost::python::object convert_to_object(const T& x)
  {
    boost::python::list result;
    for(size_t i=0;i<get_size(x);++i)
      {
	result.append(get_element(x,i));
      }
    return result;
  }

  static inline boost::python::object convert_to_object(double x)
  {
    return boost::python::object(x);
  }

  static inline boost::python::object convert_to_object(float x)
  {
    return boost::python::object(x);
  }

  
  template <typename T>
  T convert_from_object(const boost::python::object& o,const T&)
  {
    T result(boost::python::len(o));
    for(int i=0;i<get_size(result);++i)
      {
	result[i]=boost::python::extract<typename element_type_trait<T>::element_type>(o[i]);
      }
    return result;
  }

  static inline double convert_from_object(const boost::python::object& o,double)
  {
    return boost::python::extract<double>(o);
  }
  
  static inline double convert_from_object(const boost::python::object& o,float)
  {
    return boost::python::extract<double>(o);
  }


  template <typename Ty,typename Tx,typename Tp>
  class pymodel
    :public model<Ty,Tx,Tp,std::string>
  {
  private:
    boost::python::object pyfunc;
    std::string type_name;
  public:
    pymodel()
    {
      if(!Py_IsInitialized())
	{
	  Py_Initialize();
	}
    }

    pymodel(const pymodel& rhs)
      :pyfunc(rhs.pyfunc)
    {
      for(int i=0;i<rhs.get_num_params();++i)
	{
	  this->push_param_info(rhs.get_param_info(i));
	}
    }

    pymodel& operator=(const pymodel& rhs)
    {
      if(this==&rhs)
	{
	  return *this;
	}
      pyfunc=rhs.pyfunc;
      for(int i=0;i<rhs.get_num_params();++i)
	{
	  this->push_param_info(rhs.get_param_info(i));
	}
      return *this;
    }

    ~pymodel()
    {}

  public:
    void attach(const std::string module_name,
		const std::string arg_name,
		const std::string arg_value,
		const std::string func_name)
    {
      type_name=module_name+"."+func_name;
      this->clear_param_info();
      boost::python::object mod(boost::python::import(module_name.c_str()));
      pyfunc=mod.attr(func_name.c_str());
      boost::python::list args_names(mod.attr(arg_name.c_str()));
      boost::python::list args_values(mod.attr(arg_value.c_str()));
      
      int nparams=boost::python::len(args_names);
      for(size_t i=0;i!=nparams;++i)
	{
	  boost::python::object pname_obj=args_names[i];
	  std::string pname=boost::python::extract<std::string>(pname_obj);
	  typename element_type_trait<Tp>::element_type pvalue=
	    boost::python::extract<typename element_type_trait<Tp>::element_type>(args_values[i]);
	  
	  push_param_info(param_info<Tp,std::string>(pname,pvalue));
	}
    }
  private:
    model<Ty,Tx,Tp,std::string>* do_clone()const
    {
      return new pymodel(*this);
    }
    
    void do_destroy()
    {
      delete this;
    }
    
    Ty do_eval(const Tx& x,const Tp& p)
    {
      boost::python::list args;
      for(size_t i=0;i<get_size(p);++i)
	{
	  args.append(get_element(p,i));
	}
      return convert_from_object(pyfunc(convert_to_object(x),args),x);
    }

    const char* do_get_type_name()const
    {
      return type_name.c_str();
    }
  };
}


#endif
//EOF