From 1f4a944064bc42284c33e6b755353d191cf288e8 Mon Sep 17 00:00:00 2001 From: astrojhgu Date: Mon, 15 Dec 2008 07:26:12 +0000 Subject: git-svn-id: file:///home/svn/opt_utilities@1 ed2142bd-67ad-457f-ba7c-d818d4011675 --- core/default_data_set.hpp | 47 ++ core/fitter.hpp | 1043 +++++++++++++++++++++++++++++++++++++++++++++ core/freeze_param.hpp | 204 +++++++++ core/opt_exception.hpp | 106 +++++ core/opt_traits.hpp | 64 +++ core/optimizer.hpp | 289 +++++++++++++ 6 files changed, 1753 insertions(+) create mode 100644 core/default_data_set.hpp create mode 100644 core/fitter.hpp create mode 100644 core/freeze_param.hpp create mode 100644 core/opt_exception.hpp create mode 100644 core/opt_traits.hpp create mode 100644 core/optimizer.hpp (limited to 'core') diff --git a/core/default_data_set.hpp b/core/default_data_set.hpp new file mode 100644 index 0000000..be06c04 --- /dev/null +++ b/core/default_data_set.hpp @@ -0,0 +1,47 @@ +#ifndef DEFAULT_DATA_SET +#define DEFAULT_DATA_SET +#include "fitter.hpp" +#include + + +namespace opt_utilities +{ + +template +class default_data_set + :public data_set +{ +private: + std::vector > data_vec; + + data_set* do_clone()const + { + return new default_data_set(*this); + } + + + const data& do_get_data(int i)const + { + return data_vec.at(i); + } + + int do_size()const + { + return data_vec.size(); + } + + void do_push_back(const data& d) + { + data_vec.push_back(d); + } + + void do_clear() + { + data_vec.clear(); + } + +}; +} + +#endif +//EOF diff --git a/core/fitter.hpp b/core/fitter.hpp new file mode 100644 index 0000000..ed90a83 --- /dev/null +++ b/core/fitter.hpp @@ -0,0 +1,1043 @@ +#ifndef FITTER_HPP +#define FITTER_HPP +#include "opt_exception.hpp" +#include "optimizer.hpp" +#include +#include +#include +#include +#include +namespace opt_utilities +{ + + /////////////////////////////////// + //////class data/////////////////// + //////contain single data point//// + /////////////////////////////////// + template + class statistic; + + template + class param_modifier; + + + template + class data + { + private: + Tx x,x_lower_err,x_upper_err; + Ty y,y_lower_err,y_upper_err; + public: + data(const Tx& _x,const Ty& _y,const Ty& _y_lower_err,const Ty& _y_upper_err,const Tx& _x_lower_err,const Tx& _x_upper_err) + { + opt_eq(x,_x); + opt_eq(x_lower_err,_x_lower_err); + opt_eq(x_upper_err,_x_upper_err); + opt_eq(y,_y); + opt_eq(y_lower_err,_y_lower_err); + opt_eq(y_upper_err,_y_upper_err); + + } + + data() + :x(), + x_lower_err(), + x_upper_err(), + y(), + y_lower_err(), + y_upper_err() + {} + + data(const data& rhs) + { + opt_eq(x,rhs.x); + opt_eq(x_lower_err,rhs.x_lower_err); + opt_eq(x_upper_err,rhs.x_upper_err); + opt_eq(y,rhs.y); + opt_eq(y_lower_err,rhs.y_lower_err); + opt_eq(y_upper_err,rhs.y_upper_err); + } + + data& operator=(const data& rhs) + { + opt_eq(x,rhs.x); + opt_eq(x_lower_err,rhs.x_lower_err); + opt_eq(x_upper_err,rhs.x_upper_err); + opt_eq(y,rhs.y); + opt_eq(y_lower_err,rhs.y_lower_err); + opt_eq(y_upper_err,rhs.y_upper_err); + return *this; + } + + public: + const Tx& get_x()const + { + return x; + } + + const Tx& get_x_lower_err()const + { + return x_lower_err; + } + + const Tx& get_x_upper_err()const + { + return x_upper_err; + } + + const Ty& get_y()const + { + return y; + } + + const Ty& get_y_lower_err()const + { + return y_lower_err; + } + + const Ty& get_y_upper_err()const + { + return y_upper_err; + } + + + void set_x(const Tx& _x) + { + opt_eq(x,_x); + } + + void set_x_lower_err(const Tx& _x) + { + opt_eq(x_lower_err,_x); + } + + void set_x_upper_err(const Tx& _x) + { + opt_eq(x_upper_err,_x); + } + + void set_y(const Ty& _y) + { + opt_eq(y,_y); + } + + void set_y_lower_err(const Ty& _y) + { + opt_eq(y_lower_err,_y); + } + + void set_y_upper_err(const Ty& _y) + { + opt_eq(y_upper_err,_y); + } + + + }; + + //////////////////////////////////// + ///class data_set/////////////////// + ///contain a set of data//////////// + //////////////////////////////////// + + + template + class data_set + { + private: + + virtual const data& do_get_data(int i)const=0; + virtual int do_size()const=0; + virtual void do_push_back(const data&)=0; + virtual void do_clear()=0; + virtual data_set* do_clone()const=0; + + virtual void do_destroy() + { + delete this; + } + public: + const data& get_data(int i)const + { + return this->do_get_data(i); + } + int size()const + { + return do_size(); + } + void push_back(const data& d) + { + return do_push_back(d); + } + void clear() + { + do_clear(); + } + + data_set* clone()const + { + return this->do_clone(); + } + + void destroy() + { + do_destroy(); + } + + virtual ~data_set(){} + + }; + + /////////////////////////////////////////////// + /////class param_info////////////////////////// + /////record the information of one parameter/// + /////including the name, default value///////// + /////////////////////////////////////////////// + + + template + class param_info + { + private: + Tstr name; + //bool frozen; + typename element_type_trait::element_type default_value; + + public: + param_info(const Tstr& _name, + const typename element_type_trait::element_type& _v) + :name(_name),default_value(_v){} + + param_info() + :name() + {} + + param_info(const param_info& rhs) + :name(rhs.name) + { + opt_eq(default_value,rhs.default_value); + } + + param_info& operator=(const param_info& rhs) + { + name=rhs.name; + opt_eq(default_value,rhs.default_value); + return *this; + } + + const Tstr& get_name()const + { + return this->name; + } + + const typename element_type_trait::element_type& get_default_value()const + { + return default_value; + } + + void set_default_value(const typename element_type_trait::element_type& x) + { + opt_eq(default_value,x); + } + + void set_name(const Tstr& _name) + { + name=_name; + } + }; + + + + + + template + class model + { + private: + std::vector > param_info_list; + param_info null_param; + // int num_free_params; + param_modifier* p_param_modifier; + private: + virtual model* do_clone()const=0; + + virtual void do_destroy() + { + delete this; + } + public: + model* clone()const + { + return do_clone(); + } + + void destroy() + { + do_destroy(); + } + public: + model() + :p_param_modifier(0) + {} + + model(const model& rhs) + :p_param_modifier(0) + { + if(rhs.p_param_modifier!=0) + { + set_param_modifier(*(rhs.p_param_modifier)); + } + param_info_list=rhs.param_info_list; + null_param=rhs.null_param; + + } + + model& operator=(const model& rhs) + { + if(this==&rhs) + { + return *this; + } + if(rhs.p_param_modifier!=0) + { + set_param_modifier(*(rhs.p_param_modifier)); + } + param_info_list=rhs.param_info_list; + null_param=rhs.null_param; + return *this; + } + + virtual ~model() + { + if(p_param_modifier) + { + //delete p_param_modifier; + p_param_modifier->destroy(); + } + } + + void set_param_modifier(const param_modifier& pm) + { + if(p_param_modifier!=0) + { + //delete p_param_modifier; + p_param_modifier->destroy(); + } + p_param_modifier=pm.clone(); + p_param_modifier->set_model(*this); + } + + void set_param_modifier() + { + if(p_param_modifier!=0) + { + //delete p_param_modifier; + p_param_modifier->destroy(); + } + p_param_modifier=0; + } + + param_modifier& get_param_modifier() + { + if(p_param_modifier==0) + { + throw param_modifier_undefined(); + } + return *p_param_modifier; + } + + Tstr report_param_status(const Tstr& s)const + { + if(p_param_modifier==0) + { + return Tstr(); + } + + return p_param_modifier->report_param_status(s); + + } + + + const param_info& get_param_info(const Tstr& pname) + { + for(typename std::vector >::iterator i=param_info_list.begin(); + i!=param_info_list.end();++i) + { + if(i->get_name()==pname) + { + return *i; + } + } + std::cerr<<"Param unfound!"<& get_param_info(int n)const + { + return param_info_list[n%get_num_params()]; + } + + + Tp get_all_params()const + { + Tp result; + resize(result,param_info_list.size()); + for(size_t i=0;iget_num_free_params(); + } + return get_num_params(); + } + + void set_param_value(const Tstr& pname, + const typename element_type_trait::element_type& v) + { + //int porder=0; + for(typename std::vector >::iterator i=param_info_list.begin(); + i!=param_info_list.end();++i) + { + if(i->get_name()==pname) + { + i->set_default_value(v); + return; + } + } + std::cerr<<"param "<& pinfo) + { + param_info_list.push_back(pinfo); + // this->num_free_params++; + } + + void clear_param_info() + { + // this->num_free_params=0; + param_info_list.clear(); + } + + + + public: + Tstr to_string()const + { + return do_to_string(); + } + + Tp reform_param(const Tp& p)const + { + if(p_param_modifier==0) + { + return p; + } + return p_param_modifier->reform(p); + } + + Tp deform_param(const Tp& p)const + { + if(p_param_modifier==0) + { + return p; + } + return p_param_modifier->deform(p); + } + + + Ty eval(const Tx& x,const Tp& p) + { + return do_eval(x,reform_param(p)); + } + + virtual Ty do_eval(const Tx& x,const Tp& p)=0; + + private: + virtual Tstr do_to_string()const + { + return Tstr(); + } + + }; + + + + template + class fitter + { + public: + model* p_model; + statistic* p_statistic; + data_set* p_data_set; + optimizer optengine; + public: + fitter() + :p_model(0),p_statistic(0),p_data_set(0),optengine() + {} + + + fitter(const fitter& rhs) + :p_model(0),p_statistic(0),p_data_set(0),optengine() + { + if(rhs.p_model!=0) + { + set_model(*(rhs.p_model)); + } + if(rhs.p_statistic!=0) + { + set_statistic(*(rhs.p_statistic)); + assert(p_statistic->p_fitter!=0); + } + if(rhs.p_data_set!=0) + { + load_data(*(rhs.p_data_set)); + } + optengine=rhs.optengine; + } + + fitter& operator=(const fitter& rhs) + { + if(this==&rhs) + { + return *this; + } + if(rhs.p_model!=0) + { + set_model(*(rhs.p_model)); + } + if(rhs.p_statistic!=0) + { + set_statistic(*(rhs.p_statistic)); + } + if(rhs.p_data_set!=0) + { + load_data(*(rhs.p_data_set)); + } + optengine=rhs.optengine; + return *this; + } + + virtual ~fitter() + { + if(p_model!=0) + { + //delete p_model; + p_model->destroy(); + } + if(p_statistic!=0) + { + //delete p_statistic; + p_statistic->destroy(); + } + if(p_data_set!=0) + { + //delete p_data_set; + p_data_set->destroy(); + } + } + + + Ty eval_model(const Tx& x,const Tp& p) + { + if(p_model==0) + { + throw model_undefined(); + } + return p_model->eval(x,p); + } + + + public: + void set_model(const model& m) + { + if(p_model!=0) + { + //delete p_model; + p_model->destroy(); + } + p_model=m.clone(); + //p_model=&m; + // current_param.resize(m.get_num_params()); + } + /* + void set(const model& m) + { + set_model(m); + } + */ + void set_statistic(const statistic& s) + { + if(p_statistic!=0) + { + //delete p_statistic; + p_statistic->destroy(); + } + p_statistic=s.clone(); + //p_statistic=&s; + p_statistic->set_fitter(*this); + } + /* + void set(const statistic& s) + { + set_statistic(s); + } + */ + + void set_param_modifier(const param_modifier& pm) + { + if(p_model==0) + { + throw model_undefined(); + } + p_model->set_param_modifier(pm); + } + + void set_param_modifier() + { + if(p_model==0) + { + throw model_undefined(); + } + p_model->set_param_modifier(); + } + + param_modifier& get_param_modifier() + { + if(p_model==0) + { + throw model_undefined(); + } + return p_model->get_param_modifier(); + } + + Tstr report_param_status(const Tstr& s)const + { + if(p_model==0) + { + throw model_undefined(); + } + return p_model->report_param_status(s); + } + + /* + void set(const param_modifier& pm) + { + set_param_modifier(pm); + } + */ + void load_data(const data_set& da) + { + if(p_data_set!=0) + { + //delete p_data_set; + p_data_set->destroy(); + } + p_data_set=da.clone(); + if(p_statistic!=0) + { + p_statistic->set_fitter(*this); + } + } + + const data_set& datas()const + { + if(p_data_set==0) + { + throw data_unloaded(); + } + return *(this->p_data_set); + } + + model& model() + { + if(p_model==0) + { + throw model_undefined(); + } + return *(this->p_model); + } + + statistic& statistic() + { + if(p_statistic==0) + { + throw statistic_undefined(); + } + return *(this->p_statistic); + } + + opt_method& method() + { + return optengine.method(); + } + + + public: + void set_param_value(const Tstr& pname, + const typename element_type_trait::element_type& v) + { + if(p_model==0) + { + throw model_undefined(); + } + p_model->set_param_value(pname,v); + } + + void set_param_value(const Tp& param) + { + if(p_model==0) + { + throw model_undefined(); + } + p_model->set_param_value(param); + } + + typename element_type_trait::element_type get_param_value(const Tstr& pname)const + { + if(p_model==0) + { + throw model_undefined(); + } + return p_model->get_param_info(pname).get_default_value(); + } + + const param_info& get_param_info(const Tstr& pname)const + { + if(p_model==0) + { + throw model_undefined(); + } + return p_model->get_param_info(pname); + } + + const param_info& get_param_info(int n)const + { + if(p_model==0) + { + throw model_undefined(); + } + return p_model->get_param_info(n); + } + + const int get_param_order(const Tstr& pname)const + { + if(p_model==0) + { + throw model_undefined(); + } + return p_model->get_param_order(pname); + } + + int get_num_params()const + { + if(p_model==0) + { + throw model_undefined(); + } + return p_model->get_num_params(); + } + + Tp get_all_params()const + { + if(p_model==0) + { + throw model_undefined(); + } + //return current_param; + return p_model->get_all_params(); + } + + void set_method(const opt_method& pm) + { + //assert(p_optimizer!=0); + optengine.set_opt_method(pm); + } + /* + void set(const opt_method& pm) + { + set_method(pm); + } + */ + void set_precision(typename element_type_trait::element_type y) + { + optengine.set_precision(y); + } + + Tp fit() + { + // assert(p_model!=0); + if(p_model==0) + { + throw model_undefined(); + } + if(p_data_set==0) + { + throw data_unloaded(); + } + //assert(p_optimizer!=0); + //assert(p_data_set!=0); + //assert(p_statistic!=0); + if(p_statistic==0) + { + throw statistic_undefined(); + } + + optengine.set_func_obj(*p_statistic); + Tp current_param; + opt_eq(current_param,p_model->get_all_params()); + Tp start_point; + opt_eq(start_point,p_model->deform_param(current_param)); + // std::cout<get_all_params(); + } + optengine.set_start_point(start_point); + + Tp result; + opt_eq(result,optengine.optimize()); + + Tp decurrent_param; + opt_eq(decurrent_param,p_model->reform_param(result)); + //current_param.resize(decurrent_param.size()); + resize(current_param,get_size(decurrent_param)); + opt_eq(current_param,decurrent_param); + p_model->set_param_value(current_param); + // return current_param; + return p_model->get_all_params(); + } + + }; + + + template + class statistic + :public func_obj + { + public: + fitter* p_fitter; + + private: + virtual statistic* do_clone()const=0; + + virtual void do_destroy() + { + delete this; + } + + public: + statistic* clone()const + { + return this->do_clone(); + } + + void destroy() + { + return do_destroy(); + } + statistic() + :p_fitter(0) + {} + + statistic(const statistic& rhs) + :p_fitter(rhs.p_fitter) + {} + + statistic& operator=(const statistic& rhs) + { + if(this==&rhs) + { + return *this; + } + p_fitter=rhs.p_fitter; + return *this; + } + + virtual ~statistic() + {} + + virtual void set_fitter(fitter& pfitter) + { + p_fitter=&pfitter; + } + + virtual const fitter& get_fitter()const + { + if(p_fitter==0) + { + throw fitter_unset(); + } + return *p_fitter; + } + + Ty eval_model(const Tx& x,const Tp& p) + { + if(p_fitter==0) + { + throw fitter_unset(); + } + return p_fitter->eval_model(x,p); + } + + + const data_set& datas()const + { + if(p_fitter==0) + { + throw fitter_unset(); + } + return p_fitter->datas(); + } + + }; + + template + class param_modifier + { + private: + model* p_model; + public: + Tp reform(const Tp& p)const + { + return do_reform(p); + } + Tp deform(const Tp& p)const + { + return do_deform(p); + } + + param_modifier* clone()const + { + return do_clone(); + } + + void destroy() + { + do_destroy(); + } + + public: + + param_modifier() + :p_model(0) + {} + + param_modifier(const param_modifier& rhs) + :p_model(rhs.p_model) + {} + + param_modifier& operator=(const param_modifier& rhs) + { + if(this==&rhs) + { + return *this; + } + p_model=rhs.p_model; + return *this; + } + + public: + void set_model(model& pf) + { + p_model=&pf; + update(); + } + + const model& get_model()const + { + if(p_model==0) + { + std::cout<<"dajf;asdjfk;"; + throw model_undefined(); + } + return *(this->p_model); + } + + int get_num_free_params()const + { + return do_get_num_free_params(); + } + + Tstr report_param_status(const Tstr& name)const + { + return do_report_param_status(name); + } + + + virtual ~param_modifier(){} + private: + virtual Tp do_reform(const Tp& p)const=0; + virtual Tp do_deform(const Tp& p)const=0; + virtual int do_get_num_free_params()const=0; + virtual Tstr do_report_param_status(const Tstr&)const=0; + virtual void update(){} + + virtual param_modifier* do_clone()const=0; + + virtual void do_destroy() + { + delete this; + } + + }; + + + +}; + + +#endif +//EOF diff --git a/core/freeze_param.hpp b/core/freeze_param.hpp new file mode 100644 index 0000000..8ba677a --- /dev/null +++ b/core/freeze_param.hpp @@ -0,0 +1,204 @@ +#ifndef FREEZE_PARAM_HPP +#define FREEZE_PARAM_HPP +#include "fitter.hpp" +#include +#include + +namespace opt_utilities +{ + template + class freeze_param + :public param_modifier + { + private: + std::set param_names; + std::vector param_num; + int num_free; + + public: + freeze_param() + { + + } + + freeze_param(const Tstr& name) + { + param_names.insert(name); + } + + private: + freeze_param* do_clone()const + { + return new freeze_param(*this); + } + + + + void update() + { + param_num.clear(); + for(typename std::set::const_iterator i=param_names.begin(); + i!=param_names.end();++i) + { + try + { + param_num.push_back(this->get_model().get_param_order(*i)); + } + catch(opt_exception& e) + { + param_names.erase(*i); + throw; + } + + } + } + + int do_get_num_free_params()const + { + return this->get_model().get_num_params()-param_num.size(); + } + + bool is_frozen(int i)const + { + if(find(param_num.begin(),param_num.end(),i)==param_num.end()) + { + return false; + } + return true; + } + + + Tp do_reform(const Tp& p)const + { + int nparams=(this->get_model().get_num_params()); + Tp reformed_p(nparams); + int i=0,j=0; + for(i=0;i<(int)nparams;++i) + { + if(this->is_frozen(i)) + { + const param_info& pinf=this->get_model().get_param_info(i); + //std::cout<<"frozen:"<get_model().get_num_params()); + for(;i<(int)get_size(p);++i) + { + //std::cout<is_frozen(i)) + { + //opt_eq(get_element(deformed_p,j),get_element(p,i)); + set_element(deformed_p,j,get_element(p,i)); + j++; + } + } + + assert(j==do_get_num_free_params()); + return deformed_p; + } + + + Tstr do_report_param_status(const Tstr& name)const + { + if(param_names.find(name)==param_names.end()) + { + return "thawed"; + } + return "frozen"; + } + + public: + freeze_param operator+(const freeze_param& fp)const + { + freeze_param result(*this); + for(typename std::set::const_iterator i=fp.param_names.begin(); + i!=fp.param_names.end(); + ++i) + { + result.param_names.insert(*i); + } + return result; + } + + freeze_param& operator+=(const freeze_param& fp) + { + //param_names.insert(param_names.end(), + //fp.param_names.begin(), + //fp.param_names.end()); + for(typename std::set::const_iterator i=fp.param_names.begin(); + i!=fp.param_names.end(); + ++i) + { + param_names.insert(*i); + } + try + { + update(); + } + catch(opt_exception& e) + { + throw; + } + return *this; + } + + freeze_param& operator-=(const freeze_param& fp) + { + //param_names.insert(param_names.end(), + //fp.param_names.begin(), + //fp.param_names.end()); + for(typename std::set::const_iterator i=fp.param_names.begin(); + i!=fp.param_names.end(); + ++i) + { + param_names.erase(*i); + } + try + { + update(); + } + catch(opt_exception& e) + { + throw; + } + return *this; + } + + }; + + template + freeze_param freeze(const Tstr& name) + { + return freeze_param(name); + } + +}; + + +#endif +//EOF diff --git a/core/opt_exception.hpp b/core/opt_exception.hpp new file mode 100644 index 0000000..925e4b5 --- /dev/null +++ b/core/opt_exception.hpp @@ -0,0 +1,106 @@ +#ifndef OPT_EXCEPTION +#define OPT_EXCEPTION +#include +#include +namespace opt_utilities +{ + class opt_exception + :public std::exception + { + private: + std::string _what; + public: + opt_exception() + {}; + + ~opt_exception()throw() + {} + + opt_exception(const std::string& str) + :_what(str) + {} + + const char* what()const throw() + { + return _what.c_str(); + } + }; + + class target_function_undefined + :public opt_exception + { + public: + target_function_undefined() + :opt_exception("target function undefined") + {} + }; + + class opt_method_undefined + :public opt_exception + { + public: + opt_method_undefined() + :opt_exception("opt method undefined") + {} + }; + + class fitter_unset + :public opt_exception + { + public: + fitter_unset() + :opt_exception("fitter_unset") + {} + }; + + class model_undefined + :public opt_exception + { + public: + model_undefined() + :opt_exception("model_undefined") + {} + }; + + class data_unloaded + :public opt_exception + { + public: + data_unloaded() + :opt_exception("data not loaded") + {} + }; + + + class statistic_undefined + :public opt_exception + { + public: + statistic_undefined() + :opt_exception("statistic undefined") + {} + }; + + class param_not_found + :public opt_exception + { + public: + param_not_found() + :opt_exception("param name invalid") + {} + }; + + class param_modifier_undefined + :public opt_exception + { + public: + param_modifier_undefined() + :opt_exception("param modifier undefined") + {} + }; + +}; + + +#endif +//EOF diff --git a/core/opt_traits.hpp b/core/opt_traits.hpp new file mode 100644 index 0000000..c539a63 --- /dev/null +++ b/core/opt_traits.hpp @@ -0,0 +1,64 @@ +#ifndef ARRAY_OPERATION +#define ARRAY_OPERATION +#include +namespace opt_utilities +{ + /////////Useful function/////////////////////////////////// + template + inline size_t get_size(const T& x) + { + return x.size(); + } + + template + class element_type_trait + { + public: + typedef typename T::value_type element_type; + }; + + template + class return_type_trait + { + public: + typedef T value_type; + typedef T& reference_type; + typedef const T& const_reference_type; + }; + + template + inline typename return_type_trait::element_type>::const_reference_type get_element(const T& x,size_t i) + { + return x[i]; + } + /* + template + inline typename element_type_trait::element_type& get_element(T& x,size_t i) + { + return x[i]; + } + */ + + template + inline void set_element(T& x,size_t i, + const TX& v) + { + x[i]=v; + } + + template + inline void resize(T& x,size_t s) + { + x.resize(s); + } + + template + inline Tl& opt_eq(Tl& lhs,const Tr& rhs) + { + return (lhs=rhs); + } +}; + + + +#endif diff --git a/core/optimizer.hpp b/core/optimizer.hpp new file mode 100644 index 0000000..ca73775 --- /dev/null +++ b/core/optimizer.hpp @@ -0,0 +1,289 @@ +#ifndef OPTIMZER_H_ +#define OPTIMZER_H_ +//#define DEBUG +#include +#include "opt_traits.hpp" +#include "opt_exception.hpp" +#include +#include +#ifdef DEBUG +#include +using namespace std; +#endif + +namespace opt_utilities +{ + /////////Forward declare/////////////////////////////////// + template + class optimizer; + + template + class func_obj; + + template + class opt_method; + + + //////////////Target Function///////////////////// + ///An eval function should be implemented///////// + ///The eval function return the function value//// + ///which is wrapped by the func_obj/////////////// + ////////////////////////////////////////////////// + template + class func_obj + :public std::unary_function + { + private: + virtual rT do_eval(const pT&)=0; + virtual func_obj* do_clone()const=0; + virtual void do_destroy() + { + delete this; + } + + public: + public: + func_obj* clone()const + { + return do_clone(); + } + + void destroy() + { + do_destroy(); + } + + rT operator()(const pT& p) + { + return do_eval(p); + } + + + rT eval(const pT& p) + { + return do_eval(p); + }; + virtual ~func_obj(){}; + // virtual XT walk(XT,YT)=0; + }; + + + ///////////////Optimization method////////////////////// + + template + class opt_method + { + public: + virtual void do_set_optimizer(optimizer&)=0; + virtual void do_set_precision(rT)=0; + virtual pT do_optimize()=0; + virtual void do_set_start_point(const pT& p)=0; + virtual opt_method* do_clone()const=0; + + virtual void do_destroy() + { + delete this; + } + public: + void set_optimizer(optimizer& op) + { + do_set_optimizer(op); + }; + + void set_precision(rT x) + { + do_set_precision(x); + } + + void set_start_point(const pT& p) + { + do_set_start_point(p); + } + + pT optimize() + { + return do_optimize(); + }; + + opt_method* clone()const + { + return do_clone(); + } + + void destroy() + { + do_destroy(); + } + + virtual ~opt_method(){}; + }; + + + ///////////Optimizer//////////////////////////////////// + template + class optimizer + { + public: + + private: + + ////////////pointer to an optimization method objection//////////// + ////////////The optimization method implements a certain method /// + ////////////Currently only Mont-carlo method is implemented//////// + opt_method* p_opt_method; + func_obj* p_func_obj; + + public: + optimizer() + :p_opt_method(0),p_func_obj(0) + {} + + optimizer(func_obj& fc,const opt_method& om) + :p_func_obj(fc.clone()),p_opt_method(om.clone()) + { + p_opt_method->set_optimizer(*this); + } + + optimizer(const optimizer& rhs) + :p_opt_method(0),p_func_obj(0) + { + if(rhs.p_func_obj!=0) + { + set_func_obj(*(rhs.p_func_obj)); + } + if(rhs.p_opt_method!=0) + { + set_opt_method(*(rhs.p_opt_method)); + } + } + + optimizer& operator=(const optimizer& rhs) + { + if(this==&rhs) + { + return *this; + } + if(rhs.p_func_obj!=0) + { + set_func_obj(*(rhs.p_func_obj)); + } + if(rhs.p_opt_method!=0) + { + set_opt_method(*(rhs.p_opt_method)); + } + return *this; + } + + + virtual ~optimizer() + { + if(p_func_obj!=0) + { + //delete p_func_obj; + p_func_obj->destroy(); + } + if(p_opt_method!=0) + { + //delete p_opt_method; + p_opt_method->destroy(); + } + }; + + public: + ////////////Re-set target function object/////////////////////////// + void set_func_obj(const func_obj& fc) + { + if(p_func_obj!=0) + { + //delete p_func_obj; + p_func_obj->destroy(); + } + p_func_obj=fc.clone(); + if(p_opt_method!=0) + { + p_opt_method->set_optimizer(*this); + } + } + + ////////////Re-set optimization method////////////////////////////// + void set_opt_method(const opt_method& om) + { + if(p_opt_method!=0) + { + //delete p_opt_method; + p_opt_method->destroy(); + } + + p_opt_method=om.clone(); + p_opt_method->set_optimizer(*this); + } + + opt_method& method() + { + if(p_opt_method==0) + { + throw opt_method_undefined(); + } + return *(this->p_opt_method); + } + + void set_precision(rT x) + { + if(p_opt_method==0) + { + throw opt_method_undefined(); + } + p_opt_method->set_precision(x); + } + + void set_start_point(const pT& x) + { + if(p_opt_method==0) + { + throw opt_method_undefined(); + } + p_opt_method->set_start_point(x); + } + + ////////////Just call the eval function in the target function object/// + ////////////In case the pointer to a target function is uninitialed///// + ////////////a zero-value is returned//////////////////////////////////// + rT eval(const pT& x) + { + if(p_func_obj==0) + { + throw target_function_undefined(); + } + return p_func_obj->eval(x); + } + + + + ////////////Just call the optimize() function in the optimization method// + ////////////If no optimization method is given, an zero-value is returned/ + pT optimize() + { + if(p_opt_method==0) + { + throw opt_method_undefined(); + } + if(p_func_obj==0) + { + throw target_function_undefined(); + } + return p_opt_method->optimize(); + } + + ////////////Function that offers the access to the target function object/// + func_obj* ptr_func_obj() + { + return p_func_obj; + } + + }; +}; + +#endif +//EOF + + -- cgit v1.2.2