aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAaron LI <aly@aaronly.me>2017-11-03 10:33:51 +0800
committerAaron LI <aly@aaronly.me>2017-11-03 10:34:06 +0800
commit498e32c7c579970341c6dad9317ae79e2ef2c53f (patch)
treee92985992dfeb20b548400f1f7be2d4e8ca7ea11
parent1200c585dac0405cdb3106663752db8b5db53c35 (diff)
downloadatoolbox-498e32c7c579970341c6dad9317ae79e2ef2c53f.tar.bz2
astro/calc_psd.py: Improve plot function
Rename argument -p/--plot to -P/--plot
-rwxr-xr-xastro/calc_psd.py33
1 files changed, 14 insertions, 19 deletions
diff --git a/astro/calc_psd.py b/astro/calc_psd.py
index f93335c..018c52a 100755
--- a/astro/calc_psd.py
+++ b/astro/calc_psd.py
@@ -177,13 +177,10 @@ class PSD:
y = rho * np.sin(phi)
return (x, y)
- def plot(self, ax=None, fig=None):
+ def plot(self, ax):
"""
- Make a plot of the 1D radial PSD with matplotlib.
+ Make a plot of the 1D radial power spectrum.
"""
- if ax is None:
- fig, ax = plt.subplots(1, 1)
-
freqs = self.frequencies
xmin = freqs[1] / 1.2 # ignore the first 0
xmax = freqs[-1] * 1.1
@@ -191,16 +188,13 @@ class PSD:
ymax = np.max(self.psd1d[1:] + self.psd1d_err[1:]) * 2
ax.errorbar(freqs, self.psd1d, yerr=self.psd1d_err, fmt="none")
- ax.plot(freqs, self.psd1d, "ko")
- ax.set_xscale("log")
- ax.set_yscale("log")
- ax.set_xlim(xmin, xmax)
- ax.set_ylim(ymin, ymax)
- ax.set_title("Radial (Azimuthally Averaged) Power Spectral Density")
- ax.set_xlabel(r"k [%s$^{-1}$]" % self.pixel[1])
- ax.set_ylabel("Power")
- fig.tight_layout()
- return (fig, ax)
+ ax.plot(freqs, self.psd1d, marker="o")
+ ax.set(xscale="log", yscale="log",
+ xlim=(xmin, xmax), ylim=(ymin, ymax),
+ title="Radial (Azimuthally Averaged) Power Spectral Density",
+ xlabel=r"k [%s$^{-1}$]" % self.pixel[1],
+ ylabel="Power")
+ return ax
def open_image(infile):
@@ -255,7 +249,7 @@ def main():
"at every frequency point will be calculated, " +
"i.e., using a even grid, which may be very slow " +
"for very large images!")
- parser.add_argument("-p", "--plot", dest="plot", action="store_true",
+ parser.add_argument("-P", "--plot", dest="plot", action="store_true",
help="plot the PSD and save as a PNG image")
parser.add_argument("-i", "--infile", dest="infile", nargs="+",
help="input FITS image(s); if multiple images " +
@@ -298,10 +292,11 @@ def main():
if args.plot:
# Make and save a plot
- fig = Figure(figsize=(10, 8))
+ fig = Figure(figsize=(8, 8))
FigureCanvas(fig)
- ax = fig.add_subplot(111)
- psdobj.plot(ax=ax, fig=fig)
+ ax = fig.add_subplot(1, 1, 1)
+ psdobj.plot(ax=ax)
+ fig.tight_layout()
fig.savefig(plotfile, format="png", dpi=150)
print("Plotted PSD and saved to image: %s" % plotfile)