diff options
author | Aaron LI <aly@aaronly.me> | 2017-10-28 16:10:18 +0800 |
---|---|---|
committer | Aaron LI <aly@aaronly.me> | 2017-10-28 17:01:23 +0800 |
commit | d97cc408aa7c9dd297efce43dfed22311a3a98b2 (patch) | |
tree | 7be7e05aab84ec2417cba251e811761268ced06e | |
parent | 9aa50b115cd2650ff8cae9391a0b59560589dc89 (diff) | |
download | atoolbox-d97cc408aa7c9dd297efce43dfed22311a3a98b2.tar.bz2 |
calc_psd.py: Remove unnecessary psd2d padding
-rwxr-xr-x | astro/calc_psd.py | 75 |
1 files changed, 20 insertions, 55 deletions
diff --git a/astro/calc_psd.py b/astro/calc_psd.py index 730c6f0..5153541 100755 --- a/astro/calc_psd.py +++ b/astro/calc_psd.py @@ -6,10 +6,7 @@ """ Compute the radially averaged power spectral density (i.e., power spectrum) -of a 2D image. - -XXX: If the input image is NOT SQUARE; then are the horizontal frequencies - the same as the vertical frequencies ?? +of a 2D image (in FITS format). The input image must be square. Credit ------ @@ -50,10 +47,19 @@ class PSD: psd1d = None psd1d_err = None - def __init__(self, img, pixel=(1.0, "pixel"), normalize=True): - self.img = img.astype(np.float) + def __init__(self, image, pixel=(1.0, "pixel"), normalize=True, step=None): + self.image = np.array(image, dtype=np.float) + self.shape = self.image.shape + if self.shape[0] != self.shape[1]: + raise ValueError("input image is not square!") + self.pixel = pixel self.normalize = normalize + self.step = step + + @property + def radii(self): + pass def calc_psd2d(self): """ @@ -94,11 +100,12 @@ class PSD: frequency radial_psd_err: standard deviations of each radial_psd """ - print("Calculating radial (1D) power spectral density ... ", - end="", flush=True) - print("padding ... ", end="", flush=True) - psd2d = self.pad_square(self.psd2d, value=np.nan) - dim = psd2d.shape[0] + if not hasattr(self, "ps2d") or self.psd2d is None: + self.calc_psd2d() + + print("Radially averaging 2D power spectral density ... ") + psd2d = self.psd2d + dim = psd2d.shape[0] dim_half = (dim+1) // 2 # NOTE: # The zero-frequency component is shifted to position of index @@ -109,7 +116,7 @@ class PSD: rho = np.around(rho).astype(np.int) radial_psd = np.zeros(dim_half) radial_psd_err = np.zeros(dim_half) - print("radially averaging ... ", end="", flush=True) + print(" -> radially averaging ... ", end="", flush=True) for r in range(dim_half): # Get the indices of the elements satisfying rho[i,j]==r ii, jj = (rho == r).nonzero() @@ -145,48 +152,6 @@ class PSD: y = rho * np.sin(phi) return (x, y) - @staticmethod - def pad_square(data, value=np.nan): - """ - Symmetrically pad the supplied data matrix to make it square. - The padding rows are equally added to the top and bottom, - as well as the columns to the left and right sides. - The padded rows/columns are filled with the specified value. - """ - mat = data.copy() - rows, cols = mat.shape - dim_diff = abs(rows - cols) - dim_max = max(rows, cols) - if rows > cols: - # pad columns - if dim_diff // 2 == 0: - cols_left = np.zeros((rows, dim_diff/2)) - cols_left[:] = value - cols_right = np.zeros((rows, dim_diff/2)) - cols_right[:] = value - mat = np.hstack((cols_left, mat, cols_right)) - else: - cols_left = np.zeros((rows, np.floor(dim_diff/2))) - cols_left[:] = value - cols_right = np.zeros((rows, np.floor(dim_diff/2)+1)) - cols_right[:] = value - mat = np.hstack((cols_left, mat, cols_right)) - elif rows < cols: - # pad rows - if dim_diff // 2 == 0: - rows_top = np.zeros((dim_diff/2, cols)) - rows_top[:] = value - rows_bottom = np.zeros((dim_diff/2, cols)) - rows_bottom[:] = value - mat = np.vstack((rows_top, mat, rows_bottom)) - else: - rows_top = np.zeros((np.floor(dim_diff/2), cols)) - rows_top[:] = value - rows_bottom = np.zeros((np.floor(dim_diff/2)+1, cols)) - rows_bottom[:] = value - mat = np.vstack((rows_top, mat, rows_bottom)) - return mat - def plot(self, ax=None, fig=None): """ Make a plot of the radial (1D) PSD with matplotlib. @@ -280,7 +245,7 @@ def main(): raise OSError("output plot file '%s' already exists" % plotfile) header, image = open_image(args.infile) - psd = PSD(img=image, normalize=True) + psd = PSD(image=image, normalize=True) psd.calc_psd2d() freqs, psd, psd_err = psd.calc_psd() |