diff options
Diffstat (limited to 'astro')
-rwxr-xr-x | astro/ps2d.py | 62 |
1 files changed, 50 insertions, 12 deletions
diff --git a/astro/ps2d.py b/astro/ps2d.py index 38f9f9f..9009826 100755 --- a/astro/ps2d.py +++ b/astro/ps2d.py @@ -33,6 +33,7 @@ import os import sys import argparse import logging +from functools import lru_cache import numpy as np from scipy import fftpack @@ -63,11 +64,13 @@ OmegaM0 = 0.27 cosmo = FlatLambdaCDM(H0=H0, Om0=OmegaM0) +@lru_cache() def freq2z(freq): z = freq21cm / freq - 1.0 return z +@lru_cache() def get_frequencies(wcs, nfreq): pix = np.zeros(shape=(nfreq, 3), dtype=int) pix[:, -1] = np.arange(nfreq) @@ -171,10 +174,10 @@ class PS2D: logger.info("Applying window along frequency axis ...") self.cube *= self.window[:, np.newaxis, np.newaxis] - logger.info("Calculating 3D FFT ...") + logger.info("3D FFTing data cube ...") cubefft = fftpack.fftshift(fftpack.fftn(self.cube)) - logger.info("Calculating 3D PS ...") + logger.info("Calculating 3D power spectrum ...") ps3d = np.abs(cubefft) ** 2 # [K^2] # Normalization norm1 = 1 / (self.Nx * self.Ny * self.Nz) @@ -245,24 +248,46 @@ class PS2D: hdu.writeto(outfile, clobber=clobber) logger.info("Wrote 2D power spectrum to file: %s" % outfile) - def plot(self, fig, ax): + def plot(self, ax, ax_err, colormap="jet"): """ Plot the calculated 2D power spectrum. """ x = self.k_perp y = self.k_los + + if self.meanstd: + title = "2D Power Spectrum (mean)" + title_err = "Error (standard deviation)" + else: + title = "2D Power Spectrum (median)" + title_err = "Error (68% IQR)" + + # median/mean mappable = ax.pcolormesh(x[1:], y[1:], np.log10(self.ps2d[0, 1:, 1:]), - cmap="jet") + cmap=colormap) ax.set(xscale="log", yscale="log", xlim=(x[1], x[-1]), ylim=(y[1], y[-1]), xlabel=r"k$_{\perp}$ [Mpc$^{-1}$]", ylabel=r"k$_{||}$ [Mpc$^{-1}$]", - title="2D Power Spectrum") - cb = fig.colorbar(mappable, ax=ax, pad=0.01, aspect=30) + title=title) + cb = ax.figure.colorbar(mappable, ax=ax, pad=0.01, aspect=30) cb.ax.set_xlabel(r"[%s$^2$ Mpc$^3$]" % self.unit) - fig.tight_layout() - return (fig, ax) + + # error (68% IQR / standard deviation) + error = 0.5 * (self.ps2d[1, :, :] + self.ps2d[2, :, :]) + mappable = ax_err.pcolormesh(x[1:], y[1:], + np.log10(error[1:, 1:]), + cmap=colormap) + ax_err.set(xscale="log", yscale="log", + xlim=(x[1], x[-1]), ylim=(y[1], y[-1]), + xlabel=r"k$_{\perp}$ [Mpc$^{-1}$]", + ylabel=r"k$_{||}$ [Mpc$^{-1}$]", + title=title_err) + cb = ax_err.figure.colorbar(mappable, ax=ax_err, pad=0.01, aspect=30) + cb.ax.set_xlabel(r"[%s$^2$ Mpc$^3$]" % self.unit) + + return (ax, ax_err) @property def Nx(self): @@ -281,6 +306,7 @@ class PS2D: return self.cube.shape[0] @property + @lru_cache() def d_xy(self): """ The sampling interval along the (X, Y) spatial dimensions, @@ -294,6 +320,7 @@ class PS2D: return d_xy @property + @lru_cache() def d_z(self): """ The sampling interval along the Z line-of-sight dimension, @@ -310,6 +337,7 @@ class PS2D: return d_z @property + @lru_cache() def fs_xy(self): """ The sampling frequency along the (X, Y) spatial dimensions: @@ -319,6 +347,7 @@ class PS2D: return 1/self.d_xy @property + @lru_cache() def fs_z(self): """ The sampling frequency along the Z line-of-sight dimension. @@ -327,6 +356,7 @@ class PS2D: return 1/self.d_z @property + @lru_cache() def df_xy(self): """ The spatial frequency bin size (i.e., resolution) along the @@ -336,6 +366,7 @@ class PS2D: return self.fs_xy / self.Nx @property + @lru_cache() def df_z(self): """ The spatial frequency bin size (i.e., resolution) along the @@ -352,10 +383,12 @@ class PS2D: return 2*np.pi * self.df_xy @property + @lru_cache() def dk_z(self): return 2*np.pi * self.df_z @property + @lru_cache() def k_xy(self): """ The k-space coordinates along the (X, Y) spatial dimensions, @@ -372,12 +405,14 @@ class PS2D: return k_xy @property + @lru_cache() def k_z(self): f_z = fftpack.fftshift(fftpack.fftfreq(self.Nz, d=self.d_z)) k_z = 2*np.pi * f_z return k_z @property + @lru_cache() def k_perp(self): """ Comoving wavenumbers perpendicular to the LoS @@ -389,6 +424,7 @@ class PS2D: return k_x[k_x >= 0] @property + @lru_cache() def k_los(self): """ Comoving wavenumbers along the LoS @@ -524,12 +560,14 @@ def main(): ps2d.save(outfile=args.outfile, clobber=args.clobber) if args.plot: - fig = Figure(figsize=(9, 8)) + fig = Figure(figsize=(16, 8), dpi=150) FigureCanvas(fig) - ax = fig.add_subplot(1, 1, 1) - ps2d.plot(ax=ax, fig=fig) + ax = fig.add_subplot(1, 2, 1) + ax_err = fig.add_subplot(1, 2, 2) + ps2d.plot(ax=ax, ax_err=ax_err) + fig.tight_layout() plotfile = os.path.splitext(args.outfile)[0] + ".png" - fig.savefig(plotfile, format="png", dpi=150) + fig.savefig(plotfile) logger.info("Plotted 2D PSD and saved to image: %s" % plotfile) |