diff options
author | Aaron LI <aly@aaronly.me> | 2017-11-03 10:33:51 +0800 |
---|---|---|
committer | Aaron LI <aly@aaronly.me> | 2017-11-03 10:34:06 +0800 |
commit | 498e32c7c579970341c6dad9317ae79e2ef2c53f (patch) | |
tree | e92985992dfeb20b548400f1f7be2d4e8ca7ea11 | |
parent | 1200c585dac0405cdb3106663752db8b5db53c35 (diff) | |
download | atoolbox-498e32c7c579970341c6dad9317ae79e2ef2c53f.tar.bz2 |
astro/calc_psd.py: Improve plot function
Rename argument -p/--plot to -P/--plot
-rwxr-xr-x | astro/calc_psd.py | 33 |
1 files changed, 14 insertions, 19 deletions
diff --git a/astro/calc_psd.py b/astro/calc_psd.py index f93335c..018c52a 100755 --- a/astro/calc_psd.py +++ b/astro/calc_psd.py @@ -177,13 +177,10 @@ class PSD: y = rho * np.sin(phi) return (x, y) - def plot(self, ax=None, fig=None): + def plot(self, ax): """ - Make a plot of the 1D radial PSD with matplotlib. + Make a plot of the 1D radial power spectrum. """ - if ax is None: - fig, ax = plt.subplots(1, 1) - freqs = self.frequencies xmin = freqs[1] / 1.2 # ignore the first 0 xmax = freqs[-1] * 1.1 @@ -191,16 +188,13 @@ class PSD: ymax = np.max(self.psd1d[1:] + self.psd1d_err[1:]) * 2 ax.errorbar(freqs, self.psd1d, yerr=self.psd1d_err, fmt="none") - ax.plot(freqs, self.psd1d, "ko") - ax.set_xscale("log") - ax.set_yscale("log") - ax.set_xlim(xmin, xmax) - ax.set_ylim(ymin, ymax) - ax.set_title("Radial (Azimuthally Averaged) Power Spectral Density") - ax.set_xlabel(r"k [%s$^{-1}$]" % self.pixel[1]) - ax.set_ylabel("Power") - fig.tight_layout() - return (fig, ax) + ax.plot(freqs, self.psd1d, marker="o") + ax.set(xscale="log", yscale="log", + xlim=(xmin, xmax), ylim=(ymin, ymax), + title="Radial (Azimuthally Averaged) Power Spectral Density", + xlabel=r"k [%s$^{-1}$]" % self.pixel[1], + ylabel="Power") + return ax def open_image(infile): @@ -255,7 +249,7 @@ def main(): "at every frequency point will be calculated, " + "i.e., using a even grid, which may be very slow " + "for very large images!") - parser.add_argument("-p", "--plot", dest="plot", action="store_true", + parser.add_argument("-P", "--plot", dest="plot", action="store_true", help="plot the PSD and save as a PNG image") parser.add_argument("-i", "--infile", dest="infile", nargs="+", help="input FITS image(s); if multiple images " + @@ -298,10 +292,11 @@ def main(): if args.plot: # Make and save a plot - fig = Figure(figsize=(10, 8)) + fig = Figure(figsize=(8, 8)) FigureCanvas(fig) - ax = fig.add_subplot(111) - psdobj.plot(ax=ax, fig=fig) + ax = fig.add_subplot(1, 1, 1) + psdobj.plot(ax=ax) + fig.tight_layout() fig.savefig(plotfile, format="png", dpi=150) print("Plotted PSD and saved to image: %s" % plotfile) |