diff options
author | Aaron LI <aly@aaronly.me> | 2017-12-04 21:22:10 +0800 |
---|---|---|
committer | Aaron LI <aly@aaronly.me> | 2017-12-04 21:22:10 +0800 |
commit | 0bb53d7565b00d27e2a0d5c8bd949586b344cfcb (patch) | |
tree | 7aae04b8f9bf1c9954687f2a656050105d72c14c /astro/fits | |
parent | df238cb5cb76356315ec5478ff67748121d491bf (diff) | |
download | atoolbox-0bb53d7565b00d27e2a0d5c8bd949586b344cfcb.tar.bz2 |
astro/fitscube.py: Add get_slice() method and simplify "info" subcommand
Diffstat (limited to 'astro/fits')
-rwxr-xr-x | astro/fits/fitscube.py | 32 |
1 files changed, 18 insertions, 14 deletions
diff --git a/astro/fits/fitscube.py b/astro/fits/fitscube.py index a297662..13d9be0 100755 --- a/astro/fits/fitscube.py +++ b/astro/fits/fitscube.py @@ -177,10 +177,23 @@ class FITSCube: """ return (self.data[i, :, :] for i in range(self.nslice)) + def get_slice(self, i, csize=None): + """ + Get the i-th (0-based) slice image, and crop out the central box + of size ``csize`` if specified. + """ + if csize is None: + return self.data[i, :, :] + else: + rows, cols = self.height, self.width + rc, cc = rows//2, cols//2 + cs1, cs2 = csize//2, (csize+1)//2 + return self.data[i, (rc-cs1):(rc+cs2), (cc-cs1):(cc+cs2)] + @property def unit(self): """ - Data cube unit. + Cube data unit. """ return self.header.get("BUNIT") @@ -219,19 +232,10 @@ def cmd_info(args): if args.meanstd: mean = np.zeros(cube.nslice) std = np.zeros(cube.nslice) - if args.center: - print("Central spatial box size: %d" % args.center) - rows, cols = cube.height, cube.width - rc, cc = rows//2, cols//2 - cs1, cs2 = args.center//2, (args.center+1)//2 - for i, image in enumerate(cube.slices): - data = image[(rc-cs1):(rc+cs2), (cc-cs1):(cc+cs2)] - mean[i] = np.mean(data) - std[i] = np.std(data) - else: - for i, image in enumerate(cube.slices): - mean[i] = np.mean(image) - std[i] = np.std(image) + for i in range(cube.nslice): + image = cube.get_slice(i, csize=args.center) + mean[i] = np.mean(image) + std[i] = np.std(image) print("Slice <z> <mean> +/- <std>:") for i, z in enumerate(zvalues): print("* %12.4e: %-12.4e %-12.4e" % (z, mean[i], std[i])) |