diff options
-rwxr-xr-x | astro/fg_fitsub.py | 167 |
1 files changed, 167 insertions, 0 deletions
diff --git a/astro/fg_fitsub.py b/astro/fg_fitsub.py new file mode 100755 index 0000000..f1e8293 --- /dev/null +++ b/astro/fg_fitsub.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2017 Weitian LI <weitian@aaronly.me> +# MIT License +# + +""" +Fit the spectral-smooth foreground along the frequency axis (i.e., +line of sight for the 21 cm signal) using a polynomial in the linear +scale for a series of simulated images (e.g., made by WSClean), then +subtract the fitted smooth component to remove/reduce the foreground +contamination. + +References +---------- +* Liu, Tegmark & Zaldarriaga 2009, MNRAS, 394, 1575 +""" + +import os +import sys +import argparse +import time +import logging + +import numpy as np +from astropy.io import fits + + +logging.basicConfig(level=logging.INFO, + format="[%(levelname)s:%(lineno)d] %(message)s") +logger = logging.getLogger() + + +def open_image(infile): + """ + Open the slice image and return its header and 2D image data. + + NOTE + ---- + The input slice image may have following dimensions: + * NAXIS=2: [Y, X] + * NAXIS=3: [FREQ=1, Y, X] + * NAXIS=4: [STOKES=1, FREQ=1, Y, X] + + NOTE + ---- + Only open slice image that has only ONE frequency and ONE Stokes + parameter. + + Returns + ------- + header : `~astropy.io.fits.Header` + image : 2D `~numpy.ndarray` + The 2D [Y, X] image part of the slice image. + """ + with fits.open(infile) as f: + header = f[0].header + data = f[0].data + if data.ndim == 2: + # NAXIS=2: [Y, X] + image = data + elif data.ndim == 3 and data.shape[0] == 1: + # NAXIS=3: [FREQ=1, Y, X] + image = data[0, :, :] + elif data.ndim == 4 and data.shape[0] == 1 and data.shape[1] == 1: + # NAXIS=4: [STOKES=1, FREQ=1, Y, X] + image = data[0, 0, :, :] + else: + raise ValueError("Slice '{0}' has invalid dimensions: {1}".format( + infile, data.shape)) + return (header, image) + + +def get_frequency(header): + freq = None + try: + freq = header["FREQ"] # [MHz] + except KeyError: + try: + ctype3 = header["CTYPE3"] + if ctype3 == "FREQ": + freq = header["CRVAL3"] / 1e6 # [MHz] + except KeyError: + pass + return freq + + +def main(): + parser = argparse.ArgumentParser( + description="Subtract foreground through polynomial fitting") + parser.add_argument("-C", "--clobber", dest="clobber", action="store_true", + help="overwrite existing output file") + parser.add_argument("-p", "--poly-order", dest="poly_order", + type=int, default=2, + help="order of polynomial used for fitting " + + "(default: 2, i.e., quadratic)") + parser.add_argument("-o", "--outdir", dest="outdir", required=True, + help="output directory to store the subtracted images") + parser.add_argument("-i", "--infiles", dest="infiles", nargs="+", + help="input images slices (in order)") + args = parser.parse_args() + + if not os.path.exists(args.outdir): + os.mkdir(args.outdir) + logger.info("Created output directory: %s" % args.outdir) + for infile in args.infiles: + outfile = os.path.join(args.outdir, os.path.basename(infile)) + if os.path.exists(outfile): + if args.clobber: + os.remove(outfile) + logger.warning("Removed existing output file: %s" % outfile) + else: + raise OSError("Output file already exists: %s" % outfile) + + nfiles = len(args.infiles) + logger.info("Number of images: %d" % nfiles) + headers = [] + images = [] + freqs = np.zeros(nfiles) + for i, infile in enumerate(args.infiles): + header, image = open_image(infile) + headers.append(header) + images.append(image) + freq = get_frequency(header) + if freq is None: + raise ValueError("no frequency for image: %s" % infile) + freqs[i] = freq + logger.info("Loaded slice #%d: %s @ %.2f [MHz]" % (i+1, infile, freq)) + + cube = np.stack(images) + nz, ny, nx = cube.shape + logger.info("Image cube dimensions: %dx%d * %d (slices)" % (nx, ny, nz)) + cubeout = cube.copy() + npix = nx * ny + logger.info("Polynomial fitting order: %d" % args.poly_order) + t1 = time.perf_counter() + for i in range(npix): + if (i+1) % 10000 == 0: + t2 = time.perf_counter() + telapsed = (t2 - t1) / 60.0 + teta = telapsed * (npix - i) / i + logger.info("%d/%d [%.2f%%] || elapsed %.1f / ETA %.1f [min] ..." % + (i+1, npix, 100*(i+1)/npix, telapsed, teta)) + iy = i // ny + ix = i % ny + vlos = cube[:, iy, ix] + pfit = np.polyfit(freqs, vlos, deg=args.poly_order) + vfit = np.polyval(pfit, freqs) + cubeout[:, iy, ix] -= vfit + + logger.info("Done fitting and subtracting foreground!") + + for i, infile in enumerate(args.infiles): + outfile = os.path.join(args.outdir, os.path.basename(infile)) + image = cubeout[i, :, :] + header = headers[i] + header.add_history(" ".join(sys.argv)) + fits.PrimaryHDU(data=image, header=header).writeto(outfile) + logger.info("Wrote subtracted image slice: %s" % outfile) + + t2 = time.perf_counter() + telapsed = (t2 - t1) / 60.0 + logger.info("Running time: %.1f [min]" % telapsed) + + +if __name__ == "__main__": + main() |