diff options
-rwxr-xr-x | astro/calc_psd.py | 130 |
1 files changed, 80 insertions, 50 deletions
diff --git a/astro/calc_psd.py b/astro/calc_psd.py index 5153541..6029169 100755 --- a/astro/calc_psd.py +++ b/astro/calc_psd.py @@ -18,6 +18,7 @@ Credit import os import argparse +from functools import lru_cache import numpy as np from astropy.io import fits @@ -34,19 +35,6 @@ class PSD: Computes the 2D power spectral density and the radially averaged power spectral density (i.e., 1D power spectrum). """ - # 2D image data - img = None - # value and unit of 1 pixel for the input image - pixel = (None, None) - # whether to normalize the power spectral density by image size - normalize = True - # 2D power spectral density - psd2d = None - # 1D (radially averaged) power spectral density - freqs = None - psd1d = None - psd1d_err = None - def __init__(self, image, pixel=(1.0, "pixel"), normalize=True, step=None): self.image = np.array(image, dtype=np.float) self.shape = self.image.shape @@ -56,10 +44,42 @@ class PSD: self.pixel = pixel self.normalize = normalize self.step = step + if step is not None and step <= 1: + raise ValueError("step must be greater than 1") @property + @lru_cache() def radii(self): - pass + """ + The radial (frequency) points where to calculate the powers. + If ``self.step`` is ``None``, then the powers at every frequency + point are calculated. If ``self.step`` is specified, then a + log-even grid is adopted, which can greatly save computation time + for large images. + """ + dim_half = (self.shape[0] + 1) // 2 + x = np.arange(dim_half) + if self.step is None: + return x + else: + xmax = x.max() + x2 = list(x[x*(self.step-1) <= 1]) + v1 = x[len(x2)] + while v1 < xmax: + x2.append(v1) + v1 *= self.step + x2.append(xmax) + return np.array(x2) + + @property + @lru_cache() + def frequencies(self): + """ + The (spatial) frequencies w.r.t. the above radii. + """ + radii = self.radii + freqs = (1 / (self.shape[0] * self.pixel[0])) * radii + return freqs def calc_psd2d(self): """ @@ -77,9 +97,9 @@ class PSD: otherwise has unit ${pixel_unit}^2. """ print("Calculating 2D power spectral density ... ", end="", flush=True) - rows, cols = self.img.shape + rows, cols = self.shape # Compute the power spectral density (i.e., power spectrum) - imgf = np.fft.fftshift(np.fft.fft2(self.img)) + imgf = np.fft.fftshift(np.fft.fft2(self.image)) if self.normalize: norm = rows * cols * self.pixel[0]**2 else: @@ -103,36 +123,40 @@ class PSD: if not hasattr(self, "ps2d") or self.psd2d is None: self.calc_psd2d() - print("Radially averaging 2D power spectral density ... ") - psd2d = self.psd2d - dim = psd2d.shape[0] + print("Radially averaging 2D power spectral density ... ", + end="", flush=True) + dim = self.shape[0] dim_half = (dim+1) // 2 # NOTE: # The zero-frequency component is shifted to position of index # (0-based): (ceil((n-1) / 2), ceil((m-1) / 2)) - px = np.arange(dim_half-dim, dim_half) - x, y = np.meshgrid(px, px) + px = np.arange(dim_half-dim, dim_half) + x, y = np.meshgrid(px, px) rho, phi = self.cart2pol(x, y) - rho = np.around(rho).astype(np.int) - radial_psd = np.zeros(dim_half) - radial_psd_err = np.zeros(dim_half) - print(" -> radially averaging ... ", end="", flush=True) - for r in range(dim_half): - # Get the indices of the elements satisfying rho[i,j]==r - ii, jj = (rho == r).nonzero() - # Calculate the mean value at a given radii - data = psd2d[ii, jj] - radial_psd[r] = np.nanmean(data) - radial_psd_err[r] = np.nanstd(data) - # Calculate frequencies - f = np.fft.fftfreq(dim, d=self.pixel[0]) - freqs = np.abs(f[:dim_half]) - # - self.freqs = freqs - self.psd1d = radial_psd - self.psd1d_err = radial_psd_err + + radii = self.radii + nr = len(radii) + if nr > 100: + print("\n ... many points to calculate, may take a while ... ", + end="", flush=True) + else: + print(" %d data points ... " % nr, end="", flush=True) + psd1d = np.zeros(nr) + psd1d_err = np.zeros(nr) + for i, r in enumerate(radii): + if (i+1) % 100 == 0: + percent = 100 * (i+1) / nr + print("%.1f%% ... " % percent, end="", flush=True) + ii, jj = (rho <= r).nonzero() + rho[ii, jj] = np.inf + data = self.psd2d[ii, jj] + psd1d[i] = np.mean(data) + psd1d_err[i] = np.std(data) print("DONE", flush=True) - return (freqs, radial_psd, radial_psd_err) + + self.psd1d = psd1d + self.psd1d_err = psd1d_err + return (self.frequencies, psd1d, psd1d_err) @staticmethod def cart2pol(x, y): @@ -159,14 +183,14 @@ class PSD: if ax is None: fig, ax = plt.subplots(1, 1) # - xmin = self.freqs[1] / 1.2 # ignore the first 0 - xmax = self.freqs[-1] + xmin = self.frequencies[1] / 1.2 # ignore the first 0 + xmax = self.frequencies[-1] ymin = np.nanmin(self.psd1d) / 10.0 ymax = np.nanmax(self.psd1d + self.psd1d_err) # - eb = ax.errorbar(self.freqs, self.psd1d, yerr=self.psd1d_err, - fmt="none") - ax.plot(self.freqs, self.psd1d, "ko") + ax.errorbar(self.frequencies, self.psd1d, yerr=self.psd1d_err, + fmt="none") + ax.plot(self.frequencies, self.psd1d, "ko") ax.set_xscale("log") ax.set_yscale("log") ax.set_xlim(xmin, xmax) @@ -226,6 +250,13 @@ def main(): description="Calculate radially averaged power spectral density") parser.add_argument("-C", "--clobber", dest="clobber", action="store_true", help="overwrite the output files if already exist") + parser.add_argument("-s", "--step", dest="step", type=float, default=None, + help="step ratio between 2 consecutive radial " + + "frequency points, must be > 1, thus a log-even " + + "grid is adopted; if not specified, then the power " + + "at every frequency point will be calculated, " + + "i.e., using a even grid, which may be very slow " + + "for very large images!") parser.add_argument("-i", "--infile", dest="infile", required=True, help="input FITS image") parser.add_argument("-o", "--outfile", dest="outfile", required=True, @@ -245,9 +276,8 @@ def main(): raise OSError("output plot file '%s' already exists" % plotfile) header, image = open_image(args.infile) - psd = PSD(image=image, normalize=True) - psd.calc_psd2d() - freqs, psd, psd_err = psd.calc_psd() + psdobj = PSD(image=image, normalize=True, step=args.step) + freqs, psd, psd_err = psdobj.calc_psd() # Write out PSD results psd_data = np.column_stack((freqs, psd, psd_err)) @@ -259,9 +289,9 @@ def main(): fig = Figure(figsize=(10, 8)) FigureCanvas(fig) ax = fig.add_subplot(111) - psd.plot(ax=ax, fig=fig) + psdobj.plot(ax=ax, fig=fig) fig.savefig(plotfile, format="png", dpi=150) - print("Plotted PSD and saved as: %s" % plotfile) + print("Plotted PSD and saved to image: %s" % plotfile) if __name__ == "__main__": |