From bc3c8b44041285652c5596bbc191d3c174802be6 Mon Sep 17 00:00:00 2001 From: Aaron LI Date: Tue, 11 Jul 2017 20:39:07 +0800 Subject: astro/ps2d.py: add argument "--no-window" and update logging format Signed-off-by: Aaron LI --- astro/ps2d.py | 60 +++++++++++++++++++++++++++++++---------------------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/astro/ps2d.py b/astro/ps2d.py index e110776..9ce91bf 100755 --- a/astro/ps2d.py +++ b/astro/ps2d.py @@ -23,7 +23,9 @@ from astropy.cosmology import FlatLambdaCDM import astropy.constants as ac -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%dT%H:%M:%S") logger = logging.getLogger(os.path.basename(sys.argv[0])) @@ -70,32 +72,30 @@ class PS2D: self.cosmo = FlatLambdaCDM(H0=H0, Om0=OmegaM0) # Transverse comoving distance at zc; unit: [Mpc] self.DMz = self.cosmo.comoving_transverse_distance(self.zc).value - self.set_window(name=window, width=width) - - def set_window(self, name, width="extended"): - self.window = { - "name": name, - "func": getattr(windows, name), - "width": width - } - filter = self.window["func"](self.window_width, sym=False) - if len(filter) > self.nfreq: - # cut the filter - midx = int(len(filter) / 2) # index of the peak element - nleft = int(self.nfreq / 2) # number of element on the left - nright = int((self.nfreq-1) / 2) # number of element on the right - filter = filter[(midx-nleft):(midx+nright+1)] - self.window["filter"] = filter - logger.info("Set window: %s (%s)" % (name, width)) + self.window = self.gen_window(name=window, width=width) - @property - def window_width(self): - if self.window["width"] == "extended": - w = self.window["func"](self.nfreq, sym=False) + def gen_window(self, name=None, width="extended"): + if name is None: + window = np.ones(self.nfreq) + return window + + window_func = getattr(windows, name) + if width == "extended": + w = window_func(self.nfreq, sym=False) ex = 1.0 / (w.sum() / self.nfreq) - return int(ex * self.nfreq) + width_pix = int(ex * self.nfreq) else: - return self.nfreq + width_pix = self.nfreq + + window = window_func(width_pix, sym=False) + if len(window) > self.nfreq: + # cut the filter + midx = int(len(window) / 2) # index of the peak element + nleft = int(self.nfreq / 2) # number of element on the left + nright = int((self.nfreq-1) / 2) # number of element on the right + window = window[(midx-nleft):(midx+nright+1)] + logger.info("Generated window: %s (%s/%d)" % (name, width, width_pix)) + return window def pad_cube(self): # Pad the image cube to be square in spatial dimensions. @@ -103,15 +103,14 @@ class PS2D: __, ny, nz = self.cube.shape if ny != nz: logger.info("Padding image to be square ...") - raise RuntimeError("image must be square!") + raise NotImplementedError def calc_ps3d(self): """ Calculate the 3D power spectrum of the image cube. """ logger.info("Applying window to frequency axis ...") - w = self.window["filter"] - cube2 = self.cube * w[:, np.newaxis, np.newaxis] + cube2 = self.cube * self.window[:, np.newaxis, np.newaxis] logger.info("Calculating 3D FFT and PS ...") cubefft = fftpack.fftshift(fftpack.fftn(cube2)) self.ps3d = np.abs(cubefft) ** 2 @@ -253,6 +252,9 @@ def main(): parser.add_argument("-p", "--pixelsize", dest="pixelsize", type=float, required=True, help="image cube pixel size; unit: [arcmin]") + parser.add_argument("--no-window", dest="no_window", + action="store_true", + help="do not apply window along frequency axis") parser.add_argument("-i", "--infile", dest="infile", required=True, help="input FITS image cube") parser.add_argument("-o", "--outfile", dest="outfile", required=True, @@ -267,7 +269,9 @@ def main(): logger.info("%d frequencies [MHz]:" % nfreq) for f in frequencies: logger.info("* %.2f" % f) - ps2d = PS2D(cube=cube, pixelsize=args.pixelsize, frequencies=frequencies) + window = None if args.no_window else "nuttall" + ps2d = PS2D(cube=cube, pixelsize=args.pixelsize, frequencies=frequencies, + window=window) ps2d.calc_ps3d() ps2d.calc_ps2d() ps2d.save(outfile=args.outfile, clobber=args.clobber) -- cgit v1.2.2