aboutsummaryrefslogtreecommitdiffstats
path: root/astro
diff options
context:
space:
mode:
authorAaron LI <aly@aaronly.me>2017-10-28 16:10:18 +0800
committerAaron LI <aly@aaronly.me>2017-10-28 17:01:23 +0800
commitd97cc408aa7c9dd297efce43dfed22311a3a98b2 (patch)
tree7be7e05aab84ec2417cba251e811761268ced06e /astro
parent9aa50b115cd2650ff8cae9391a0b59560589dc89 (diff)
downloadatoolbox-d97cc408aa7c9dd297efce43dfed22311a3a98b2.tar.bz2
calc_psd.py: Remove unnecessary psd2d padding
Diffstat (limited to 'astro')
-rwxr-xr-xastro/calc_psd.py75
1 files changed, 20 insertions, 55 deletions
diff --git a/astro/calc_psd.py b/astro/calc_psd.py
index 730c6f0..5153541 100755
--- a/astro/calc_psd.py
+++ b/astro/calc_psd.py
@@ -6,10 +6,7 @@
"""
Compute the radially averaged power spectral density (i.e., power spectrum)
-of a 2D image.
-
-XXX: If the input image is NOT SQUARE; then are the horizontal frequencies
- the same as the vertical frequencies ??
+of a 2D image (in FITS format). The input image must be square.
Credit
------
@@ -50,10 +47,19 @@ class PSD:
psd1d = None
psd1d_err = None
- def __init__(self, img, pixel=(1.0, "pixel"), normalize=True):
- self.img = img.astype(np.float)
+ def __init__(self, image, pixel=(1.0, "pixel"), normalize=True, step=None):
+ self.image = np.array(image, dtype=np.float)
+ self.shape = self.image.shape
+ if self.shape[0] != self.shape[1]:
+ raise ValueError("input image is not square!")
+
self.pixel = pixel
self.normalize = normalize
+ self.step = step
+
+ @property
+ def radii(self):
+ pass
def calc_psd2d(self):
"""
@@ -94,11 +100,12 @@ class PSD:
frequency
radial_psd_err: standard deviations of each radial_psd
"""
- print("Calculating radial (1D) power spectral density ... ",
- end="", flush=True)
- print("padding ... ", end="", flush=True)
- psd2d = self.pad_square(self.psd2d, value=np.nan)
- dim = psd2d.shape[0]
+ if not hasattr(self, "ps2d") or self.psd2d is None:
+ self.calc_psd2d()
+
+ print("Radially averaging 2D power spectral density ... ")
+ psd2d = self.psd2d
+ dim = psd2d.shape[0]
dim_half = (dim+1) // 2
# NOTE:
# The zero-frequency component is shifted to position of index
@@ -109,7 +116,7 @@ class PSD:
rho = np.around(rho).astype(np.int)
radial_psd = np.zeros(dim_half)
radial_psd_err = np.zeros(dim_half)
- print("radially averaging ... ", end="", flush=True)
+ print(" -> radially averaging ... ", end="", flush=True)
for r in range(dim_half):
# Get the indices of the elements satisfying rho[i,j]==r
ii, jj = (rho == r).nonzero()
@@ -145,48 +152,6 @@ class PSD:
y = rho * np.sin(phi)
return (x, y)
- @staticmethod
- def pad_square(data, value=np.nan):
- """
- Symmetrically pad the supplied data matrix to make it square.
- The padding rows are equally added to the top and bottom,
- as well as the columns to the left and right sides.
- The padded rows/columns are filled with the specified value.
- """
- mat = data.copy()
- rows, cols = mat.shape
- dim_diff = abs(rows - cols)
- dim_max = max(rows, cols)
- if rows > cols:
- # pad columns
- if dim_diff // 2 == 0:
- cols_left = np.zeros((rows, dim_diff/2))
- cols_left[:] = value
- cols_right = np.zeros((rows, dim_diff/2))
- cols_right[:] = value
- mat = np.hstack((cols_left, mat, cols_right))
- else:
- cols_left = np.zeros((rows, np.floor(dim_diff/2)))
- cols_left[:] = value
- cols_right = np.zeros((rows, np.floor(dim_diff/2)+1))
- cols_right[:] = value
- mat = np.hstack((cols_left, mat, cols_right))
- elif rows < cols:
- # pad rows
- if dim_diff // 2 == 0:
- rows_top = np.zeros((dim_diff/2, cols))
- rows_top[:] = value
- rows_bottom = np.zeros((dim_diff/2, cols))
- rows_bottom[:] = value
- mat = np.vstack((rows_top, mat, rows_bottom))
- else:
- rows_top = np.zeros((np.floor(dim_diff/2), cols))
- rows_top[:] = value
- rows_bottom = np.zeros((np.floor(dim_diff/2)+1, cols))
- rows_bottom[:] = value
- mat = np.vstack((rows_top, mat, rows_bottom))
- return mat
-
def plot(self, ax=None, fig=None):
"""
Make a plot of the radial (1D) PSD with matplotlib.
@@ -280,7 +245,7 @@ def main():
raise OSError("output plot file '%s' already exists" % plotfile)
header, image = open_image(args.infile)
- psd = PSD(img=image, normalize=True)
+ psd = PSD(image=image, normalize=True)
psd.calc_psd2d()
freqs, psd, psd_err = psd.calc_psd()