aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xastro/ps2d.py62
1 files changed, 50 insertions, 12 deletions
diff --git a/astro/ps2d.py b/astro/ps2d.py
index 38f9f9f..9009826 100755
--- a/astro/ps2d.py
+++ b/astro/ps2d.py
@@ -33,6 +33,7 @@ import os
import sys
import argparse
import logging
+from functools import lru_cache
import numpy as np
from scipy import fftpack
@@ -63,11 +64,13 @@ OmegaM0 = 0.27
cosmo = FlatLambdaCDM(H0=H0, Om0=OmegaM0)
+@lru_cache()
def freq2z(freq):
z = freq21cm / freq - 1.0
return z
+@lru_cache()
def get_frequencies(wcs, nfreq):
pix = np.zeros(shape=(nfreq, 3), dtype=int)
pix[:, -1] = np.arange(nfreq)
@@ -171,10 +174,10 @@ class PS2D:
logger.info("Applying window along frequency axis ...")
self.cube *= self.window[:, np.newaxis, np.newaxis]
- logger.info("Calculating 3D FFT ...")
+ logger.info("3D FFTing data cube ...")
cubefft = fftpack.fftshift(fftpack.fftn(self.cube))
- logger.info("Calculating 3D PS ...")
+ logger.info("Calculating 3D power spectrum ...")
ps3d = np.abs(cubefft) ** 2 # [K^2]
# Normalization
norm1 = 1 / (self.Nx * self.Ny * self.Nz)
@@ -245,24 +248,46 @@ class PS2D:
hdu.writeto(outfile, clobber=clobber)
logger.info("Wrote 2D power spectrum to file: %s" % outfile)
- def plot(self, fig, ax):
+ def plot(self, ax, ax_err, colormap="jet"):
"""
Plot the calculated 2D power spectrum.
"""
x = self.k_perp
y = self.k_los
+
+ if self.meanstd:
+ title = "2D Power Spectrum (mean)"
+ title_err = "Error (standard deviation)"
+ else:
+ title = "2D Power Spectrum (median)"
+ title_err = "Error (68% IQR)"
+
+ # median/mean
mappable = ax.pcolormesh(x[1:], y[1:],
np.log10(self.ps2d[0, 1:, 1:]),
- cmap="jet")
+ cmap=colormap)
ax.set(xscale="log", yscale="log",
xlim=(x[1], x[-1]), ylim=(y[1], y[-1]),
xlabel=r"k$_{\perp}$ [Mpc$^{-1}$]",
ylabel=r"k$_{||}$ [Mpc$^{-1}$]",
- title="2D Power Spectrum")
- cb = fig.colorbar(mappable, ax=ax, pad=0.01, aspect=30)
+ title=title)
+ cb = ax.figure.colorbar(mappable, ax=ax, pad=0.01, aspect=30)
cb.ax.set_xlabel(r"[%s$^2$ Mpc$^3$]" % self.unit)
- fig.tight_layout()
- return (fig, ax)
+
+ # error (68% IQR / standard deviation)
+ error = 0.5 * (self.ps2d[1, :, :] + self.ps2d[2, :, :])
+ mappable = ax_err.pcolormesh(x[1:], y[1:],
+ np.log10(error[1:, 1:]),
+ cmap=colormap)
+ ax_err.set(xscale="log", yscale="log",
+ xlim=(x[1], x[-1]), ylim=(y[1], y[-1]),
+ xlabel=r"k$_{\perp}$ [Mpc$^{-1}$]",
+ ylabel=r"k$_{||}$ [Mpc$^{-1}$]",
+ title=title_err)
+ cb = ax_err.figure.colorbar(mappable, ax=ax_err, pad=0.01, aspect=30)
+ cb.ax.set_xlabel(r"[%s$^2$ Mpc$^3$]" % self.unit)
+
+ return (ax, ax_err)
@property
def Nx(self):
@@ -281,6 +306,7 @@ class PS2D:
return self.cube.shape[0]
@property
+ @lru_cache()
def d_xy(self):
"""
The sampling interval along the (X, Y) spatial dimensions,
@@ -294,6 +320,7 @@ class PS2D:
return d_xy
@property
+ @lru_cache()
def d_z(self):
"""
The sampling interval along the Z line-of-sight dimension,
@@ -310,6 +337,7 @@ class PS2D:
return d_z
@property
+ @lru_cache()
def fs_xy(self):
"""
The sampling frequency along the (X, Y) spatial dimensions:
@@ -319,6 +347,7 @@ class PS2D:
return 1/self.d_xy
@property
+ @lru_cache()
def fs_z(self):
"""
The sampling frequency along the Z line-of-sight dimension.
@@ -327,6 +356,7 @@ class PS2D:
return 1/self.d_z
@property
+ @lru_cache()
def df_xy(self):
"""
The spatial frequency bin size (i.e., resolution) along the
@@ -336,6 +366,7 @@ class PS2D:
return self.fs_xy / self.Nx
@property
+ @lru_cache()
def df_z(self):
"""
The spatial frequency bin size (i.e., resolution) along the
@@ -352,10 +383,12 @@ class PS2D:
return 2*np.pi * self.df_xy
@property
+ @lru_cache()
def dk_z(self):
return 2*np.pi * self.df_z
@property
+ @lru_cache()
def k_xy(self):
"""
The k-space coordinates along the (X, Y) spatial dimensions,
@@ -372,12 +405,14 @@ class PS2D:
return k_xy
@property
+ @lru_cache()
def k_z(self):
f_z = fftpack.fftshift(fftpack.fftfreq(self.Nz, d=self.d_z))
k_z = 2*np.pi * f_z
return k_z
@property
+ @lru_cache()
def k_perp(self):
"""
Comoving wavenumbers perpendicular to the LoS
@@ -389,6 +424,7 @@ class PS2D:
return k_x[k_x >= 0]
@property
+ @lru_cache()
def k_los(self):
"""
Comoving wavenumbers along the LoS
@@ -524,12 +560,14 @@ def main():
ps2d.save(outfile=args.outfile, clobber=args.clobber)
if args.plot:
- fig = Figure(figsize=(9, 8))
+ fig = Figure(figsize=(16, 8), dpi=150)
FigureCanvas(fig)
- ax = fig.add_subplot(1, 1, 1)
- ps2d.plot(ax=ax, fig=fig)
+ ax = fig.add_subplot(1, 2, 1)
+ ax_err = fig.add_subplot(1, 2, 2)
+ ps2d.plot(ax=ax, ax_err=ax_err)
+ fig.tight_layout()
plotfile = os.path.splitext(args.outfile)[0] + ".png"
- fig.savefig(plotfile, format="png", dpi=150)
+ fig.savefig(plotfile)
logger.info("Plotted 2D PSD and saved to image: %s" % plotfile)