diff options
-rw-r--r-- | fg21sim/extragalactic/clusters/solver.py | 50 |
1 files changed, 36 insertions, 14 deletions
diff --git a/fg21sim/extragalactic/clusters/solver.py b/fg21sim/extragalactic/clusters/solver.py index 50ea75f..bc9a8fd 100644 --- a/fg21sim/extragalactic/clusters/solver.py +++ b/fg21sim/extragalactic/clusters/solver.py @@ -188,6 +188,37 @@ class FokkerPlanckSolver: (1 - np.exp(-w[~mask]))) return W + @staticmethod + def bound_w(w, wmin=1e-8, wmax=1e3): + """ + Bound the absolute values of w within [wmin, wmax]. + + To avoid the underflow/overflow during later W/Wplus/Wminus + calculations. + """ + with np.errstate(invalid="ignore"): + # Ignore NaN's + m1 = (np.abs(w) < wmin) + m2 = (np.abs(w) > wmax) + ww = np.array(w) + ww[m1] = wmin * np.sign(ww[m1]) + ww[m2] = wmax * np.sign(ww[m2]) + return ww + + def Wplus(self, w): + # References: Ref.[1],Eq.(32) + ww = self.bound_w(w) + W = self.W(ww) + Wplus = W * np.exp(ww/2) + return Wplus + + def Wminus(self, w): + # References: Ref.[1],Eq.(32) + ww = self.bound_w(w) + W = self.W(ww) + Wminus = W * np.exp(-ww/2) + return Wminus + def tridiagonal_coefs(self, tc, uc): """ Calculate the coefficients for the tridiagonal system of linear @@ -219,19 +250,10 @@ class FokkerPlanckSolver: C_mhalf = self.X_mhalf(C) w_phalf = dx_phalf * B_phalf / C_phalf w_mhalf = dx_mhalf * B_mhalf / C_mhalf - # Avoid overflow when w is too large - w_max = 300 - with np.errstate(invalid="ignore"): - mask_phalf = (np.abs(w_phalf) > w_max) - mask_mhalf = (np.abs(w_mhalf) > w_max) - w_phalf[mask_phalf] = w_max * (np.sign(w_phalf[mask_phalf])) - w_mhalf[mask_mhalf] = w_max * (np.sign(w_mhalf[mask_mhalf])) - W_phalf = self.W(w_phalf) - W_mhalf = self.W(w_mhalf) - Wplus_phalf = W_phalf * np.exp(w_phalf/2) - Wplus_mhalf = W_mhalf * np.exp(w_mhalf/2) - Wminus_phalf = W_phalf * np.exp(-w_phalf/2) - Wminus_mhalf = W_mhalf * np.exp(-w_mhalf/2) + Wplus_phalf = self.Wplus(w_phalf) + Wplus_mhalf = self.Wplus(w_mhalf) + Wminus_phalf = self.Wminus(w_phalf) + Wminus_mhalf = self.Wminus(w_mhalf) # a = (dt/dx) * (C_mhalf/dx_mhalf) * Wminus_mhalf a[0] = 0.0 # Fix a[0] which is NaN @@ -299,6 +321,6 @@ class FokkerPlanckSolver: i = 0 while tc < tstop: i += 1 - logger.info("[%d/%d] t=%.3f ..." % (i, nstep, tc)) + logger.debug("[%d/%d] t=%.3f ..." % (i, nstep, tc)) tc, uc = self.solve_step(tc, uc) return uc |