aboutsummaryrefslogtreecommitdiffstats
path: root/core/freeze_param.hpp
diff options
context:
space:
mode:
authorastrojhgu <astrojhgu@ed2142bd-67ad-457f-ba7c-d818d4011675>2008-12-15 07:26:12 +0000
committerastrojhgu <astrojhgu@ed2142bd-67ad-457f-ba7c-d818d4011675>2008-12-15 07:26:12 +0000
commit1f4a944064bc42284c33e6b755353d191cf288e8 (patch)
treec8cb2253dea5f395e0f867aa6976433bd3eb00de /core/freeze_param.hpp
downloadopt-utilities-1f4a944064bc42284c33e6b755353d191cf288e8.tar.bz2
git-svn-id: file:///home/svn/opt_utilities@1 ed2142bd-67ad-457f-ba7c-d818d4011675
Diffstat (limited to 'core/freeze_param.hpp')
-rw-r--r--core/freeze_param.hpp204
1 files changed, 204 insertions, 0 deletions
diff --git a/core/freeze_param.hpp b/core/freeze_param.hpp
new file mode 100644
index 0000000..8ba677a
--- /dev/null
+++ b/core/freeze_param.hpp
@@ -0,0 +1,204 @@
+#ifndef FREEZE_PARAM_HPP
+#define FREEZE_PARAM_HPP
+#include "fitter.hpp"
+#include <vector>
+#include <set>
+
+namespace opt_utilities
+{
+ template <typename Ty,typename Tx,typename Tp,typename Tstr=std::string>
+ class freeze_param
+ :public param_modifier<Ty,Tx,Tp,Tstr>
+ {
+ private:
+ std::set<Tstr> param_names;
+ std::vector<int> param_num;
+ int num_free;
+
+ public:
+ freeze_param()
+ {
+
+ }
+
+ freeze_param(const Tstr& name)
+ {
+ param_names.insert(name);
+ }
+
+ private:
+ freeze_param<Ty,Tx,Tp,Tstr>* do_clone()const
+ {
+ return new freeze_param<Ty,Tx,Tp,Tstr>(*this);
+ }
+
+
+
+ void update()
+ {
+ param_num.clear();
+ for(typename std::set<Tstr>::const_iterator i=param_names.begin();
+ i!=param_names.end();++i)
+ {
+ try
+ {
+ param_num.push_back(this->get_model().get_param_order(*i));
+ }
+ catch(opt_exception& e)
+ {
+ param_names.erase(*i);
+ throw;
+ }
+
+ }
+ }
+
+ int do_get_num_free_params()const
+ {
+ return this->get_model().get_num_params()-param_num.size();
+ }
+
+ bool is_frozen(int i)const
+ {
+ if(find(param_num.begin(),param_num.end(),i)==param_num.end())
+ {
+ return false;
+ }
+ return true;
+ }
+
+
+ Tp do_reform(const Tp& p)const
+ {
+ int nparams=(this->get_model().get_num_params());
+ Tp reformed_p(nparams);
+ int i=0,j=0;
+ for(i=0;i<(int)nparams;++i)
+ {
+ if(this->is_frozen(i))
+ {
+ const param_info<Tp,Tstr>& pinf=this->get_model().get_param_info(i);
+ //std::cout<<"frozen:"<<pinf.get_name()
+ // <<i<<"\t"<<j
+ // <<std::endl;
+ //opt_eq(get_element(reformed_p,i),pinf.get_default_value());
+ set_element(reformed_p,i,pinf.get_default_value());
+ }
+ else
+ {
+ //opt_eq(get_element(reformed_p,i),get_element(p,j));
+ set_element(reformed_p,i,get_element(p,j));
+ j++;
+ }
+ }
+ /*
+ for(int i=0;i<reformed_p.size();++i)
+ {
+ std::cout<<get_element(reformed_p,i)<<",";
+ }
+ */
+ //std::cout<<"\n";
+ return reformed_p;
+ // return p;
+ }
+
+ Tp do_deform(const Tp& p)const
+ {
+ Tp deformed_p(do_get_num_free_params());
+ int i(0),j(0);
+ assert(get_size(p)==this->get_model().get_num_params());
+ for(;i<(int)get_size(p);++i)
+ {
+ //std::cout<<is_frozen(j)<<"\n";
+ if(!this->is_frozen(i))
+ {
+ //opt_eq(get_element(deformed_p,j),get_element(p,i));
+ set_element(deformed_p,j,get_element(p,i));
+ j++;
+ }
+ }
+
+ assert(j==do_get_num_free_params());
+ return deformed_p;
+ }
+
+
+ Tstr do_report_param_status(const Tstr& name)const
+ {
+ if(param_names.find(name)==param_names.end())
+ {
+ return "thawed";
+ }
+ return "frozen";
+ }
+
+ public:
+ freeze_param operator+(const freeze_param& fp)const
+ {
+ freeze_param result(*this);
+ for(typename std::set<Tstr>::const_iterator i=fp.param_names.begin();
+ i!=fp.param_names.end();
+ ++i)
+ {
+ result.param_names.insert(*i);
+ }
+ return result;
+ }
+
+ freeze_param& operator+=(const freeze_param& fp)
+ {
+ //param_names.insert(param_names.end(),
+ //fp.param_names.begin(),
+ //fp.param_names.end());
+ for(typename std::set<Tstr>::const_iterator i=fp.param_names.begin();
+ i!=fp.param_names.end();
+ ++i)
+ {
+ param_names.insert(*i);
+ }
+ try
+ {
+ update();
+ }
+ catch(opt_exception& e)
+ {
+ throw;
+ }
+ return *this;
+ }
+
+ freeze_param& operator-=(const freeze_param& fp)
+ {
+ //param_names.insert(param_names.end(),
+ //fp.param_names.begin(),
+ //fp.param_names.end());
+ for(typename std::set<Tstr>::const_iterator i=fp.param_names.begin();
+ i!=fp.param_names.end();
+ ++i)
+ {
+ param_names.erase(*i);
+ }
+ try
+ {
+ update();
+ }
+ catch(opt_exception& e)
+ {
+ throw;
+ }
+ return *this;
+ }
+
+ };
+
+ template <typename Ty,typename Tx,typename Tp,typename Tstr>
+ freeze_param<Ty,Tx,Tp,Tstr> freeze(const Tstr& name)
+ {
+ return freeze_param<Ty,Tx,Tp,Tstr>(name);
+ }
+
+};
+
+
+#endif
+//EOF