aboutsummaryrefslogtreecommitdiffstats
path: root/astro
diff options
context:
space:
mode:
authorAaron LI <aly@aaronly.me>2017-12-04 21:22:10 +0800
committerAaron LI <aly@aaronly.me>2017-12-04 21:22:10 +0800
commit0bb53d7565b00d27e2a0d5c8bd949586b344cfcb (patch)
tree7aae04b8f9bf1c9954687f2a656050105d72c14c /astro
parentdf238cb5cb76356315ec5478ff67748121d491bf (diff)
downloadatoolbox-0bb53d7565b00d27e2a0d5c8bd949586b344cfcb.tar.bz2
astro/fitscube.py: Add get_slice() method and simplify "info" subcommand
Diffstat (limited to 'astro')
-rwxr-xr-xastro/fits/fitscube.py32
1 files changed, 18 insertions, 14 deletions
diff --git a/astro/fits/fitscube.py b/astro/fits/fitscube.py
index a297662..13d9be0 100755
--- a/astro/fits/fitscube.py
+++ b/astro/fits/fitscube.py
@@ -177,10 +177,23 @@ class FITSCube:
"""
return (self.data[i, :, :] for i in range(self.nslice))
+ def get_slice(self, i, csize=None):
+ """
+ Get the i-th (0-based) slice image, and crop out the central box
+ of size ``csize`` if specified.
+ """
+ if csize is None:
+ return self.data[i, :, :]
+ else:
+ rows, cols = self.height, self.width
+ rc, cc = rows//2, cols//2
+ cs1, cs2 = csize//2, (csize+1)//2
+ return self.data[i, (rc-cs1):(rc+cs2), (cc-cs1):(cc+cs2)]
+
@property
def unit(self):
"""
- Data cube unit.
+ Cube data unit.
"""
return self.header.get("BUNIT")
@@ -219,19 +232,10 @@ def cmd_info(args):
if args.meanstd:
mean = np.zeros(cube.nslice)
std = np.zeros(cube.nslice)
- if args.center:
- print("Central spatial box size: %d" % args.center)
- rows, cols = cube.height, cube.width
- rc, cc = rows//2, cols//2
- cs1, cs2 = args.center//2, (args.center+1)//2
- for i, image in enumerate(cube.slices):
- data = image[(rc-cs1):(rc+cs2), (cc-cs1):(cc+cs2)]
- mean[i] = np.mean(data)
- std[i] = np.std(data)
- else:
- for i, image in enumerate(cube.slices):
- mean[i] = np.mean(image)
- std[i] = np.std(image)
+ for i in range(cube.nslice):
+ image = cube.get_slice(i, csize=args.center)
+ mean[i] = np.mean(image)
+ std[i] = np.std(image)
print("Slice <z> <mean> +/- <std>:")
for i, z in enumerate(zvalues):
print("* %12.4e: %-12.4e %-12.4e" % (z, mean[i], std[i]))