diff options
author | Aaron LI <aly@aaronly.me> | 2017-10-28 15:51:19 +0800 |
---|---|---|
committer | Aaron LI <aly@aaronly.me> | 2017-10-28 15:52:07 +0800 |
commit | 9aa50b115cd2650ff8cae9391a0b59560589dc89 (patch) | |
tree | f896b390f2ce62ecf1038148bbd7fb7795e150b3 /astro/calc_psd.py | |
parent | 8838f44c2e79a58dfc7752a7ddb0380fb387a948 (diff) | |
download | atoolbox-9aa50b115cd2650ff8cae9391a0b59560589dc89.tar.bz2 |
calc_psd.py: Removed unnecessary AstroImage class
Diffstat (limited to 'astro/calc_psd.py')
-rwxr-xr-x | astro/calc_psd.py | 216 |
1 files changed, 43 insertions, 173 deletions
diff --git a/astro/calc_psd.py b/astro/calc_psd.py index 4661c62..730c6f0 100755 --- a/astro/calc_psd.py +++ b/astro/calc_psd.py @@ -82,7 +82,7 @@ class PSD: print("DONE", flush=True) return self.psd2d - def calc_radial_psd1d(self): + def calc_psd(self): """ Computes the radially averaged power spectral density from the provided 2D power spectral density. @@ -216,161 +216,44 @@ class PSD: return (fig, ax) -class AstroImage: +def open_image(infile): """ - Manipulate the astronimcal counts image, as well as the corresponding - exposure map and background map. + Open the slice image and return its header and 2D image data. + + NOTE + ---- + The input slice image may have following dimensions: + * NAXIS=2: [Y, X] + * NAXIS=3: [FREQ=1, Y, X] + * NAXIS=4: [STOKES=1, FREQ=1, Y, X] + + NOTE + ---- + Only open slice image that has only ONE frequency and ONE Stokes + parameter. + + Returns + ------- + header : `~astropy.io.fits.Header` + image : 2D `~numpy.ndarray` + The 2D [Y, X] image part of the slice image. """ - # input counts image - image = None - # exposure map with respect to the input counts image - expmap = None - # background map (e.g., stowed background) - bkgmap = None - # exposure time of the input image - exposure = None - # exposure time of the background map - exposure_bkg = None - - def __init__(self, image, expmap=None, bkgmap=None): - self.load_image(image) - self.load_expmap(expmap) - self.load_bkgmap(bkgmap) - - @staticmethod - def open_image(infile): - """ - Open the slice image and return its header and 2D image data. - - NOTE - ---- - The input slice image may have following dimensions: - * NAXIS=2: [Y, X] - * NAXIS=3: [FREQ=1, Y, X] - * NAXIS=4: [STOKES=1, FREQ=1, Y, X] - - NOTE - ---- - Only open slice image that has only ONE frequency and ONE Stokes - parameter. - - Returns - ------- - header : `~astropy.io.fits.Header` - image : 2D `~numpy.ndarray` - The 2D [Y, X] image part of the slice image. - """ - with fits.open(infile) as f: - header = f[0].header - data = f[0].data - if data.ndim == 2: - # NAXIS=2: [Y, X] - image = data - elif data.ndim == 3 and data.shape[0] == 1: - # NAXIS=3: [FREQ=1, Y, X] - image = data[0, :, :] - elif data.ndim == 4 and data.shape[0] == 1 and data.shape[1] == 1: - # NAXIS=4: [STOKES=1, FREQ=1, Y, X] - image = data[0, 0, :, :] - else: - raise ValueError("Slice '{0}' has invalid dimensions: {1}".format( - infile, data.shape)) - return (header, image) - - def load_image(self, image): - print("Loading image ... ", end="", flush=True) - self.header, self.image = self.open_image(image) - self.exposure = self.header.get("EXPOSURE") - print("DONE", flush=True) - - def load_expmap(self, expmap): - if expmap: - print("Loading exposure map ... ", end="", flush=True) - __, self.expmap = self.open_image(expmap) - print("DONE", flush=True) - - def load_bkgmap(self, bkgmap): - if bkgmap: - print("Loading background map ... ", end="", flush=True) - header, self.bkgmap = self.open_image(bkgmap) - self.exposure_bkg = header.get("EXPOSURE") - print("DONE", flush=True) - - def fix_shapes(self, tolerance=2): - """ - Fix the shapes of self.expmap and self.bkgmap to make them have - the same shape as the self.image. - - NOTE: - * if the image is bigger than the reference image, then its - columns on the right and rows on the botton are clipped; - * if the image is smaller than the reference image, then padding - columns on the right and rows on the botton are added. - * Original images are REPLACED! - - Arguments: - * tolerance: allow absolute difference between images - """ - def _fix_shape(img, ref, tol=tolerance): - if img.shape == ref.shape: - print("SKIPPED", flush=True) - return img - elif np.allclose(img.shape, ref.shape, atol=tol): - print(img.shape, "->", ref.shape, flush=True) - rows, cols = img.shape - rows_ref, cols_ref = ref.shape - # rows - if rows > rows_ref: - img_fixed = img[:rows_ref, :] - else: - img_fixed = np.row_stack((img, - np.zeros((rows_ref-rows, cols), dtype=img.dtype))) - # columns - if cols > cols_ref: - img_fixed = img_fixed[:, :cols_ref] - else: - img_fixed = np.column_stack((img_fixed, - np.zeros((rows_ref, cols_ref-cols), dtype=img.dtype))) - return img_fixed - else: - raise ValueError("shape difference exceeds tolerance: " + \ - "(%d, %d) vs. (%d, %d)" % (img.shape + ref.shape)) - # - if self.bkgmap is not None: - print("Fixing shape for bkgmap ... ", end="", flush=True) - self.bkgmap = _fix_shape(self.bkgmap, self.image) - if self.expmap is not None: - print("Fixing shape for expmap ... ", end="", flush=True) - self.expmap = _fix_shape(self.expmap, self.image) - - def subtract_bkg(self): - print("Subtracting background ... ", end="", flush=True) - self.image -= (self.bkgmap / self.exposure_bkg * self.exposure) - print("DONE", flush=True) - - def correct_exposure(self, cut=0.015): - """ - Correct the image for exposure by dividing by the expmap to - create the exposure-corrected image. - - Arguments: - * cut: the threshold percentage with respect to the maximum - exposure map value; and those pixels with lower values - than this threshold will be excluded/clipped (set to ZERO) - if set to None, then skip clipping image - """ - print("Correcting image for exposure ... ", end="", flush=True) - with np.errstate(divide="ignore", invalid="ignore"): - self.image /= self.expmap - # set invalid values to ZERO - self.image[ ~ np.isfinite(self.image) ] = 0.0 - print("DONE", flush=True) - if cut is not None: - # clip image according the exposure threshold - print("Clipping image (%s) ... " % cut, end="", flush=True) - threshold = cut * np.max(self.expmap) - self.image[ self.expmap < threshold ] = 0.0 - print("DONE", flush=True) + with fits.open(infile) as f: + header = f[0].header + data = f[0].data + if data.ndim == 2: + # NAXIS=2: [Y, X] + image = data + elif data.ndim == 3 and data.shape[0] == 1: + # NAXIS=3: [FREQ=1, Y, X] + image = data[0, :, :] + elif data.ndim == 4 and data.shape[0] == 1 and data.shape[1] == 1: + # NAXIS=4: [STOKES=1, FREQ=1, Y, X] + image = data[0, 0, :, :] + else: + raise ValueError("Slice '{0}' has invalid dimensions: {1}".format( + infile, data.shape)) + return (header, image) def main(): @@ -378,10 +261,6 @@ def main(): description="Calculate radially averaged power spectral density") parser.add_argument("-C", "--clobber", dest="clobber", action="store_true", help="overwrite the output files if already exist") - parser.add_argument("-b", "--bkgmap", dest="bkgmap", default=None, - help="background image (for background subtraction)") - parser.add_argument("-e", "--expmap", dest="expmap", default=None, - help="exposure map (for exposure correction)") parser.add_argument("-i", "--infile", dest="infile", required=True, help="input FITS image") parser.add_argument("-o", "--outfile", dest="outfile", required=True, @@ -400,24 +279,15 @@ def main(): if (not args.clobber) and os.path.exists(plotfile): raise OSError("output plot file '%s' already exists" % plotfile) - # Load image data - image = AstroImage(image=args.infile, - expmap=args.expmap, - bkgmap=args.bkgmap) - image.fix_shapes() - if args.bkgmap: - image.subtract_bkg() - if args.expmap: - image.correct_exposure() - - # Calculate the power spectral density - psd = PSD(img=image.image, normalize=True) + header, image = open_image(args.infile) + psd = PSD(img=image, normalize=True) psd.calc_psd2d() - freqs, psd1d, psd1d_err = psd.calc_radial_psd1d() + freqs, psd, psd_err = psd.calc_psd() # Write out PSD results - psd_data = np.column_stack((freqs, psd1d, psd1d_err)) - np.savetxt(args.outfile, psd_data, header="freqs psd1d psd1d_err") + psd_data = np.column_stack((freqs, psd, psd_err)) + np.savetxt(args.outfile, psd_data, header="freqs psd psd_err") + print("Saved PSD data to: %s" % args.outfile) if args.plot: # Make and save a plot |