aboutsummaryrefslogtreecommitdiffstats
path: root/astro/calc_psd.py
diff options
context:
space:
mode:
Diffstat (limited to 'astro/calc_psd.py')
-rwxr-xr-xastro/calc_psd.py130
1 files changed, 80 insertions, 50 deletions
diff --git a/astro/calc_psd.py b/astro/calc_psd.py
index 5153541..6029169 100755
--- a/astro/calc_psd.py
+++ b/astro/calc_psd.py
@@ -18,6 +18,7 @@ Credit
import os
import argparse
+from functools import lru_cache
import numpy as np
from astropy.io import fits
@@ -34,19 +35,6 @@ class PSD:
Computes the 2D power spectral density and the radially averaged power
spectral density (i.e., 1D power spectrum).
"""
- # 2D image data
- img = None
- # value and unit of 1 pixel for the input image
- pixel = (None, None)
- # whether to normalize the power spectral density by image size
- normalize = True
- # 2D power spectral density
- psd2d = None
- # 1D (radially averaged) power spectral density
- freqs = None
- psd1d = None
- psd1d_err = None
-
def __init__(self, image, pixel=(1.0, "pixel"), normalize=True, step=None):
self.image = np.array(image, dtype=np.float)
self.shape = self.image.shape
@@ -56,10 +44,42 @@ class PSD:
self.pixel = pixel
self.normalize = normalize
self.step = step
+ if step is not None and step <= 1:
+ raise ValueError("step must be greater than 1")
@property
+ @lru_cache()
def radii(self):
- pass
+ """
+ The radial (frequency) points where to calculate the powers.
+ If ``self.step`` is ``None``, then the powers at every frequency
+ point are calculated. If ``self.step`` is specified, then a
+ log-even grid is adopted, which can greatly save computation time
+ for large images.
+ """
+ dim_half = (self.shape[0] + 1) // 2
+ x = np.arange(dim_half)
+ if self.step is None:
+ return x
+ else:
+ xmax = x.max()
+ x2 = list(x[x*(self.step-1) <= 1])
+ v1 = x[len(x2)]
+ while v1 < xmax:
+ x2.append(v1)
+ v1 *= self.step
+ x2.append(xmax)
+ return np.array(x2)
+
+ @property
+ @lru_cache()
+ def frequencies(self):
+ """
+ The (spatial) frequencies w.r.t. the above radii.
+ """
+ radii = self.radii
+ freqs = (1 / (self.shape[0] * self.pixel[0])) * radii
+ return freqs
def calc_psd2d(self):
"""
@@ -77,9 +97,9 @@ class PSD:
otherwise has unit ${pixel_unit}^2.
"""
print("Calculating 2D power spectral density ... ", end="", flush=True)
- rows, cols = self.img.shape
+ rows, cols = self.shape
# Compute the power spectral density (i.e., power spectrum)
- imgf = np.fft.fftshift(np.fft.fft2(self.img))
+ imgf = np.fft.fftshift(np.fft.fft2(self.image))
if self.normalize:
norm = rows * cols * self.pixel[0]**2
else:
@@ -103,36 +123,40 @@ class PSD:
if not hasattr(self, "ps2d") or self.psd2d is None:
self.calc_psd2d()
- print("Radially averaging 2D power spectral density ... ")
- psd2d = self.psd2d
- dim = psd2d.shape[0]
+ print("Radially averaging 2D power spectral density ... ",
+ end="", flush=True)
+ dim = self.shape[0]
dim_half = (dim+1) // 2
# NOTE:
# The zero-frequency component is shifted to position of index
# (0-based): (ceil((n-1) / 2), ceil((m-1) / 2))
- px = np.arange(dim_half-dim, dim_half)
- x, y = np.meshgrid(px, px)
+ px = np.arange(dim_half-dim, dim_half)
+ x, y = np.meshgrid(px, px)
rho, phi = self.cart2pol(x, y)
- rho = np.around(rho).astype(np.int)
- radial_psd = np.zeros(dim_half)
- radial_psd_err = np.zeros(dim_half)
- print(" -> radially averaging ... ", end="", flush=True)
- for r in range(dim_half):
- # Get the indices of the elements satisfying rho[i,j]==r
- ii, jj = (rho == r).nonzero()
- # Calculate the mean value at a given radii
- data = psd2d[ii, jj]
- radial_psd[r] = np.nanmean(data)
- radial_psd_err[r] = np.nanstd(data)
- # Calculate frequencies
- f = np.fft.fftfreq(dim, d=self.pixel[0])
- freqs = np.abs(f[:dim_half])
- #
- self.freqs = freqs
- self.psd1d = radial_psd
- self.psd1d_err = radial_psd_err
+
+ radii = self.radii
+ nr = len(radii)
+ if nr > 100:
+ print("\n ... many points to calculate, may take a while ... ",
+ end="", flush=True)
+ else:
+ print(" %d data points ... " % nr, end="", flush=True)
+ psd1d = np.zeros(nr)
+ psd1d_err = np.zeros(nr)
+ for i, r in enumerate(radii):
+ if (i+1) % 100 == 0:
+ percent = 100 * (i+1) / nr
+ print("%.1f%% ... " % percent, end="", flush=True)
+ ii, jj = (rho <= r).nonzero()
+ rho[ii, jj] = np.inf
+ data = self.psd2d[ii, jj]
+ psd1d[i] = np.mean(data)
+ psd1d_err[i] = np.std(data)
print("DONE", flush=True)
- return (freqs, radial_psd, radial_psd_err)
+
+ self.psd1d = psd1d
+ self.psd1d_err = psd1d_err
+ return (self.frequencies, psd1d, psd1d_err)
@staticmethod
def cart2pol(x, y):
@@ -159,14 +183,14 @@ class PSD:
if ax is None:
fig, ax = plt.subplots(1, 1)
#
- xmin = self.freqs[1] / 1.2 # ignore the first 0
- xmax = self.freqs[-1]
+ xmin = self.frequencies[1] / 1.2 # ignore the first 0
+ xmax = self.frequencies[-1]
ymin = np.nanmin(self.psd1d) / 10.0
ymax = np.nanmax(self.psd1d + self.psd1d_err)
#
- eb = ax.errorbar(self.freqs, self.psd1d, yerr=self.psd1d_err,
- fmt="none")
- ax.plot(self.freqs, self.psd1d, "ko")
+ ax.errorbar(self.frequencies, self.psd1d, yerr=self.psd1d_err,
+ fmt="none")
+ ax.plot(self.frequencies, self.psd1d, "ko")
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlim(xmin, xmax)
@@ -226,6 +250,13 @@ def main():
description="Calculate radially averaged power spectral density")
parser.add_argument("-C", "--clobber", dest="clobber", action="store_true",
help="overwrite the output files if already exist")
+ parser.add_argument("-s", "--step", dest="step", type=float, default=None,
+ help="step ratio between 2 consecutive radial " +
+ "frequency points, must be > 1, thus a log-even " +
+ "grid is adopted; if not specified, then the power " +
+ "at every frequency point will be calculated, " +
+ "i.e., using a even grid, which may be very slow " +
+ "for very large images!")
parser.add_argument("-i", "--infile", dest="infile", required=True,
help="input FITS image")
parser.add_argument("-o", "--outfile", dest="outfile", required=True,
@@ -245,9 +276,8 @@ def main():
raise OSError("output plot file '%s' already exists" % plotfile)
header, image = open_image(args.infile)
- psd = PSD(image=image, normalize=True)
- psd.calc_psd2d()
- freqs, psd, psd_err = psd.calc_psd()
+ psdobj = PSD(image=image, normalize=True, step=args.step)
+ freqs, psd, psd_err = psdobj.calc_psd()
# Write out PSD results
psd_data = np.column_stack((freqs, psd, psd_err))
@@ -259,9 +289,9 @@ def main():
fig = Figure(figsize=(10, 8))
FigureCanvas(fig)
ax = fig.add_subplot(111)
- psd.plot(ax=ax, fig=fig)
+ psdobj.plot(ax=ax, fig=fig)
fig.savefig(plotfile, format="png", dpi=150)
- print("Plotted PSD and saved as: %s" % plotfile)
+ print("Plotted PSD and saved to image: %s" % plotfile)
if __name__ == "__main__":