diff options
Diffstat (limited to 'python')
-rwxr-xr-x | python/adjust_spectrum_error.py | 170 | ||||
-rwxr-xr-x | python/calc_radial_psd.py | 450 | ||||
-rwxr-xr-x | python/crosstalk_deprojection.py | 1808 | ||||
-rwxr-xr-x | python/fit_sbp.py | 807 | ||||
-rw-r--r-- | python/imapUTF7.py | 189 | ||||
-rwxr-xr-x | python/msvst_starlet.py | 646 | ||||
-rw-r--r-- | python/plot.py | 35 | ||||
-rw-r--r-- | python/plot_tprofiles_zzh.py | 126 | ||||
-rwxr-xr-x | python/randomize_events.py | 72 | ||||
-rwxr-xr-x | python/rebuild_ipod_db.py | 595 | ||||
-rwxr-xr-x | python/splitBoxRegion.py | 148 | ||||
-rw-r--r-- | python/splitCCDgaps.py | 107 | ||||
-rw-r--r-- | python/xkeywordsync.py | 533 |
13 files changed, 5686 insertions, 0 deletions
diff --git a/python/adjust_spectrum_error.py b/python/adjust_spectrum_error.py new file mode 100755 index 0000000..0f80ec7 --- /dev/null +++ b/python/adjust_spectrum_error.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Squeeze the spectrum according to the grouping specification, then +calculate the statistical errors for each group, and apply error +adjustments (e.g., incorporate the systematic uncertainties). +""" + +__version__ = "0.1.0" +__date__ = "2016-01-11" + + +import sys +import argparse + +import numpy as np +from astropy.io import fits + + +class Spectrum: + """ + Spectrum class to keep spectrum information and perform manipulations. + """ + header = None + channel = None + counts = None + grouping = None + quality = None + + def __init__(self, specfile): + f = fits.open(specfile) + spechdu = f['SPECTRUM'] + self.header = spechdu.header + self.channel = spechdu.data.field('CHANNEL') + self.counts = spechdu.data.field('COUNTS') + self.grouping = spechdu.data.field('GROUPING') + self.quality = spechdu.data.field('QUALITY') + f.close() + + def squeezeByGrouping(self): + """ + Squeeze the spectrum according to the grouping specification, + i.e., sum the counts belonging to the same group, and place the + sum as the first channel within each group with other channels + of counts zero's. + """ + counts_squeezed = [] + cnt_sum = 0 + cnt_num = 0 + first = True + for grp, cnt in zip(self.grouping, self.counts): + if first and grp == 1: + # first group + cnt_sum = cnt + cnt_num = 1 + first = False + elif grp == 1: + # save previous group + counts_squeezed.append(cnt_sum) + counts_squeezed += [ 0 for i in range(cnt_num-1) ] + # start new group + cnt_sum = cnt + cnt_num = 1 + else: + # group continues + cnt_sum += cnt + cnt_num += 1 + # last group + # save previous group + counts_squeezed.append(cnt_sum) + counts_squeezed += [ 0 for i in range(cnt_num-1) ] + self.counts_squeezed = np.array(counts_squeezed, dtype=np.int32) + + def calcStatErr(self, gehrels=False): + """ + Calculate the statistical errors for the grouped channels, + and save as the STAT_ERR column. + """ + idx_nz = np.nonzero(self.counts_squeezed) + stat_err = np.zeros(self.counts_squeezed.shape) + if gehrels: + # Gehrels + stat_err[idx_nz] = 1 + np.sqrt(self.counts_squeezed[idx_nz] + 0.75) + else: + stat_err[idx_nz] = np.sqrt(self.counts_squeezed[idx_nz]) + self.stat_err = stat_err + + @staticmethod + def parseSysErr(syserr): + """ + Parse the string format of syserr supplied in the commandline. + """ + items = map(str.strip, syserr.split(',')) + syserr_spec = [] + for item in items: + spec = item.split(':') + try: + spec = (int(spec[0]), int(spec[1]), float(spec[2])) + except: + raise ValueError("invalid syserr specficiation") + syserr_spec.append(spec) + return syserr_spec + + def applySysErr(self, syserr): + """ + Apply systematic error adjustments to the above calculated + statistical errors. + """ + syserr_spec = self.parseSysErr(syserr) + for lo, hi, se in syserr_spec: + err_adjusted = self.stat_err[(lo-1):(hi-1)] * np.sqrt(1+se) + self.stat_err_adjusted = err_adjusted + + def updateHeader(self): + """ + Update header accordingly. + """ + # POISSERR + self.header['POISSERR'] = False + + def write(self, filename, clobber=False): + """ + Write the updated/modified spectrum block to file. + """ + channel_col = fits.Column(name='CHANNEL', format='J', + array=self.channel) + counts_col = fits.Column(name='COUNTS', format='J', + array=self.counts_squeezed) + stat_err_col = fits.Column(name='STAT_ERR', format='D', + array=self.stat_err_adjusted) + grouping_col = fits.Column(name='GROUPING', format='I', + array=self.grouping) + quality_col = fits.Column(name='QUALITY', format='I', + array=self.quality) + spec_cols = fits.ColDefs([channel_col, counts_col, stat_err_col, + grouping_col, quality_col]) + spechdu = fits.BinTableHDU.from_columns(spec_cols, header=self.header) + spechdu.writeto(filename, clobber=clobber) + + +def main(): + parser = argparse.ArgumentParser( + description="Apply systematic error adjustments to spectrum.") + parser.add_argument("-V", "--version", action="version", + version="%(prog)s " + "%s (%s)" % (__version__, __date__)) + parser.add_argument("infile", help="input spectrum file") + parser.add_argument("outfile", help="output adjusted spectrum file") + parser.add_argument("-e", "--syserr", dest="syserr", required=True, + help="systematic error specification; " + \ + "syntax: ch1low:ch1high:syserr1,...") + parser.add_argument("-C", "--clobber", dest="clobber", + action="store_true", help="overwrite output file if exists") + parser.add_argument("-G", "--gehrels", dest="gehrels", + action="store_true", help="use Gehrels error?") + args = parser.parse_args() + + spec = Spectrum(args.infile) + spec.squeezeByGrouping() + spec.calcStatErr(gehrels=args.gehrels) + spec.applySysErr(syserr=args.syserr) + spec.updateHeader() + spec.write(args.outfile, clobber=args.clobber) + + +if __name__ == "__main__": + main() + + +# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/calc_radial_psd.py b/python/calc_radial_psd.py new file mode 100755 index 0000000..23bd819 --- /dev/null +++ b/python/calc_radial_psd.py @@ -0,0 +1,450 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Credit: +# [1] Radially averaged power spectrum of 2D real-valued matrix +# Evan Ruzanski +# 'raPsd2d.m' +# https://www.mathworks.com/matlabcentral/fileexchange/23636-radially-averaged-power-spectrum-of-2d-real-valued-matrix +# +# XXX: +# * If the input image is NOT SQUARE; then are the horizontal frequencies +# the same as the vertical frequencies ?? +# +# Aaron LI <aaronly.me@gmail.com> +# Created: 2015-04-22 +# Updated: 2016-04-28 +# +# Changelog: +# 2016-04-28: +# * Fix wrong meshgrid with respect to the shift zero-frequency component +# * Use "numpy.fft" instead of "scipy.fftpack" +# * Split method "pad_square()" from "calc_radial_psd()" +# * Hide numpy warning when dividing by zero +# * Add method "AstroImage.fix_shapes()" +# * Add support for background subtraction and exposure correction +# * Show verbose information during calculation +# * Add class "AstroImage" +# * Set default value for 'args.png' +# * Rename from 'radialPSD2d.py' to 'calc_radial_psd.py' +# 2016-04-26: +# * Adjust plot function +# * Update normalize argument; Add pixel argument +# 2016-04-25: +# * Update plot function +# * Add command line scripting support +# * Encapsulate the functions within class 'PSD' +# * Update docs/comments +# + +""" +Compute the radially averaged power spectral density (i.e., power spectrum). +""" + +__version__ = "0.5.0" +__date__ = "2016-04-28" + + +import sys +import os +import argparse + +import numpy as np +from astropy.io import fits + +import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure + +plt.style.use("ggplot") + + +class PSD: + """ + Computes the 2D power spectral density and the radially averaged power + spectral density (i.e., 1D power spectrum). + """ + # 2D image data + img = None + # value and unit of 1 pixel for the input image + pixel = (None, None) + # whether to normalize the power spectral density by image size + normalize = True + # 2D power spectral density + psd2d = None + # 1D (radially averaged) power spectral density + freqs = None + psd1d = None + psd1d_err = None + + def __init__(self, img, pixel=(1.0, "pixel"), normalize=True): + self.img = img.astype(np.float) + self.pixel = pixel + self.normalize = normalize + + def calc_psd2d(self, verbose=False): + """ + Computes the 2D power spectral density of the given image. + Note that the low frequency components are shifted to the center + of the FFT'ed image. + + NOTE: + The zero-frequency component is shifted to position of index (0-based) + (ceil((n-1) / 2), ceil((m-1) / 2)), + where (n, m) are the number of rows and columns of the image/psd2d. + + Return: + 2D power spectral density, which is dimensionless if normalized, + otherwise has unit ${pixel_unit}^2. + """ + if verbose: + print("Calculating 2D power spectral density ... ", + end="", flush=True) + rows, cols = self.img.shape + ## Compute the power spectral density (i.e., power spectrum) + imgf = np.fft.fftshift(np.fft.fft2(self.img)) + if self.normalize: + norm = rows * cols * self.pixel[0]**2 + else: + norm = 1.0 # Do not normalize + self.psd2d = (np.abs(imgf) / norm) ** 2 + if verbose: + print("DONE", flush=True) + return self.psd2d + + def calc_radial_psd1d(self, verbose=False): + """ + Computes the radially averaged power spectral density from the + provided 2D power spectral density. + + Return: + (freqs, radial_psd, radial_psd_err) + freqs: spatial freqencies (unit: ${pixel_unit}^(-1)) + radial_psd: radially averaged power spectral density for each + frequency + radial_psd_err: standard deviations of each radial_psd + """ + if verbose: + print("Calculating radial (1D) power spectral density ... ", + end="", flush=True) + if verbose: + print("padding ... ", end="", flush=True) + psd2d = self.pad_square(self.psd2d, value=np.nan) + dim = psd2d.shape[0] + dim_half = (dim+1) // 2 + # NOTE: + # The zero-frequency component is shifted to position of index + # (0-based): (ceil((n-1) / 2), ceil((m-1) / 2)) + px = np.arange(dim_half-dim, dim_half) + x, y = np.meshgrid(px, px) + rho, phi = self.cart2pol(x, y) + rho = np.around(rho).astype(np.int) + radial_psd = np.zeros(dim_half) + radial_psd_err = np.zeros(dim_half) + if verbose: + print("radially averaging ... ", + end="", flush=True) + for r in range(dim_half): + # Get the indices of the elements satisfying rho[i,j]==r + ii, jj = (rho == r).nonzero() + # Calculate the mean value at a given radii + data = psd2d[ii, jj] + radial_psd[r] = np.nanmean(data) + radial_psd_err[r] = np.nanstd(data) + # Calculate frequencies + f = np.fft.fftfreq(dim, d=self.pixel[0]) + freqs = np.abs(f[:dim_half]) + # + self.freqs = freqs + self.psd1d = radial_psd + self.psd1d_err = radial_psd_err + if verbose: + print("DONE", end="", flush=True) + return (freqs, radial_psd, radial_psd_err) + + @staticmethod + def cart2pol(x, y): + """ + Convert Cartesian coordinates to polar coordinates. + """ + rho = np.sqrt(x**2 + y**2) + phi = np.arctan2(y, x) + return (rho, phi) + + @staticmethod + def pol2cart(rho, phi): + """ + Convert polar coordinates to Cartesian coordinates. + """ + x = rho * np.cos(phi) + y = rho * np.sin(phi) + return (x, y) + + @staticmethod + def pad_square(data, value=np.nan): + """ + Symmetrically pad the supplied data matrix to make it square. + The padding rows are equally added to the top and bottom, + as well as the columns to the left and right sides. + The padded rows/columns are filled with the specified value. + """ + mat = data.copy() + rows, cols = mat.shape + dim_diff = abs(rows - cols) + dim_max = max(rows, cols) + if rows > cols: + # pad columns + if dim_diff // 2 == 0: + cols_left = np.zeros((rows, dim_diff/2)) + cols_left[:] = value + cols_right = np.zeros((rows, dim_diff/2)) + cols_right[:] = value + mat = np.hstack((cols_left, mat, cols_right)) + else: + cols_left = np.zeros((rows, np.floor(dim_diff/2))) + cols_left[:] = value + cols_right = np.zeros((rows, np.floor(dim_diff/2)+1)) + cols_right[:] = value + mat = np.hstack((cols_left, mat, cols_right)) + elif rows < cols: + # pad rows + if dim_diff // 2 == 0: + rows_top = np.zeros((dim_diff/2, cols)) + rows_top[:] = value + rows_bottom = np.zeros((dim_diff/2, cols)) + rows_bottom[:] = value + mat = np.vstack((rows_top, mat, rows_bottom)) + else: + rows_top = np.zeros((np.floor(dim_diff/2), cols)) + rows_top[:] = value + rows_bottom = np.zeros((np.floor(dim_diff/2)+1, cols)) + rows_bottom[:] = value + mat = np.vstack((rows_top, mat, rows_bottom)) + return mat + + def plot(self, ax=None, fig=None): + """ + Make a plot of the radial (1D) PSD with matplotlib. + """ + if ax is None: + fig, ax = plt.subplots(1, 1) + # + xmin = self.freqs[1] / 1.2 # ignore the first 0 + xmax = self.freqs[-1] + ymin = np.nanmin(self.psd1d) / 10.0 + ymax = np.nanmax(self.psd1d + self.psd1d_err) + # + eb = ax.errorbar(self.freqs, self.psd1d, yerr=self.psd1d_err, + fmt="none") + ax.plot(self.freqs, self.psd1d, "ko") + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlim(xmin, xmax) + ax.set_ylim(ymin, ymax) + ax.set_title("Radially Averaged Power Spectral Density") + ax.set_xlabel(r"k (%s$^{-1}$)" % self.pixel[1]) + if self.normalize: + ax.set_ylabel("Power") + else: + ax.set_ylabel(r"Power (%s$^2$)" % self.pixel[1]) + fig.tight_layout() + return (fig, ax) + + +class AstroImage: + """ + Manipulate the astronimcal counts image, as well as the corresponding + exposure map and background map. + """ + # input counts image + image = None + # exposure map with respect to the input counts image + expmap = None + # background map (e.g., stowed background) + bkgmap = None + # exposure time of the input image + exposure = None + # exposure time of the background map + exposure_bkg = None + + def __init__(self, image, expmap=None, bkgmap=None, verbose=False): + self.load_image(image, verbose=verbose) + self.load_expmap(expmap, verbose=verbose) + self.load_bkgmap(bkgmap, verbose=verbose) + + def load_image(self, image, verbose=False): + if verbose: + print("Loading image ... ", end="", flush=True) + with fits.open(image) as imgfits: + self.image = imgfits[0].data.astype(np.float) + self.exposure = imgfits[0].header["EXPOSURE"] + if verbose: + print("DONE", flush=True) + + def load_expmap(self, expmap, verbose=False): + if expmap: + if verbose: + print("Loading exposure map ... ", end="", flush=True) + with fits.open(expmap) as imgfits: + self.expmap = imgfits[0].data.astype(np.float) + if verbose: + print("DONE", flush=True) + + def load_bkgmap(self, bkgmap, verbose=False): + if bkgmap: + if verbose: + print("Loading background map ... ", end="", flush=True) + with fits.open(bkgmap) as imgfits: + self.bkgmap = imgfits[0].data.astype(np.float) + self.exposure_bkg = imgfits[0].header["EXPOSURE"] + if verbose: + print("DONE", flush=True) + + def fix_shapes(self, tolerance=2, verbose=False): + """ + Fix the shapes of self.expmap and self.bkgmap to make them have + the same shape as the self.image. + + NOTE: + * if the image is bigger than the reference image, then its + columns on the right and rows on the botton are clipped; + * if the image is smaller than the reference image, then padding + columns on the right and rows on the botton are added. + * Original images are REPLACED! + + Arguments: + * tolerance: allow absolute difference between images + """ + def _fix_shape(img, ref, tol=tolerance, verbose=verbose): + if img.shape == ref.shape: + if verbose: + print("SKIPPED", flush=True) + return img + elif np.allclose(img.shape, ref.shape, atol=tol): + if verbose: + print(img.shape, "->", ref.shape, flush=True) + rows, cols = img.shape + rows_ref, cols_ref = ref.shape + # rows + if rows > rows_ref: + img_fixed = img[:rows_ref, :] + else: + img_fixed = np.row_stack((img, + np.zeros((rows_ref-rows, cols), dtype=img.dtype))) + # columns + if cols > cols_ref: + img_fixed = img_fixed[:, :cols_ref] + else: + img_fixed = np.column_stack((img_fixed, + np.zeros((rows_ref, cols_ref-cols), dtype=img.dtype))) + return img_fixed + else: + raise ValueError("shape difference exceeds tolerance: " + \ + "(%d, %d) vs. (%d, %d)" % (img.shape + ref.shape)) + # + if self.bkgmap is not None: + if verbose: + print("Fixing shape for bkgmap ... ", end="", flush=True) + self.bkgmap = _fix_shape(self.bkgmap, self.image) + if self.expmap is not None: + if verbose: + print("Fixing shape for expmap ... ", end="", flush=True) + self.expmap = _fix_shape(self.expmap, self.image) + + def subtract_bkg(self, verbose=False): + if verbose: + print("Subtracting background ... ", end="", flush=True) + self.image -= (self.bkgmap / self.exposure_bkg * self.exposure) + if verbose: + print("DONE", flush=True) + + def correct_exposure(self, cut=0.015, verbose=False): + """ + Correct the image for exposure by dividing by the expmap to + create the exposure-corrected image. + + Arguments: + * cut: the threshold percentage with respect to the maximum + exposure map value; and those pixels with lower values + than this threshold will be excluded/clipped (set to ZERO) + if set to None, then skip clipping image + """ + if verbose: + print("Correcting image for exposure ... ", end="", flush=True) + with np.errstate(divide="ignore", invalid="ignore"): + self.image /= self.expmap + # set invalid values to ZERO + self.image[ ~ np.isfinite(self.image) ] = 0.0 + if verbose: + print("DONE", flush=True) + if cut is not None: + # clip image according the exposure threshold + if verbose: + print("Clipping image (%s) ... " % cut, end="", flush=True) + threshold = cut * np.max(self.expmap) + self.image[ self.expmap < threshold ] = 0.0 + if verbose: + print("DONE", flush=True) + + +def main(): + parser = argparse.ArgumentParser( + description="Compute the radially averaged power spectral density", + epilog="Version: %s (%s)" % (__version__, __date__)) + parser.add_argument("-V", "--version", action="version", + version="%(prog)s " + "%s (%s)" % (__version__, __date__)) + parser.add_argument("-v", "--verbose", dest="verbose", + action="store_true", help="show verbose information") + parser.add_argument("-C", "--clobber", dest="clobber", + action="store_true", + help="overwrite the output files if already exist") + parser.add_argument("-i", "--infile", dest="infile", + required=True, help="input image") + parser.add_argument("-b", "--bkgmap", dest="bkgmap", default=None, + help="background map (for background subtraction)") + parser.add_argument("-e", "--expmap", dest="expmap", default=None, + help="exposure map (for exposure correction)") + parser.add_argument("-o", "--outfile", dest="outfile", + required=True, help="output file to store the PSD data") + parser.add_argument("-p", "--png", dest="png", default=None, + help="plot the PSD and save (default: same basename as outfile)") + args = parser.parse_args() + + if args.png is None: + args.png = os.path.splitext(args.outfile)[0] + ".png" + + # Check output files whether already exists + if (not args.clobber) and os.path.exists(args.outfile): + raise ValueError("outfile '%s' already exists" % args.outfile) + if (not args.clobber) and os.path.exists(args.png): + raise ValueError("output png '%s' already exists" % args.png) + + # Load image data + image = AstroImage(image=args.infile, expmap=args.expmap, + bkgmap=args.bkgmap, verbose=args.verbose) + image.fix_shapes(verbose=args.verbose) + if args.bkgmap: + image.subtract_bkg(verbose=args.verbose) + if args.expmap: + image.correct_exposure(verbose=args.verbose) + + # Calculate the power spectral density + psd = PSD(img=image.image, normalize=True) + psd.calc_psd2d(verbose=args.verbose) + freqs, psd1d, psd1d_err = psd.calc_radial_psd1d(verbose=args.verbose) + + # Write out PSD results + psd_data = np.column_stack((freqs, psd1d, psd1d_err)) + np.savetxt(args.outfile, psd_data, header="freqs psd1d psd1d_err") + + # Make and save a plot + fig = Figure(figsize=(10, 8)) + canvas = FigureCanvas(fig) + ax = fig.add_subplot(111) + psd.plot(ax=ax, fig=fig) + fig.savefig(args.png, format="png", dpi=150) + + +if __name__ == "__main__": + main() + diff --git a/python/crosstalk_deprojection.py b/python/crosstalk_deprojection.py new file mode 100755 index 0000000..d5bab05 --- /dev/null +++ b/python/crosstalk_deprojection.py @@ -0,0 +1,1808 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# References: +# [1] Definition of RMF and ARF file formats +# https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html +# [2] The OGIP Spectral File Format +# https://heasarc.gsfc.nasa.gov/docs/heasarc/ofwg/docs/summary/ogip_92_007_summary.html +# [3] CIAO: Auxiliary Response File +# http://cxc.harvard.edu/ciao/dictionary/arf.html +# [4] CIAO: Redistribution Matrix File +# http://cxc.harvard.edu/ciao/dictionary/rmf.html +# [5] astropy - FITS format code +# http://docs.astropy.org/en/stable/io/fits/usage/table.html#column-creation +# [6] XSPEC - Spectral Fitting +# https://heasarc.gsfc.nasa.gov/docs/xanadu/xspec/manual/XspecSpectralFitting.html +# [7] Direct X-ray Spectra Deprojection +# https://www-xray.ast.cam.ac.uk/papers/dsdeproj/ +# Sanders & Fabian 2007, MNRAS, 381, 1381 +# +# +# Weitian LI +# Created: 2016-03-26 +# Updated: 2016-04-20 +# +# ChangeLog: +# 2016-04-20: +# * Add argument 'add_history' to some methods (to avoid many duplicated +# histories due to Monte Carlo) +# * Rename 'reset_header_keywords()' to 'fix_header_keywords()', +# and add mandatory spectral keywords if missing +# * Add method 'fix_header()' to class 'Crosstalk' and 'Deprojection', +# and fix the headers before write spectra +# 2016-04-19: +# * Ignore numpy error due to division by zero +# * Update tool description and sample configuration +# * Add two other main methods: `main_deprojection()' and `main_crosstalk()' +# * Add argument 'group_squeeze' to some methods for better performance +# * Rename from 'correct_crosstalk.py' to 'crosstalk_deprojection.py' +# 2016-04-18: +# * Implement deprojection function: class Deprojection +# * Support spectral grouping (supply the grouping specification) +# * Add grouping, estimate_errors, copy, randomize, etc. methods +# * Utilize the Monte Carlo techniques to estimate the final spectral errors +# * Collect all ARFs and RMFs within dictionaries +# 2016-04-06: +# * Fix `RMF: get_rmfimg()' for XMM EPIC RMF +# 2016-04-02: +# * Interpolate ARF in order to match the spectral channel energies +# * Add version and date information +# * Update documentations +# * Update header history contents +# 2016-04-01: +# * Greatly update the documentations (e.g., description, sample config) +# * Add class `RMF' +# * Add method `get_energy()' for class `ARF' +# * Split out class `SpectrumSet' from `Spectrum' +# * Implement background subtraction +# * Add config `subtract_bkg' and corresponding argument +# +# XXX/FIXME: +# * Deprojection: account for ARF differences across different regions +# +# TODO: +# * Split classes ARF, RMF, Spectrum, and SpectrumSet to a separate module +# + +__version__ = "0.5.2" +__date__ = "2016-04-20" + + +""" +Correct the crosstalk effect of XMM spectra by subtracting the photons +that scattered from the surrounding regions due to the finite PSF, and +by compensating the photons that scattered to the surrounding regions, +according to the generated crosstalk ARFs by SAS `arfgen'. + +After the crosstalk effect being corrected, the deprojection is performed +to deproject the crosstalk-corrected spectra to derive the spectra with +both the crosstalk effect and projection effect corrected. + + +Sample config file (in `ConfigObj' syntax): +----------------------------------------------------------- +# operation mode: deprojection, crosstalk, or both (default) +mode = both +# supply a *groupped* spectrum (from which the "GROUPING" and "QUALITY" +# are used to group all the following spectra) +grouping = spec_grp.pi +# whether to subtract the background before crosstalk correction +subtract_bkg = True +# whether to fix the negative channel values due to spectral subtractions +fix_negative = False +# Monte Carlo times for spectral error estimation +mc_times = 5000 +# show progress details and verbose information +verbose = True +# overwrite existing files +clobber = False + +# NOTE: +# ONLY specifiy ONE set of projected spectra (i.e., from the same detector +# of one observation), since ALL the following specified spectra will be +# used for the deprojection. + +[reg1] +... + +[reg2] +outfile = deprojcc_reg2.pi +spec = reg2.pi +arf = reg2.arf +rmf = reg2.rmf +bkg = reg2_bkg.pi + [[cross_in]] + [[[in1]]] + spec = reg1.pi + arf = reg1.arf + rmf = reg1.rmf + bkg = reg1_bkg.pi + cross_arf = reg_1-2.arf + [[[in2]]] + spec = reg3.pi + arf = reg3.arf + rmf = reg3.rmf + bkg = reg3_bkg.pi + cross_arf = reg_3-2.arf + [[cross_out]] + cross_arf = reg_2-1.arf, reg_2-3.arf + +[...] +... +----------------------------------------------------------- +""" + +WARNING = """ +********************************* WARNING ************************************ +The generated spectra are substantially modified (e.g., scale, add, subtract), +therefore, take special care when interpretating the fitting results, +especially the metal abundances and normalizations. +****************************************************************************** +""" + + +import sys +import os +import argparse +from datetime import datetime +from copy import copy + +import numpy as np +import scipy as sp +import scipy.interpolate +from astropy.io import fits +from configobj import ConfigObj + + +def group_data(data, grouping): + """ + Group the data with respect to the supplied `grouping' specification + (i.e., "GROUPING" columns of a spectrum). The channel counts of the + same group are summed up and assigned to the FIRST channel of this + group, while the OTHRE channels are all set to ZERO. + """ + data_grp = np.array(data).copy() + for i in reversed(range(len(data))): + if grouping[i] == 1: + # the beginning channel of a group + continue + else: + # other channels of a group + data_grp[i-1] += data_grp[i] + data_grp[i] = 0 + assert np.isclose(sum(data_grp), sum(data)) + return data_grp + + +class ARF: # {{{ + """ + Class to handle the ARF (ancillary/auxiliary response file), + which contains the combined instrumental effective area + (telescope/filter/detector) and the quantum efficiency (QE) as a + function of energy averaged over time. + The effective area is [cm^2], and the QE is [counts/photon]; they are + multiplied together to create the ARF, resulting in [cm^2 counts/photon]. + + **CAVEAT/NOTE**: + Generally, the "ENERG_LO" and "ENERG_HI" columns of an ARF are *different* + to the "E_MIN" and "E_MAX" columns of a RMF (which are corresponding + to the spectrum channel energies). + For the XMM EPIC *pn* and Chandra *ACIS*, the generated ARF does NOT have + the same number of data points to that of spectral channels, i.e., the + "ENERG_LO" and "ENERG_HI" columns of ARF is different to the "E_MIN" and + "E_MAX" columns of RMF. + Therefore it is necessary to interpolate and extrapolate the ARF curve + in order to match the spectrum (or RMF "EBOUNDS" extension). + As for the XMM EPIC *MOS1* and *MOS2*, the ARF data points match the + spectral channels, i.e., the energy positions of each ARF data point and + spectral channel are consistent. Thus the interpolation is not needed. + + References: + [1] CIAO: Auxiliary Response File + http://cxc.harvard.edu/ciao/dictionary/arf.html + [2] Definition of RMF and ARF file formats + https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html + """ + filename = None + fitsobj = None + # only consider the "SPECTRUM" extension + header = None + energ_lo = None + energ_hi = None + specresp = None + # function of the interpolated ARF + f_interp = None + # energies of the spectral channels + energy_channel = None + # spectral channel grouping specification + grouping = None + groupped = False + # groupped ARF channels with respect to the grouping + specresp_grp = None + + def __init__(self, filename): + self.filename = filename + self.fitsobj = fits.open(filename) + ext_specresp = self.fitsobj["SPECRESP"] + self.header = ext_specresp.header + self.energ_lo = ext_specresp.data["ENERG_LO"] + self.energ_hi = ext_specresp.data["ENERG_HI"] + self.specresp = ext_specresp.data["SPECRESP"] + + def get_data(self, groupped=False, group_squeeze=False, copy=True): + if groupped: + specresp = self.specresp_grp + if group_squeeze: + specresp = specresp[self.grouping == 1] + else: + specresp = self.specresp + if copy: + return specresp.copy() + else: + return specresp + + def get_energy(self, mean="geometric"): + """ + Return the mean energy values of the ARF. + + Arguments: + * mean: type of the mean energy: + + "geometric": geometric mean, i.e., e = sqrt(e_min*e_max) + + "arithmetic": arithmetic mean, i.e., e = 0.5*(e_min+e_max) + """ + if mean == "geometric": + energy = np.sqrt(self.energ_lo * self.energ_hi) + elif mean == "arithmetic": + energy = 0.5 * (self.energ_lo + self.energ_hi) + else: + raise ValueError("Invalid mean type: %s" % mean) + return energy + + def interpolate(self, x=None, verbose=False): + """ + Cubic interpolate the ARF curve using `scipy.interpolate' + + If the requested point is outside of the data range, the + fill value of *zero* is returned. + + Arguments: + * x: points at which the interpolation to be calculated. + + Return: + If x is None, then the interpolated function is returned, + otherwise, the interpolated data are returned. + """ + if not hasattr(self, "f_interp") or self.f_interp is None: + energy = self.get_energy() + arf = self.get_data(copy=False) + if verbose: + print("INFO: interpolating '%s' (this may take a while) ..." \ + % self.filename, file=sys.stderr) + f_interp = sp.interpolate.interp1d(energy, arf, kind="cubic", + bounds_error=False, fill_value=0.0, assume_sorted=True) + self.f_interp = f_interp + if x is not None: + return self.f_interp(x) + else: + return self.f_interp + + def apply_grouping(self, energy_channel, grouping, verbose=False): + """ + Group the ARF channels (INTERPOLATED with respect to the spectral + channels) by the supplied grouping specification. + + Arguments: + * energy_channel: energies of the spectral channel + * grouping: spectral grouping specification + + Return: `self.specresp_grp' + """ + if self.groupped: + return + if verbose: + print("INFO: Grouping spectrum '%s' ..." % self.filename, + file=sys.stderr) + self.energy_channel = energy_channel + self.grouping = grouping + # interpolate the ARF w.r.t the spectral channel energies + arf_interp = self.interpolate(x=energy_channel, verbose=verbose) + self.specresp_grp = group_data(arf_interp, grouping) + self.groupped = True +# class ARF }}} + + +class RMF: # {{{ + """ + Class to handle the RMF (redistribution matrix file), + which maps from energy space into detector pulse height (or position) + space. Since detectors are not perfect, this involves a spreading of + the observed counts by the detector resolution, which is expressed as + a matrix multiplication. + For X-ray spectral analysis, the RMF encodes the probability R(E,p) + that a detected photon of energy E will be assisgned to a given + channel value (PHA or PI) of p. + + The standard Legacy format [2] for the RMF uses a binary table in which + each row contains R(E,p) for a single value of E as a function of p. + Non-zero sequences of elements of R(E,p) are encoded using a set of + variable length array columns. This format is compact but hard to + manipulate and understand. + + **CAVEAT/NOTE**: + + See also the above ARF CAVEAT/NOTE. + + The "EBOUNDS" extension contains the `CHANNEL', `E_MIN' and `E_MAX' + columns. This `CHANNEL' is the same as that of a spectrum. Therefore, + the energy values determined from the `E_MIN' and `E_MAX' columns are + used to interpolate and extrapolate the ARF curve. + + The `ENERG_LO' and `ENERG_HI' columns of the "MATRIX" extension are + the same as that of a ARF. + + References: + [1] CIAO: Redistribution Matrix File + http://cxc.harvard.edu/ciao/dictionary/rmf.html + [2] Definition of RMF and ARF file formats + https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html + """ + filename = None + fitsobj = None + ## extension "MATRIX" + hdr_matrix = None + energ_lo = None + energ_hi = None + n_grp = None + f_chan = None + n_chan = None + # raw squeezed RMF matrix data + matrix = None + ## extension "EBOUNDS" + hdr_ebounds = None + channel = None + e_min = None + e_max = None + ## converted 2D RMF matrix/image from the squeezed binary table + # size: len(energ_lo) x len(channel) + rmfimg = None + + def __init__(self, filename): + self.filename = filename + self.fitsobj = fits.open(filename) + ## "MATRIX" extension + ext_matrix = self.fitsobj["MATRIX"] + self.hdr_matrix = ext_matrix.header + self.energ_lo = ext_matrix.data["ENERG_LO"] + self.energ_hi = ext_matrix.data["ENERG_HI"] + self.n_grp = ext_matrix.data["N_GRP"] + self.f_chan = ext_matrix.data["F_CHAN"] + self.n_chan = ext_matrix.data["N_CHAN"] + self.matrix = ext_matrix.data["MATRIX"] + ## "EBOUNDS" extension + ext_ebounds = self.fitsobj["EBOUNDS"] + self.hdr_ebounds = ext_ebounds.header + self.channel = ext_ebounds.data["CHANNEL"] + self.e_min = ext_ebounds.data["E_MIN"] + self.e_max = ext_ebounds.data["E_MAX"] + + def get_energy(self, mean="geometric"): + """ + Return the mean energy values of the RMF "EBOUNDS". + + Arguments: + * mean: type of the mean energy: + + "geometric": geometric mean, i.e., e = sqrt(e_min*e_max) + + "arithmetic": arithmetic mean, i.e., e = 0.5*(e_min+e_max) + """ + if mean == "geometric": + energy = np.sqrt(self.e_min * self.e_max) + elif mean == "arithmetic": + energy = 0.5 * (self.e_min + self.e_max) + else: + raise ValueError("Invalid mean type: %s" % mean) + return energy + + def get_rmfimg(self): + """ + Convert the RMF data in squeezed binary table (standard Legacy format) + to a 2D image/matrix. + """ + def _make_rmfimg_row(n_channel, dtype, f_chan, n_chan, mat_row): + # make sure that `f_chan' and `n_chan' are 1-D numpy array + f_chan = np.array(f_chan).reshape(-1) + f_chan -= 1 # FITS indices are 1-based + n_chan = np.array(n_chan).reshape(-1) + idx = np.concatenate([ np.arange(f, f+n) \ + for f, n in zip(f_chan, n_chan) ]) + rmfrow = np.zeros(n_channel, dtype=dtype) + rmfrow[idx] = mat_row + return rmfrow + # + if self.rmfimg is None: + # Make the 2D RMF matrix/image + n_energy = len(self.energ_lo) + n_channel = len(self.channel) + rmf_dtype = self.matrix[0].dtype + rmfimg = np.zeros(shape=(n_energy, n_channel), dtype=rmf_dtype) + for i in np.arange(n_energy)[self.n_grp > 0]: + rmfimg[i, :] = _make_rmfimg_row(n_channel, rmf_dtype, + self.f_chan[i], self.n_chan[i], self.matrix[i]) + self.rmfimg = rmfimg + return self.rmfimg + + def write_rmfimg(self, outfile, clobber=False): + rmfimg = self.get_rmfimg() + # merge headers + header = self.hdr_matrix.copy(strip=True) + header.extend(self.hdr_ebounds.copy(strip=True)) + outfits = fits.PrimaryHDU(data=rmfimg, header=header) + outfits.writeto(outfile, checksum=True, clobber=clobber) +# class RMF }}} + + +class Spectrum: # {{{ + """ + Class that deals with the X-ray spectrum file (usually *.pi). + """ + filename = None + # FITS object return by `fits.open()' + fitsobj = None + # header of "SPECTRUM" extension + header = None + # "SPECTRUM" extension data + channel = None + # name of the spectrum data column (i.e., type, "COUNTS" or "RATE") + spec_type = None + # unit of the spectrum data ("count" for "COUNTS", "count/s" for "RATE") + spec_unit = None + # spectrum data + spec_data = None + # estimated spectral errors for each channel/group + spec_err = None + # statistical errors for each channel/group + stat_err = None + # grouping and quality + grouping = None + quality = None + # whether the spectral data being groupped + groupped = False + # several important keywords + EXPOSURE = None + BACKSCAL = None + AREASCAL = None + RESPFILE = None + ANCRFILE = None + BACKFILE = None + # numpy dtype and FITS format code of the spectrum data + spec_dtype = None + spec_fits_format = None + # output filename for writing the spectrum if no filename provided + outfile = None + + def __init__(self, filename, outfile=None): + self.filename = filename + self.fitsobj = fits.open(filename) + ext_spec = self.fitsobj["SPECTRUM"] + self.header = ext_spec.header.copy(strip=True) + colnames = ext_spec.columns.names + if "COUNTS" in colnames: + self.spec_type = "COUNTS" + elif "RATE" in colnames: + self.spec_type = "RATE" + else: + raise ValueError("Invalid spectrum file") + self.channel = ext_spec.data.columns["CHANNEL"].array + col_spec_data = ext_spec.data.columns[self.spec_type] + self.spec_data = col_spec_data.array.copy() + self.spec_unit = col_spec_data.unit + self.spec_dtype = col_spec_data.dtype + self.spec_fits_format = col_spec_data.format + # grouping and quality + if "GROUPING" in colnames: + self.grouping = ext_spec.data.columns["GROUPING"].array + if "QUALITY" in colnames: + self.quality = ext_spec.data.columns["QUALITY"].array + # keywords + self.EXPOSURE = self.header.get("EXPOSURE") + self.BACKSCAL = self.header.get("BACKSCAL") + self.AREASCAL = self.header.get("AREASCAL") + self.RESPFILE = self.header.get("RESPFILE") + self.ANCRFILE = self.header.get("ANCRFILE") + self.BACKFILE = self.header.get("BACKFILE") + # output filename + self.outfile = outfile + + def get_data(self, group_squeeze=False, copy=True): + """ + Get the spectral data (i.e., self.spec_data). + + Arguments: + * group_squeeze: whether squeeze the spectral data according to + the grouping (i.e., exclude the channels that + are not the first channel of the group, which + also have value of ZERO). + This argument is effective only the grouping + being applied. + """ + if group_squeeze and self.groupped: + spec_data = self.spec_data[self.grouping == 1] + else: + spec_data = self.spec_data + if copy: + return spec_data.copy() + else: + return spec_data + + def get_channel(self, copy=True): + if copy: + return self.channel.copy() + else: + return self.channel + + def set_data(self, spec_data, group_squeeze=True): + """ + Set the spectral data of this spectrum to the supplied data. + """ + if group_squeeze and self.groupped: + assert sum(self.grouping == 1) == len(spec_data) + self.spec_data[self.grouping == 1] = spec_data + else: + assert len(self.spec_data) == len(spec_data) + self.spec_data = spec_data.copy() + + def add_stat_err(self, stat_err, group_squeeze=True): + """ + Add the "STAT_ERR" column as the statistical errors of each spectral + group, which are estimated by utilizing the Monte Carlo techniques. + """ + self.stat_err = np.zeros(self.spec_data.shape, + dtype=self.spec_data.dtype) + if group_squeeze and self.groupped: + assert sum(self.grouping == 1) == len(stat_err) + self.stat_err[self.grouping == 1] = stat_err + else: + assert len(self.stat_err) == len(stat_err) + self.stat_err = stat_err.copy() + self.header["POISSERR"] = False + + def apply_grouping(self, grouping=None, quality=None): + """ + Apply the spectral channel grouping specification to the spectrum. + + NOTE: + * The spectral data (i.e., self.spec_data) is MODIFIED! + * The spectral data within the same group are summed up. + * The self grouping is overwritten if `grouping' is supplied, as well + as the self quality. + """ + if grouping is not None: + self.grouping = grouping + if quality is not None: + self.quality = quality + self.spec_data = group_data(self.spec_data, self.grouping) + self.groupped = True + + def estimate_errors(self, gehrels=True): + """ + Estimate the statistical errors of each spectral group (after + applying grouping) for the source spectrum (and background spectrum). + + If `gehrels=True', the statistical error for a spectral group with + N photons is given by `1 + sqrt(N + 0.75)'; otherwise, the error + is given by `sqrt(N)'. + + Results: `self.spec_err' + """ + eps = 1.0e-10 + if gehrels: + self.spec_err = 1.0 + np.sqrt(self.spec_data + 0.75) + else: + self.spec_err = np.sqrt(self.spec_data) + # replace the zeros with a very small value (because + # `np.random.normal' requires `scale' > 0) + self.spec_err[self.spec_err <= 0.0] = eps + + def copy(self): + """ + Return a copy of this object, with the `np.ndarray' properties are + copied. + """ + new = copy(self) + for k, v in self.__dict__.items(): + if isinstance(v, np.ndarray): + setattr(new, k, v.copy()) + return new + + def randomize(self): + """ + Randomize the spectral data according to the estimated spectral + group errors by assuming the normal distribution. + + NOTE: this method should be called AFTER the `copy()' method. + """ + if self.spec_err is None: + raise ValueError("No valid 'spec_err' presents") + if self.groupped: + idx = self.grouping == 1 + self.spec_data[idx] = np.random.normal(self.spec_data[idx], + self.spec_err[idx]) + else: + self.spec_data = np.random.normal(self.spec_data, self.spec_err) + return self + + def fix_header_keywords(self, + reset_kw=["ANCRFILE", "RESPFILE", "BACKFILE"]): + """ + Reset the keywords to "NONE" to avoid confusion or mistakes, + and also add mandatory spectral keywords if missing. + + Reference: + [1] The OGIP Spectral File Format, Sec. 3.1.1 + https://heasarc.gsfc.nasa.gov/docs/heasarc/ofwg/docs/summary/ogip_92_007_summary.html + """ + default_keywords = { + ## Mandatory keywords + #"EXTNAME" : "SPECTRUM", + "TELESCOP" : "NONE", + "INSTRUME" : "NONE", + "FILTER" : "NONE", + #"EXPOSURE" : <integration_time (s)>, + "BACKFILE" : "NONE", + "CORRFILE" : "NONE", + "CORRSCAL" : 1.0, + "RESPFILE" : "NONE", + "ANCRFILE" : "NONE", + "HDUCLASS" : "OGIP", + "HDUCLAS1" : "SPECTRUM", + "HDUVERS" : "1.2.1", + "POISSERR" : True, + #"CHANTYPE" : "PI", + #"DETCHANS" : <total_number_of_detector_channels>, + ## Optional keywords for further information + "BACKSCAL" : 1.0, + "AREASCAL" : 1.0, + # Type of spectral data: + # (1) "TOTAL": gross spectrum (source+bkg); + # (2) "NET": background-subtracted spectrum + # (3) "BKG" background spectrum + #"HDUCLAS2" : "NET", + # Details of the type of data: + # (1) "COUNT": data stored as counts + # (2) "RATE": data stored as counts/s + "HDUCLAS3" : { "COUNTS":"COUNT", + "RATE":"RATE" }.get(self.spec_type), + } + # add mandatory keywords if missing + for kw, value in default_keywords.items(): + if kw not in self.header: + self.header[kw] = value + # reset the specified keywords + for kw in reset_kw: + self.header[kw] = default_keywords.get(kw) + + def write(self, filename=None, clobber=False): + """ + Create a new "SPECTRUM" table/extension and replace the original + one, then write to output file. + """ + if filename is None: + filename = self.outfile + columns = [ + fits.Column(name="CHANNEL", format="I", array=self.channel), + fits.Column(name=self.spec_type, format=self.spec_fits_format, + unit=self.spec_unit, array=self.spec_data), + ] + if self.grouping is not None: + columns.append(fits.Column(name="GROUPING", + format="I", array=self.grouping)) + if self.quality is not None: + columns.append(fits.Column(name="QUALITY", + format="I", array=self.quality)) + if self.stat_err is not None: + columns.append(fits.Column(name="STAT_ERR", unit=self.spec_unit, + format=self.spec_fits_format, + array=self.stat_err)) + ext_spec_cols = fits.ColDefs(columns) + ext_spec = fits.BinTableHDU.from_columns(ext_spec_cols, + header=self.header) + self.fitsobj["SPECTRUM"] = ext_spec + self.fitsobj.writeto(filename, clobber=clobber, checksum=True) +# class Spectrum }}} + + +class SpectrumSet(Spectrum): # {{{ + """ + This class handles a set of spectrum, including the source spectrum, + RMF, ARF, and the background spectrum. + + **NOTE**: + The "COUNTS" column data are converted from "int32" to "float32", + since this spectrum will be subtracted/compensated according to the + ratios of ARFs. + """ + # ARF object for this spectrum + arf = None + # RMF object for this spectrum + rmf = None + # background Spectrum object for this spectrum + bkg = None + # inner and outer radius of the region from which the spectrum extracted + radius_inner = None + radius_outer = None + # total angular range of the spectral region + angle = None + + # numpy dtype and FITS format code to which the spectrum data be + # converted if the data is "COUNTS" + #_spec_dtype = np.float32 + #_spec_fits_format = "E" + _spec_dtype = np.float64 + _spec_fits_format = "D" + + def __init__(self, filename, outfile=None, arf=None, rmf=None, bkg=None): + super().__init__(filename, outfile) + # convert spectrum data type if necessary + if self.spec_data.dtype != self._spec_dtype: + self.spec_data = self.spec_data.astype(self._spec_dtype) + self.spec_dtype = self._spec_dtype + self.spec_fits_format = self._spec_fits_format + if arf is not None: + if isinstance(arf, ARF): + self.arf = arf + else: + self.arf = ARF(arf) + if rmf is not None: + if isinstance(rmf, RMF): + self.rmf = rmf + else: + self.rmf = RMF(rmf) + if bkg is not None: + if isinstance(bkg, Spectrum): + self.bkg = bkg + else: + self.bkg = Spectrum(bkg) + # convert background spectrum data type if necessary + if self.bkg.spec_data.dtype != self._spec_dtype: + self.bkg.spec_data = self.bkg.spec_data.astype(self._spec_dtype) + self.bkg.spec_dtype = self._spec_dtype + self.bkg.spec_fits_format = self._spec_fits_format + + def get_energy(self, mean="geometric"): + """ + Get the energy values of each channel if RMF present. + + NOTE: + The "E_MIN" and "E_MAX" columns of the RMF is required to calculate + the spectrum channel energies. + And the channel energies are generally different to the "ENERG_LO" + and "ENERG_HI" of the corresponding ARF. + """ + if self.rmf is None: + return None + else: + return self.rmf.get_energy(mean=mean) + + def get_arf(self, mean="geometric", groupped=True, copy=True): + """ + Get the interpolated ARF data w.r.t the spectral channel energies + if the ARF presents. + + Arguments: + * groupped: (bool) whether to get the groupped ARF + + Return: (groupped) interpolated ARF data + """ + if self.arf is None: + return None + else: + return self.arf.get_data(groupped=groupped, copy=copy) + + def read_xflt(self): + """ + Read the XFLT000# keywords from the header, check the validity (e.g., + "XFLT0001" should equals "XFLT0002", "XFLT0003" should equals 0). + Sum all the additional XFLT000# pairs (e.g., ) which describes the + regions angluar ranges. + """ + eps = 1.0e-6 + xflt0001 = float(self.header["XFLT0001"]) + xflt0002 = float(self.header["XFLT0002"]) + xflt0003 = float(self.header["XFLT0003"]) + # XFLT000# validity check + assert np.isclose(xflt0001, xflt0002) + assert abs(xflt0003) < eps + # outer radius of the region + self.radius_outer = xflt0001 + # angular regions + self.angle = 0.0 + num = 4 + while True: + try: + angle_begin = float(self.header["XFLT%04d" % num]) + angle_end = float(self.header["XFLT%04d" % (num+1)]) + num += 2 + except KeyError: + break + self.angle += (angle_end - angle_begin) + # if NO additional XFLT000# keys exist, assume "annulus" region + if self.angle < eps: + self.angle = 360.0 + + def scale(self): + """ + Scale the spectral data (and spectral group errors if present) of + the source spectrum (and background spectra if present) according + to the region angular size to make it correspond to the whole annulus + region (i.e., 360 degrees). + + NOTE: the spectral data and errors (i.e., `self.spec_data', and + `self.spec_err') is MODIFIED! + """ + self.spec_data *= (360.0 / self.angle) + if self.spec_err is not None: + self.spec_err *= (360.0 / self.angle) + # also scale the background spectrum if present + if self.bkg: + self.bkg.spec_data *= (360.0 / self.angle) + if self.bkg.spec_err is not None: + self.bkg.spec_err *= (360.0 / self.angle) + + def apply_grouping(self, grouping=None, quality=None, verbose=False): + """ + Apply the spectral channel grouping specification to the source + spectrum, the ARF (which is used during the later spectral + manipulations), and the background spectrum (if presents). + + NOTE: + * The spectral data (i.e., self.spec_data) is MODIFIED! + * The spectral data within the same group are summed up. + * The self grouping is overwritten if `grouping' is supplied, as well + as the self quality. + """ + super().apply_grouping(grouping=grouping, quality=quality) + # also group the ARF accordingly + self.arf.apply_grouping(energy_channel=self.get_energy(), + grouping=self.grouping, verbose=verbose) + # group the background spectrum if present + if self.bkg: + self.bkg.spec_data = group_data(self.bkg.spec_data, self.grouping) + + def estimate_errors(self, gehrels=True): + """ + Estimate the statistical errors of each spectral group (after + applying grouping) for the source spectrum (and background spectrum). + + If `gehrels=True', the statistical error for a spectral group with + N photons is given by `1 + sqrt(N + 0.75)'; otherwise, the error + is given by `sqrt(N)'. + + Results: `self.spec_err' (and `self.bkg.spec_err') + """ + super().estimate_errors(gehrels=gehrels) + eps = 1.0e-10 + # estimate the errors for background spectrum if present + if self.bkg: + if gehrels: + self.bkg.spec_err = 1.0 + np.sqrt(self.bkg.spec_data + 0.75) + else: + self.bkg.spec_err = np.sqrt(self.bkg.spec_data) + self.bkg.spec_err[self.bkg.spec_err <= 0.0] = eps + + def subtract_bkg(self, inplace=True, add_history=False, verbose=False): + """ + Subtract the background contribution from the source spectrum. + The `EXPOSURE' and `BACKSCAL' values are required to calculate + the fraction/ratio for the background subtraction. + + Arguments: + * inplace: whether replace the `spec_data' with the background- + subtracted spectrum data; If True, the attribute + `spec_bkg_subtracted' is also set to `True' when + the subtraction finished. + The keywords "BACKSCAL" and "AREASCAL" are set to 1.0. + + Return: + background-subtracted spectrum data + """ + ratio = (self.EXPOSURE / self.bkg.EXPOSURE) * \ + (self.BACKSCAL / self.bkg.BACKSCAL) * \ + (self.AREASCAL / self.bkg.AREASCAL) + operation = " SUBTRACT_BACKGROUND: %s - %s * %s" % \ + (self.filename, ratio, self.bkg.filename) + if verbose: + print(operation, file=sys.stderr) + spec_data_subbkg = self.spec_data - ratio * self.bkg.get_data() + if inplace: + self.spec_data = spec_data_subbkg + self.spec_bkg_subtracted = True + self.BACKSCAL = 1.0 + self.AREASCAL = 1.0 + # update header + self.header["BACKSCAL"] = 1.0 + self.header["AREASCAL"] = 1.0 + self.header["BACKFILE"] = "NONE" + self.header["HDUCLAS2"] = "NET" # background-subtracted spectrum + # also record history + if add_history: + self.header.add_history(operation) + return spec_data_subbkg + + def subtract(self, spectrumset, cross_arf, groupped=False, + group_squeeze=False, add_history=False, verbose=False): + """ + Subtract the photons that originate from the surrounding regions + but were scattered into this spectrum due to the finite PSF. + + The background of this spectrum and the given spectrum should + both be subtracted before applying this subtraction for crosstalk + correction, as well as the below `compensate()' procedure. + + NOTE: + 1. The crosstalk ARF must be provided, since the `spectrumset.arf' + is required to be its ARF without taking crosstalk into account: + spec1_new = spec1 - spec2 * (cross_arf_2_to_1 / arf2) + 2. The ARF are interpolated to match the energies of spetral channels. + """ + operation = " SUBTRACT: %s - (%s/%s) * %s" % (self.filename, + cross_arf.filename, spectrumset.arf.filename, + spectrumset.filename) + if verbose: + print(operation, file=sys.stderr) + energy = self.get_energy() + if groupped: + spectrumset.arf.apply_grouping(energy_channel=energy, + grouping=self.grouping, verbose=verbose) + cross_arf.apply_grouping(energy_channel=energy, + grouping=self.grouping, verbose=verbose) + arfresp_spec = spectrumset.arf.get_data(groupped=True, + group_squeeze=group_squeeze) + arfresp_cross = cross_arf.get_data(groupped=True, + group_squeeze=group_squeeze) + else: + arfresp_spec = spectrumset.arf.interpolate(x=energy, + verbose=verbose) + arfresp_cross = cross_arf.interpolate(x=energy, verbose=verbose) + with np.errstate(divide="ignore", invalid="ignore"): + arf_ratio = arfresp_cross / arfresp_spec + # fix nan/inf values due to division by zero + arf_ratio[ ~ np.isfinite(arf_ratio) ] = 0.0 + spec_data = self.get_data(group_squeeze=group_squeeze) - \ + spectrumset.get_data(group_squeeze=group_squeeze)*arf_ratio + self.set_data(spec_data, group_squeeze=group_squeeze) + # record history + if add_history: + self.header.add_history(operation) + + def compensate(self, cross_arf, groupped=False, group_squeeze=False, + add_history=False, verbose=False): + """ + Compensate the photons that originate from this regions but were + scattered into the surrounding regions due to the finite PSF. + + formula: + spec1_new = spec1 + spec1 * (cross_arf_1_to_2 / arf1) + """ + operation = " COMPENSATE: %s + (%s/%s) * %s" % (self.filename, + cross_arf.filename, self.arf.filename, self.filename) + if verbose: + print(operation, file=sys.stderr) + energy = self.get_energy() + if groupped: + cross_arf.apply_grouping(energy_channel=energy, + grouping=self.grouping, verbose=verbose) + arfresp_this = self.arf.get_data(groupped=True, + group_squeeze=group_squeeze) + arfresp_cross = cross_arf.get_data(groupped=True, + group_squeeze=group_squeeze) + else: + arfresp_this = self.arf.interpolate(x=energy, verbose=verbose) + arfresp_cross = cross_arf.interpolate(x=energy, verbose=verbose) + with np.errstate(divide="ignore", invalid="ignore"): + arf_ratio = arfresp_cross / arfresp_this + # fix nan/inf values due to division by zero + arf_ratio[ ~ np.isfinite(arf_ratio) ] = 0.0 + spec_data = self.get_data(group_squeeze=group_squeeze) + \ + self.get_data(group_squeeze=group_squeeze) * arf_ratio + self.set_data(spec_data, group_squeeze=group_squeeze) + # record history + if add_history: + self.header.add_history(operation) + + def fix_negative(self, add_history=False, verbose=False): + """ + The subtractions may lead to negative counts, it may be necessary + to fix these channels with negative values. + """ + neg_counts = self.spec_data < 0 + N = len(neg_counts) + neg_channels = np.arange(N, dtype=np.int)[neg_counts] + if len(neg_channels) > 0: + print("WARNING: %d channels have NEGATIVE counts" % \ + len(neg_channels), file=sys.stderr) + i = 0 + while len(neg_channels) > 0: + i += 1 + if verbose: + if i == 1: + print("*** Fixing negative channels: iter %d..." % i, + end="", file=sys.stderr) + else: + print("%d..." % i, end="", file=sys.stderr) + for ch in neg_channels: + neg_val = self.spec_data[ch] + if ch < N-2: + self.spec_data[ch] = 0 + self.spec_data[(ch+1):(ch+3)] -= 0.5 * np.abs(neg_val) + else: + # just set to zero if it is the last 2 channels + self.spec_data[ch] = 0 + # update negative channels indices + neg_counts = self.spec_data < 0 + neg_channels = np.arange(N, dtype=np.int)[neg_counts] + if i > 0: + print("FIXED!", file=sys.stderr) + # record history + if add_history: + self.header.add_history(" FIXED NEGATIVE CHANNELS") + + def set_radius_inner(self, radius_inner): + """ + Set the inner radius of the spectral region. + """ + assert radius_inner < self.radius_outer + self.radius_inner = radius_inner + + def copy(self): + """ + Return a copy of this object. + """ + new = super().copy() + if self.bkg: + new.bkg = self.bkg.copy() + return new + + def randomize(self): + """ + Randomize the source (and background if present) spectral data + according to the estimated spectral group errors by assuming the + normal distribution. + + NOTE: this method should be called AFTER the `copy()' method. + """ + super().randomize() + if self.bkg: + self.bkg.spec_data = np.random.normal(self.bkg.spec_data, + self.bkg.spec_err) + self.bkg.spec_data[self.grouping == -1] = 0.0 + return self +# class SpectrumSet }}} + + +class Crosstalk: # {{{ + """ + XMM-Newton PSF Crosstalk effect correction. + """ + # `SpectrumSet' object for the spectrum to be corrected + spectrumset = None + # NOTE/XXX: do NOT use list (e.g., []) here, otherwise, all the + # instances will share these list properties. + # `SpectrumSet' and `ARF' objects corresponding to the spectra from + # which the photons were scattered into this spectrum. + cross_in_specset = None + cross_in_arf = None + # `ARF' objects corresponding to the regions to which the photons of + # this spectrum were scattered into. + cross_out_arf = None + # grouping specification and quality data + grouping = None + quality = None + # whether the spectrum is groupped + groupped = False + + def __init__(self, config, arf_dict={}, rmf_dict={}, + grouping=None, quality=None): + """ + Arguments: + * config: a section of the whole config file (`ConfigObj' object) + """ + self.cross_in_specset = [] + self.cross_in_arf = [] + self.cross_out_arf = [] + # this spectrum to be corrected + self.spectrumset = SpectrumSet(filename=config["spec"], + outfile=config["outfile"], + arf=arf_dict.get(config["arf"], config["arf"]), + rmf=rmf_dict.get(config.get("rmf"), config.get("rmf")), + bkg=config.get("bkg")) + # spectra and cross arf from which photons were scattered in + for reg_in in config["cross_in"].values(): + specset = SpectrumSet(filename=reg_in["spec"], + arf=arf_dict.get(reg_in["arf"], reg_in["arf"]), + rmf=rmf_dict.get(reg_in.get("rmf"), reg_in.get("rmf")), + bkg=reg_in.get("bkg")) + self.cross_in_specset.append(specset) + self.cross_in_arf.append(arf_dict.get(reg_in["cross_arf"], + ARF(reg_in["cross_arf"]))) + # regions into which the photons of this spectrum were scattered into + if "cross_out" in config.sections: + cross_arf = config["cross_out"].as_list("cross_arf") + for arffile in cross_arf: + self.cross_out_arf.append(arf_dict.get(arffile, ARF(arffile))) + # grouping and quality + self.grouping = grouping + self.quality = quality + + def apply_grouping(self, verbose=False): + self.spectrumset.apply_grouping(grouping=self.grouping, + quality=self.quality, verbose=verbose) + # also group the related surrounding spectra + for specset in self.cross_in_specset: + specset.apply_grouping(grouping=self.grouping, + quality=self.quality, verbose=verbose) + self.groupped = True + + def estimate_errors(self, gehrels=True, verbose=False): + if verbose: + print("INFO: Estimating spectral errors ...") + self.spectrumset.estimate_errors(gehrels=gehrels) + # also estimate errors for the related surrounding spectra + for specset in self.cross_in_specset: + specset.estimate_errors(gehrels=gehrels) + + def do_correction(self, subtract_bkg=True, fix_negative=False, + group_squeeze=True, add_history=False, verbose=False): + """ + Perform the crosstalk correction. The background contribution + for each spectrum is subtracted first if `subtract_bkg' is True. + The basic correction procedures are recorded to the header. + """ + if add_history: + self.spectrumset.header.add_history("Crosstalk Correction BEGIN") + self.spectrumset.header.add_history(" TOOL: %s (v%s) @ %s" % (\ + os.path.basename(sys.argv[0]), __version__, + datetime.utcnow().isoformat())) + # background subtraction + if subtract_bkg: + if verbose: + print("INFO: subtract background ...", file=sys.stderr) + self.spectrumset.subtract_bkg(inplace=True, + add_history=add_history, verbose=verbose) + # also apply background subtraction to the surrounding spectra + for specset in self.cross_in_specset: + specset.subtract_bkg(inplace=True, + add_history=add_history, verbose=verbose) + # subtractions + if verbose: + print("INFO: apply subtractions ...", file=sys.stderr) + for specset, cross_arf in zip(self.cross_in_specset, + self.cross_in_arf): + self.spectrumset.subtract(spectrumset=specset, + cross_arf=cross_arf, groupped=self.groupped, + group_squeeze=group_squeeze, add_history=add_history, + verbose=verbose) + # compensations + if verbose: + print("INFO: apply compensations ...", file=sys.stderr) + for cross_arf in self.cross_out_arf: + self.spectrumset.compensate(cross_arf=cross_arf, + groupped=self.groupped, group_squeeze=group_squeeze, + add_history=add_history, verbose=verbose) + # fix negative values in channels + if fix_negative: + if verbose: + print("INFO: fix negative channel values ...", file=sys.stderr) + self.spectrumset.fix_negative(add_history=add_history, + verbose=verbose) + if add_history: + self.spectrumset.header.add_history("END Crosstalk Correction") + + def fix_header(self): + # fix header keywords + self.spectrumset.fix_header_keywords( + reset_kw=["RESPFILE", "ANCRFILE", "BACKFILE"]) + + def copy(self): + new = copy(self) + # properly handle the copy of spectrumsets + new.spectrumset = self.spectrumset.copy() + new.cross_in_specset = [ specset.copy() \ + for specset in self.cross_in_specset ] + return new + + def randomize(self): + self.spectrumset.randomize() + for specset in self.cross_in_specset: + specset.randomize() + return self + + def get_spectrum(self, copy=True): + if copy: + return self.spectrumset.copy() + else: + return self.spectrumset + + def write(self, filename=None, clobber=False): + self.spectrumset.write(filename=filename, clobber=clobber) +# class Crosstalk }}} + + +class Deprojection: # {{{ + """ + Perform the deprojection on a set of PROJECTED spectra with the + assumption of spherical symmetry of the source object, and produce + the DEPROJECTED spectra. + + NOTE: + * Assumption of the spherical symmetry + * Background should be subtracted before deprojection + * ARF differences of different regions are taken into account + + Reference & Credit: + [1] Direct X-ray Spectra Deprojection + https://www-xray.ast.cam.ac.uk/papers/dsdeproj/ + Sanders & Fabian 2007, MNRAS, 381, 1381 + """ + spectra = None + grouping = None + quality = None + + def __init__(self, spectra, grouping=None, quality=None, verbose=False): + """ + Arguments: + * spectra: a set of spectra from the inner-most to the outer-most + regions (e.g., spectra after correcting crosstalk effect) + * grouping: grouping specification for all the spectra + * quality: quality column for the spectra + """ + self.spectra = [] + for spec in spectra: + if not isinstance(spec, SpectrumSet): + raise ValueError("Not a 'SpectrumSet' object") + spec.read_xflt() + self.spectra.append(spec) + self.spectra = spectra + self.grouping = grouping + self.quality = quality + # sort spectra by `radius_outer' + self.spectra.sort(key=lambda x: x.radius_outer) + # set the inner radii + radii_inner = [0.0] + [ x.radius_outer for x in self.spectra[:-1] ] + for spec, rin in zip(self.spectra, radii_inner): + spec.set_radius_inner(rin) + if verbose: + print("Deprojection: loaded spectrum: radius: (%s, %s)" % \ + (spec.radius_inner, spec.radius_outer), + file=sys.stderr) + # check EXPOSURE validity (all spectra must have the same exposures) + exposures = [ spec.EXPOSURE for spec in self.spectra ] + assert np.allclose(exposures[:-1], exposures[1:]) + + def subtract_bkg(self, verbose=True): + for spec in self.spectra: + if not spec.bkg: + raise ValueError("Spectrum '%s' has NO background" % \ + spec.filename) + spec.subtract_bkg(inplace=True, verbose=verbose) + + def apply_grouping(self, verbose=False): + for spec in self.spectra: + spec.apply_grouping(grouping=self.grouping, quality=self.quality, + verbose=verbose) + + def estimate_errors(self, gehrels=True): + for spec in self.spectra: + spec.estimate_errors(gehrels=gehrels) + + def scale(self): + """ + Scale the spectral data according to the region angular size. + """ + for spec in self.spectra: + spec.scale() + + def do_deprojection(self, group_squeeze=True, + add_history=False, verbose=False): + # + # TODO/XXX: How to apply ARF correction here??? + # + num_spec = len(self.spectra) + tmp_spec_data = self.spectra[0].get_data(group_squeeze=group_squeeze) + spec_shape = tmp_spec_data.shape + spec_dtype = tmp_spec_data.dtype + spec_per_vol = [None] * num_spec + # + for shellnum in reversed(range(num_spec)): + if verbose: + print("DEPROJECTION: deprojecting shell %d ..." % shellnum, + file=sys.stderr) + spec = self.spectra[shellnum] + # calculate projected spectrum of outlying shells + proj_spec = np.zeros(spec_shape, spec_dtype) + for outer in range(shellnum+1, num_spec): + vol = self.projected_volume( + r1=self.spectra[outer].radius_inner, + r2=self.spectra[outer].radius_outer, + R1=spec.radius_inner, + R2=spec.radius_outer) + proj_spec += spec_per_vol[outer] * vol + # + this_spec = spec.get_data(group_squeeze=group_squeeze, copy=True) + deproj_spec = this_spec - proj_spec + # calculate the volume that this spectrum is from + this_vol = self.projected_volume( + r1=spec.radius_inner, r2=spec.radius_outer, + R1=spec.radius_inner, R2=spec.radius_outer) + # calculate the spectral data per unit volume + spec_per_vol[shellnum] = deproj_spec / this_vol + # set the spectral data to these deprojected values + self.set_spec_data(spec_per_vol, group_squeeze=group_squeeze) + # add history to header + if add_history: + self.add_history() + + def get_spec_data(self, group_squeeze=True, copy=True): + """ + Extract the spectral data of each spectrum after deprojection + performed. + """ + return [ spec.get_data(group_squeeze=group_squeeze, copy=copy) + for spec in self.spectra ] + + def set_spec_data(self, spec_data, group_squeeze=True): + """ + Set `spec_data' for each spectrum to the deprojected spectral data. + """ + assert len(spec_data) == len(self.spectra) + for spec, data in zip(self.spectra, spec_data): + spec.set_data(data, group_squeeze=group_squeeze) + + def add_stat_err(self, stat_err, group_squeeze=True): + """ + Add the "STAT_ERR" column to each spectrum. + """ + assert len(stat_err) == len(self.spectra) + for spec, err in zip(self.spectra, stat_err): + spec.add_stat_err(err, group_squeeze=group_squeeze) + + def add_history(self): + """ + Append a brief history about this tool to the header. + """ + history = "Deprojected by %s (v%s) @ %s" % ( + os.path.basename(sys.argv[0]), __version__, + datetime.utcnow().isoformat()) + for spec in self.spectra: + spec.header.add_history(history) + + def fix_header(self): + # fix header keywords + for spec in self.spectra: + spec.fix_header_keywords( + reset_kw=["RESPFILE", "ANCRFILE", "BACKFILE"]) + + def write(self, filenames=[], clobber=False): + """ + Write the deprojected spectra to output file. + """ + if filenames == []: + filenames = [ spec.outfile for spec in self.spectra ] + for spec, outfile in zip(self.spectra, filenames): + spec.write(filename=outfile, clobber=clobber) + + @staticmethod + def projected_volume(r1, r2, R1, R2): + """ + Calculate the projected volume of a spherical shell of radii r1 -> r2 + onto an annulus on the sky of radius R1 -> R2. + + This volume is the integral: + Int(R=R1,R2) Int(x=sqrt(r1^2-R^2),sqrt(r2^2-R^2)) 2*pi*R dx dR + = + Int(R=R1,R2) 2*pi*R * (sqrt(r2^2-R^2) - sqrt(r1^2-R^2)) dR + + Note that the above integral is only half the total volume + (i.e., front only). + """ + def sqrt_trunc(x): + if x > 0: + return np.sqrt(x) + else: + return 0.0 + # + p1 = sqrt_trunc(r1**2 - R2**2) + p2 = sqrt_trunc(r1**2 - R1**2) + p3 = sqrt_trunc(r2**2 - R2**2) + p4 = sqrt_trunc(r2**2 - R1**2) + return 2.0 * (2.0/3.0) * np.pi * ((p1**3 - p2**3) + (p4**3 - p3**3)) +# class Deprojection }}} + + +# Helper functions {{{ +def calc_median_errors(results): + """ + Calculate the median and errors for the spectral data gathered + through Monte Carlo simulations. + + TODO: investigate the errors calculation approach used here! + """ + results = np.array(results) + # `results' now has shape: (mc_times, num_spec, num_channel) + # sort by the Monte Carlo simulation axis + results.sort(0) + mc_times = results.shape[0] + medians = results[ int(mc_times * 0.5) ] + lowerpcs = results[ int(mc_times * 0.1585) ] + upperpcs = results[ int(mc_times * 0.8415) ] + errors = np.sqrt(0.5 * ((medians-lowerpcs)**2 + (upperpcs-medians)**2)) + return (medians, errors) + + +def set_argument(name, default, cmdargs, config): + value = default + if name in config.keys(): + value = config.as_bool(name) + value_cmd = vars(cmdargs)[name] + if value_cmd != default: + value = value_cmd # command arguments overwrite others + return value +# helper functions }}} + + +# main routine {{{ +def main(config, subtract_bkg, fix_negative, mc_times, + verbose=False, clobber=False): + # collect ARFs and RMFs into dictionaries (avoid interpolation every time) + arf_files = set() + rmf_files = set() + for region in config.sections: + config_reg = config[region] + arf_files.add(config_reg.get("arf")) + rmf_files.add(config_reg.get("rmf")) + for reg_in in config_reg["cross_in"].values(): + arf_files.add(reg_in.get("arf")) + arf_files.add(reg_in.get("cross_arf")) + if "cross_out" in config_reg.sections: + for arf in config_reg["cross_out"].as_list("cross_arf"): + arf_files.add(arf) + arf_files = arf_files - set([None]) + arf_dict = { arf: ARF(arf) for arf in arf_files } + rmf_files = rmf_files - set([None]) + rmf_dict = { rmf: RMF(rmf) for rmf in rmf_files } + if verbose: + print("INFO: arf_files:", arf_files, file=sys.stderr) + print("INFO: rmf_files:", rmf_files, file=sys.stderr) + + # get the GROUPING and QUALITY data + grouping_fits = fits.open(config["grouping"]) + grouping = grouping_fits["SPECTRUM"].data.columns["GROUPING"].array + quality = grouping_fits["SPECTRUM"].data.columns["QUALITY"].array + # squeeze the groupped spectral data, etc. + group_squeeze = True + + # crosstalk objects (BEFORE background subtraction) + crosstalks_cleancopy = [] + # crosstalk-corrected spectra + cc_spectra = [] + + # correct crosstalk effects for each region first + for region in config.sections: + if verbose: + print("INFO: processing '%s' ..." % region, file=sys.stderr) + crosstalk = Crosstalk(config.get(region), + arf_dict=arf_dict, rmf_dict=rmf_dict, + grouping=grouping, quality=quality) + crosstalk.apply_grouping(verbose=verbose) + crosstalk.estimate_errors(verbose=verbose) + # keep a (almost) clean copy of the crosstalk object + crosstalks_cleancopy.append(crosstalk.copy()) + if verbose: + print("INFO: doing crosstalk correction ...", file=sys.stderr) + crosstalk.do_correction(subtract_bkg=subtract_bkg, + fix_negative=fix_negative, group_squeeze=group_squeeze, + add_history=True, verbose=verbose) + cc_spectra.append(crosstalk.get_spectrum(copy=True)) + + # load back the crosstalk-corrected spectra for deprojection + if verbose: + print("INFO: preparing spectra for deprojection ...", file=sys.stderr) + deprojection = Deprojection(spectra=cc_spectra, grouping=grouping, + quality=quality, verbose=verbose) + if verbose: + print("INFO: scaling spectra according the region angular size...", + file=sys.stderr) + deprojection.scale() + if verbose: + print("INFO: doing deprojection ...", file=sys.stderr) + deprojection.do_deprojection(add_history=True, verbose=verbose) + deproj_results = [ deprojection.get_spec_data( + group_squeeze=group_squeeze, copy=True) ] + + # Monte Carlo for spectral group error estimation + print("INFO: Monte Carlo to estimate spectral errors (%d times) ..." % \ + mc_times, file=sys.stderr) + for i in range(mc_times): + if i % 100 == 0: + print("%d..." % i, end="", flush=True, file=sys.stderr) + # correct crosstalk effects + cc_spectra_copy = [] + for crosstalk in crosstalks_cleancopy: + # copy and randomize + crosstalk_copy = crosstalk.copy().randomize() + crosstalk_copy.do_correction(subtract_bkg=subtract_bkg, + fix_negative=fix_negative, group_squeeze=group_squeeze, + add_history=False, verbose=False) + cc_spectra_copy.append(crosstalk_copy.get_spectrum(copy=True)) + # deproject spectra + deprojection_copy = Deprojection(spectra=cc_spectra_copy, + grouping=grouping, quality=quality, verbose=False) + deprojection_copy.scale() + deprojection_copy.do_deprojection(add_history=False, verbose=False) + deproj_results.append(deprojection_copy.get_spec_data( + group_squeeze=group_squeeze, copy=True)) + print("DONE!", flush=True, file=sys.stderr) + + if verbose: + print("INFO: Calculating the median and errors for each spectrum ...", + file=sys.stderr) + medians, errors = calc_median_errors(deproj_results) + deprojection.set_spec_data(medians, group_squeeze=group_squeeze) + deprojection.add_stat_err(errors, group_squeeze=group_squeeze) + if verbose: + print("INFO: Writing the crosstalk-corrected and deprojected " + \ + "spectra with estimated statistical errors ...", file=sys.stderr) + deprojection.fix_header() + deprojection.write(clobber=clobber) +# main routine }}} + + +# main_deprojection routine {{{ +def main_deprojection(config, mc_times, verbose=False, clobber=False): + """ + Only perform the spectral deprojection. + """ + # collect ARFs and RMFs into dictionaries (avoid interpolation every time) + arf_files = set() + rmf_files = set() + for region in config.sections: + config_reg = config[region] + arf_files.add(config_reg.get("arf")) + rmf_files.add(config_reg.get("rmf")) + arf_files = arf_files - set([None]) + arf_dict = { arf: ARF(arf) for arf in arf_files } + rmf_files = rmf_files - set([None]) + rmf_dict = { rmf: RMF(rmf) for rmf in rmf_files } + if verbose: + print("INFO: arf_files:", arf_files, file=sys.stderr) + print("INFO: rmf_files:", rmf_files, file=sys.stderr) + + # get the GROUPING and QUALITY data + grouping_fits = fits.open(config["grouping"]) + grouping = grouping_fits["SPECTRUM"].data.columns["GROUPING"].array + quality = grouping_fits["SPECTRUM"].data.columns["QUALITY"].array + # squeeze the groupped spectral data, etc. + group_squeeze = True + + # load spectra for deprojection + if verbose: + print("INFO: preparing spectra for deprojection ...", file=sys.stderr) + proj_spectra = [] + for region in config.sections: + config_reg = config[region] + specset = SpectrumSet(filename=config_reg["spec"], + outfile=config_reg["outfile"], + arf=arf_dict.get(config_reg["arf"], config_reg["arf"]), + rmf=rmf_dict.get(config_reg["rmf"], config_reg["rmf"]), + bkg=config_reg["bkg"]) + proj_spectra.append(specset) + + deprojection = Deprojection(spectra=proj_spectra, grouping=grouping, + quality=quality, verbose=verbose) + deprojection.apply_grouping(verbose=verbose) + deprojection.estimate_errors() + if verbose: + print("INFO: scaling spectra according the region angular size ...", + file=sys.stderr) + deprojection.scale() + + # keep a (almost) clean copy of the input projected spectra + proj_spectra_cleancopy = [ spec.copy() for spec in proj_spectra ] + + if verbose: + print("INFO: subtract the background ...", file=sys.stderr) + deprojection.subtract_bkg(verbose=verbose) + if verbose: + print("INFO: doing deprojection ...", file=sys.stderr) + deprojection.do_deprojection(add_history=True, verbose=verbose) + deproj_results = [ deprojection.get_spec_data( + group_squeeze=group_squeeze, copy=True) ] + + # Monte Carlo for spectral group error estimation + print("INFO: Monte Carlo to estimate spectral errors (%d times) ..." % \ + mc_times, file=sys.stderr) + for i in range(mc_times): + if i % 100 == 0: + print("%d..." % i, end="", flush=True, file=sys.stderr) + # copy and randomize the input projected spectra + proj_spectra_copy = [ spec.copy().randomize() + for spec in proj_spectra_cleancopy ] + # deproject spectra + deprojection_copy = Deprojection(spectra=proj_spectra_copy, + grouping=grouping, quality=quality, verbose=False) + deprojection_copy.subtract_bkg(verbose=False) + deprojection_copy.do_deprojection(add_history=False, verbose=False) + deproj_results.append(deprojection_copy.get_spec_data( + group_squeeze=group_squeeze, copy=True)) + print("DONE!", flush=True, file=sys.stderr) + + if verbose: + print("INFO: Calculating the median and errors for each spectrum ...", + file=sys.stderr) + medians, errors = calc_median_errors(deproj_results) + deprojection.set_spec_data(medians, group_squeeze=group_squeeze) + deprojection.add_stat_err(errors, group_squeeze=group_squeeze) + if verbose: + print("INFO: Writing the deprojected spectra " + \ + "with estimated statistical errors ...", file=sys.stderr) + deprojection.fix_header() + deprojection.write(clobber=clobber) +# main_deprojection routine }}} + + +# main_crosstalk routine {{{ +def main_crosstalk(config, subtract_bkg, fix_negative, mc_times, + verbose=False, clobber=False): + """ + Only perform the crosstalk correction. + """ + # collect ARFs and RMFs into dictionaries (avoid interpolation every time) + arf_files = set() + rmf_files = set() + for region in config.sections: + config_reg = config[region] + arf_files.add(config_reg.get("arf")) + rmf_files.add(config_reg.get("rmf")) + for reg_in in config_reg["cross_in"].values(): + arf_files.add(reg_in.get("arf")) + arf_files.add(reg_in.get("cross_arf")) + if "cross_out" in config_reg.sections: + for arf in config_reg["cross_out"].as_list("cross_arf"): + arf_files.add(arf) + arf_files = arf_files - set([None]) + arf_dict = { arf: ARF(arf) for arf in arf_files } + rmf_files = rmf_files - set([None]) + rmf_dict = { rmf: RMF(rmf) for rmf in rmf_files } + if verbose: + print("INFO: arf_files:", arf_files, file=sys.stderr) + print("INFO: rmf_files:", rmf_files, file=sys.stderr) + + # get the GROUPING and QUALITY data + if "grouping" in config.keys(): + grouping_fits = fits.open(config["grouping"]) + grouping = grouping_fits["SPECTRUM"].data.columns["GROUPING"].array + quality = grouping_fits["SPECTRUM"].data.columns["QUALITY"].array + group_squeeze = True + else: + grouping = None + quality = None + group_squeeze = False + + # crosstalk objects (BEFORE background subtraction) + crosstalks_cleancopy = [] + # crosstalk-corrected spectra + cc_spectra = [] + + # correct crosstalk effects for each region first + for region in config.sections: + if verbose: + print("INFO: processing '%s' ..." % region, file=sys.stderr) + crosstalk = Crosstalk(config.get(region), + arf_dict=arf_dict, rmf_dict=rmf_dict, + grouping=grouping, quality=quality) + if grouping is not None: + crosstalk.apply_grouping(verbose=verbose) + crosstalk.estimate_errors(verbose=verbose) + # keep a (almost) clean copy of the crosstalk object + crosstalks_cleancopy.append(crosstalk.copy()) + if verbose: + print("INFO: doing crosstalk correction ...", file=sys.stderr) + crosstalk.do_correction(subtract_bkg=subtract_bkg, + fix_negative=fix_negative, group_squeeze=group_squeeze, + add_history=True, verbose=verbose) + crosstalk.fix_header() + cc_spectra.append(crosstalk.get_spectrum(copy=True)) + + # spectral data of the crosstalk-corrected spectra + cc_results = [] + cc_results.append([ spec.get_data(group_squeeze=group_squeeze, copy=True) + for spec in cc_spectra ]) + + # Monte Carlo for spectral group error estimation + print("INFO: Monte Carlo to estimate spectral errors (%d times) ..." % \ + mc_times, file=sys.stderr) + for i in range(mc_times): + if i % 100 == 0: + print("%d..." % i, end="", flush=True, file=sys.stderr) + # correct crosstalk effects + cc_spectra_copy = [] + for crosstalk in crosstalks_cleancopy: + # copy and randomize + crosstalk_copy = crosstalk.copy().randomize() + crosstalk_copy.do_correction(subtract_bkg=subtract_bkg, + fix_negative=fix_negative, group_squeeze=group_squeeze, + add_history=False, verbose=False) + cc_spectra_copy.append(crosstalk_copy.get_spectrum(copy=True)) + cc_results.append([ spec.get_data(group_squeeze=group_squeeze, + copy=True) + for spec in cc_spectra_copy ]) + print("DONE!", flush=True, file=sys.stderr) + + if verbose: + print("INFO: Calculating the median and errors for each spectrum ...", + file=sys.stderr) + medians, errors = calc_median_errors(cc_results) + if verbose: + print("INFO: Writing the crosstalk-corrected spectra " + \ + "with estimated statistical errors ...", + file=sys.stderr) + for spec, data, err in zip(cc_spectra, medians, errors): + spec.set_data(data, group_squeeze=group_squeeze) + spec.add_stat_err(err, group_squeeze=group_squeeze) + spec.write(clobber=clobber) +# main_crosstalk routine }}} + + +if __name__ == "__main__": + # arguments' default values + default_mode = "both" + default_mc_times = 5000 + # commandline arguments parser + parser = argparse.ArgumentParser( + description="Correct the crosstalk effects for XMM EPIC spectra", + epilog="Version: %s (%s)" % (__version__, __date__)) + parser.add_argument("config", help="config file in which describes " +\ + "the crosstalk relations ('ConfigObj' syntax)") + parser.add_argument("-m", "--mode", dest="mode", default=default_mode, + help="operation mode (both | crosstalk | deprojection)") + parser.add_argument("-B", "--no-subtract-bkg", dest="subtract_bkg", + action="store_false", help="do NOT subtract background first") + parser.add_argument("-N", "--fix-negative", dest="fix_negative", + action="store_true", help="fix negative channel values") + parser.add_argument("-M", "--mc-times", dest="mc_times", + type=int, default=default_mc_times, + help="Monte Carlo times for error estimation") + parser.add_argument("-C", "--clobber", dest="clobber", + action="store_true", help="overwrite output file if exists") + parser.add_argument("-v", "--verbose", dest="verbose", + action="store_true", help="show verbose information") + args = parser.parse_args() + # merge commandline arguments and config + config = ConfigObj(args.config) + subtract_bkg = set_argument("subtract_bkg", True, args, config) + fix_negative = set_argument("fix_negative", False, args, config) + verbose = set_argument("verbose", False, args, config) + clobber = set_argument("clobber", False, args, config) + # operation mode + mode = config.get("mode", default_mode) + if args.mode != default_mode: + mode = args.mode + # Monte Carlo times + mc_times = config.as_int("mc_times") + if args.mc_times != default_mc_times: + mc_times = args.mc_times + + if mode.lower() == "both": + print("MODE: CROSSTALK + DEPROJECTION", file=sys.stderr) + main(config, subtract_bkg=subtract_bkg, fix_negative=fix_negative, + mc_times=mc_times, verbose=verbose, clobber=clobber) + elif mode.lower() == "deprojection": + print("MODE: DEPROJECTION", file=sys.stderr) + main_deprojection(config, mc_times=mc_times, + verbose=verbose, clobber=clobber) + elif mode.lower() == "crosstalk": + print("MODE: CROSSTALK", file=sys.stderr) + main_crosstalk(config, subtract_bkg=subtract_bkg, + fix_negative=fix_negative, mc_times=mc_times, + verbose=verbose, clobber=clobber) + else: + raise ValueError("Invalid operation mode: %s" % mode) + print(WARNING) + +# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/fit_sbp.py b/python/fit_sbp.py new file mode 100755 index 0000000..c22e0c8 --- /dev/null +++ b/python/fit_sbp.py @@ -0,0 +1,807 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Aaron LI +# Created: 2016-03-13 +# Updated: 2016-04-26 +# +# Changelogs: +# 2016-04-26: +# * Reorder some methods of classes 'FitModelSBeta' and 'FitModelDBeta' +# * Change the output file extension from ".txt" to ".json" +# 2016-04-21: +# * Plot another X axis with unit "r500", with R500 values marked +# * Adjust output image size/resolution +# 2016-04-20: +# * Support "pix" and "kpc" units +# * Allow ignore data w.r.t R500 value +# * Major changes to the config syntax +# * Add commandline argument to select the sbp model +# 2016-04-05: +# * Allow fix parameters +# 2016-03-31: +# * Remove `ci_report()' +# * Add `make_results()' to orgnize all results as s Python dictionary +# * Report results as json string +# 2016-03-28: +# * Add `main()', `make_model()' +# * Use `configobj' to handle configurations +# * Save fit results and plot +# * Add `ci_report()' +# 2016-03-14: +# * Refactor classes `FitModelSBeta' and `FitModelDBeta' +# * Add matplotlib plot support +# * Add `ignore_data()' and `notice_data()' support +# * Add classes `FitModelSBetaNorm' and `FitModelDBetaNorm' +# +# TODO: +# * to allow fit the outer beta component, then fix it, and fit the inner one +# * to integrate basic information of config file to the output json +# * to output the ignored radius range in the same unit as input sbp data +# + +""" +Fit the surface brightness profile (SBP) with the single-beta model: + s(r) = s0 * [1.0 + (r/rc)^2] ^ (0.5-3*beta) + bkg +or the double-beta model: + s(r) = s01 * [1.0 + (r/rc1)^2] ^ (0.5-3*beta1) + + s02 * [1.0 + (r/rc2)^2] ^ (0.5-3*beta2) + bkg + + +Sample config file: +------------------------------------------------- +name = <NAME> +obsid = <OBSID> +r500_pix = <R500_PIX> +r500_kpc = <R500_KPC> + +sbpfile = sbprofile.txt +# unit of radius: pix (default) or kpc +unit = pixel + +# sbp model: "sbeta" or "dbeta" +model = sbeta +#model = dbeta + +# output file to store the fitting results +outfile = sbpfit.json +# output file to save the fitting plot +imgfile = sbpfit.png + +# data range to be ignored during fitting (same unit as the above "unit") +#ignore = 0.0-20.0, +# specify the ignore range w.r.t R500 ("r500_pix" or "r500_kpc" required) +#ignore_r500 = 0.0-0.15, + +[sbeta] +# model-related options (OVERRIDE the upper level options) +outfile = sbpfit_sbeta.json +imgfile = sbpfit_sbeta.png +#ignore = 0.0-20.0, +#ignore_r500 = 0.0-0.15, + [[params]] + # model parameters + # name = initial, lower, upper, variable (FIXED/False to fix the parameter) + s0 = 1.0e-8, 0.0, 1.0e-6 + rc = 30.0, 5.0, 1.0e4 + #rc = 30.0, 5.0, 1.0e4, FIXED + beta = 0.7, 0.3, 1.1 + bkg = 1.0e-10, 0.0, 1.0e-8 + + +[dbeta] +outfile = sbpfit_dbeta.json +imgfile = sbpfit_dbeta.png +#ignore = 0.0-20.0, +#ignore_r500 = 0.0-0.15, + [[params]] + s01 = 1.0e-8, 0.0, 1.0e-6 + rc1 = 50.0, 10.0, 1.0e4 + beta1 = 0.7, 0.3, 1.1 + s02 = 1.0e-8, 0.0, 1.0e-6 + rc2 = 30.0, 2.0, 5.0e2 + beta2 = 0.7, 0.3, 1.1 + bkg = 1.0e-10, 0.0, 1.0e-8 +------------------------------------------------- +""" + +__version__ = "0.6.2" +__date__ = "2016-04-26" + + +import os +import sys +import re +import argparse +import json +from collections import OrderedDict + +import numpy as np +import lmfit +import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure +from configobj import ConfigObj + + +plt.style.use("ggplot") + + +class FitModel: + """ + Meta-class of the fitting model. + + The supplied `func' should have the following syntax: + y = f(x, params) + where the `params' is `lmfit.Parameters' instance which contains all + the model parameters to be fitted, and should be provided as well. + """ + def __init__(self, name=None, func=None, params=lmfit.Parameters()): + self.name = name + self.func = func + self.params = params + + def f(self, x): + return self.func(x, self.params) + + def get_param(self, name=None): + """ + Return the requested `Parameter' object or the whole + `Parameters' object of no name supplied. + """ + try: + return self.params[name] + except KeyError: + return self.params + + def set_param(self, name, *args, **kwargs): + """ + Set the properties of the specified parameter. + """ + param = self.params[name] + param.set(*args, **kwargs) + + def plot(self, params, xdata, ax): + """ + Plot the fitted model. + """ + f_fitted = lambda x: self.func(x, params) + ydata = f_fitted(xdata) + ax.plot(xdata, ydata, 'k-') + +class FitModelSBeta(FitModel): + """ + The single-beta model to be fitted. + Single-beta model, with a constant background. + """ + params = lmfit.Parameters() + params.add_many( # (name, value, vary, min, max, expr) + ("s0", 1.0e-8, True, 0.0, 1.0e-6, None), + ("rc", 30.0, True, 1.0, 1.0e4, None), + ("beta", 0.7, True, 0.3, 1.1, None), + ("bkg", 1.0e-9, True, 0.0, 1.0e-7, None)) + + def __init__(self): + super(self.__class__, self).__init__(name="Single-beta", + func=self.sbeta, params=self.params) + + @staticmethod + def sbeta(r, params): + parvals = params.valuesdict() + s0 = parvals["s0"] + rc = parvals["rc"] + beta = parvals["beta"] + bkg = parvals["bkg"] + return s0 * np.power((1 + (r/rc)**2), (0.5 - 3*beta)) + bkg + + def plot(self, params, xdata, ax): + """ + Plot the fitted model, as well as the fitted parameters. + """ + super(self.__class__, self).plot(params, xdata, ax) + ydata = self.sbeta(xdata, params) + # fitted paramters + ax.vlines(x=params["rc"].value, ymin=min(ydata), ymax=max(ydata), + linestyles="dashed") + ax.hlines(y=params["bkg"].value, xmin=min(xdata), xmax=max(xdata), + linestyles="dashed") + ax.text(x=params["rc"].value, y=min(ydata), + s="beta: %.2f\nrc: %.2f" % (params["beta"].value, + params["rc"].value)) + ax.text(x=min(xdata), y=min(ydata), + s="bkg: %.3e" % params["bkg"].value, + verticalalignment="top") + + +class FitModelDBeta(FitModel): + """ + The double-beta model to be fitted. + Double-beta model, with a constant background. + + NOTE: + the first beta component (s01, rc1, beta1) describes the main and + outer SBP; while the second beta component (s02, rc2, beta2) accounts + for the central brightness excess. + """ + params = lmfit.Parameters() + params.add("s01", value=1.0e-8, min=0.0, max=1.0e-6) + params.add("rc1", value=50.0, min=10.0, max=1.0e4) + params.add("beta1", value=0.7, min=0.3, max=1.1) + #params.add("df_s0", value=1.0e-8, min=0.0, max=1.0e-6) + #params.add("s02", expr="s01 + df_s0") + params.add("s02", value=1.0e-8, min=0.0, max=1.0e-6) + #params.add("df_rc", value=30.0, min=0.0, max=1.0e4) + #params.add("rc2", expr="rc1 - df_rc") + params.add("rc2", value=20.0, min=1.0, max=5.0e2) + params.add("beta2", value=0.7, min=0.3, max=1.1) + params.add("bkg", value=1.0e-9, min=0.0, max=1.0e-7) + + def __init__(self): + super(self.__class__, self).__init__(name="Double-beta", + func=self.dbeta, params=self.params) + + @classmethod + def dbeta(self, r, params): + return self.beta1(r, params) + self.beta2(r, params) + + @staticmethod + def beta1(r, params): + """ + This beta component describes the main/outer part of the SBP. + """ + parvals = params.valuesdict() + s01 = parvals["s01"] + rc1 = parvals["rc1"] + beta1 = parvals["beta1"] + bkg = parvals["bkg"] + return s01 * np.power((1 + (r/rc1)**2), (0.5 - 3*beta1)) + bkg + + @staticmethod + def beta2(r, params): + """ + This beta component describes the central/excess part of the SBP. + """ + parvals = params.valuesdict() + s02 = parvals["s02"] + rc2 = parvals["rc2"] + beta2 = parvals["beta2"] + return s02 * np.power((1 + (r/rc2)**2), (0.5 - 3*beta2)) + + def plot(self, params, xdata, ax): + """ + Plot the fitted model, and each beta component, + as well as the fitted parameters. + """ + super(self.__class__, self).plot(params, xdata, ax) + beta1_ydata = self.beta1(xdata, params) + beta2_ydata = self.beta2(xdata, params) + ax.plot(xdata, beta1_ydata, 'b-.') + ax.plot(xdata, beta2_ydata, 'b-.') + # fitted paramters + ydata = beta1_ydata + beta2_ydata + ax.vlines(x=params["rc1"].value, ymin=min(ydata), ymax=max(ydata), + linestyles="dashed") + ax.vlines(x=params["rc2"].value, ymin=min(ydata), ymax=max(ydata), + linestyles="dashed") + ax.hlines(y=params["bkg"].value, xmin=min(xdata), xmax=max(xdata), + linestyles="dashed") + ax.text(x=params["rc1"].value, y=min(ydata), + s="beta1: %.2f\nrc1: %.2f" % (params["beta1"].value, + params["rc1"].value)) + ax.text(x=params["rc2"].value, y=min(ydata), + s="beta2: %.2f\nrc2: %.2f" % (params["beta2"].value, + params["rc2"].value)) + ax.text(x=min(xdata), y=min(ydata), + s="bkg: %.3e" % params["bkg"].value, + verticalalignment="top") + + +class FitModelSBetaNorm(FitModel): + """ + The single-beta model to be fitted. + Single-beta model, with a constant background. + Normalized the `s0' and `bkg' parameters by take the logarithm. + """ + params = lmfit.Parameters() + params.add_many( # (name, value, vary, min, max, expr) + ("log10_s0", -8.0, True, -12.0, -6.0, None), + ("rc", 30.0, True, 1.0, 1.0e4, None), + ("beta", 0.7, True, 0.3, 1.1, None), + ("log10_bkg", -9.0, True, -12.0, -7.0, None)) + + @staticmethod + def sbeta(r, params): + parvals = params.valuesdict() + s0 = 10 ** parvals["log10_s0"] + rc = parvals["rc"] + beta = parvals["beta"] + bkg = 10 ** parvals["log10_bkg"] + return s0 * np.power((1 + (r/rc)**2), (0.5 - 3*beta)) + bkg + + def __init__(self): + super(self.__class__, self).__init__(name="Single-beta", + func=self.sbeta, params=self.params) + + def plot(self, params, xdata, ax): + """ + Plot the fitted model, as well as the fitted parameters. + """ + super(self.__class__, self).plot(params, xdata, ax) + ydata = self.sbeta(xdata, params) + # fitted paramters + ax.vlines(x=params["rc"].value, ymin=min(ydata), ymax=max(ydata), + linestyles="dashed") + ax.hlines(y=(10 ** params["bkg"].value), xmin=min(xdata), + xmax=max(xdata), linestyles="dashed") + ax.text(x=params["rc"].value, y=min(ydata), + s="beta: %.2f\nrc: %.2f" % (params["beta"].value, + params["rc"].value)) + ax.text(x=min(xdata), y=min(ydata), + s="bkg: %.3e" % (10 ** params["bkg"].value), + verticalalignment="top") + + +class FitModelDBetaNorm(FitModel): + """ + The double-beta model to be fitted. + Double-beta model, with a constant background. + Normalized the `s01', `s02' and `bkg' parameters by take the logarithm. + + NOTE: + the first beta component (s01, rc1, beta1) describes the main and + outer SBP; while the second beta component (s02, rc2, beta2) accounts + for the central brightness excess. + """ + params = lmfit.Parameters() + params.add("log10_s01", value=-8.0, min=-12.0, max=-6.0) + params.add("rc1", value=50.0, min=10.0, max=1.0e4) + params.add("beta1", value=0.7, min=0.3, max=1.1) + #params.add("df_s0", value=1.0e-8, min=0.0, max=1.0e-6) + #params.add("s02", expr="s01 + df_s0") + params.add("log10_s02", value=-8.0, min=-12.0, max=-6.0) + #params.add("df_rc", value=30.0, min=0.0, max=1.0e4) + #params.add("rc2", expr="rc1 - df_rc") + params.add("rc2", value=20.0, min=1.0, max=5.0e2) + params.add("beta2", value=0.7, min=0.3, max=1.1) + params.add("log10_bkg", value=-9.0, min=-12.0, max=-7.0) + + @staticmethod + def beta1(r, params): + """ + This beta component describes the main/outer part of the SBP. + """ + parvals = params.valuesdict() + s01 = 10 ** parvals["log10_s01"] + rc1 = parvals["rc1"] + beta1 = parvals["beta1"] + bkg = 10 ** parvals["log10_bkg"] + return s01 * np.power((1 + (r/rc1)**2), (0.5 - 3*beta1)) + bkg + + @staticmethod + def beta2(r, params): + """ + This beta component describes the central/excess part of the SBP. + """ + parvals = params.valuesdict() + s02 = 10 ** parvals["log10_s02"] + rc2 = parvals["rc2"] + beta2 = parvals["beta2"] + return s02 * np.power((1 + (r/rc2)**2), (0.5 - 3*beta2)) + + @classmethod + def dbeta(self, r, params): + return self.beta1(r, params) + self.beta2(r, params) + + def __init__(self): + super(self.__class__, self).__init__(name="Double-beta", + func=self.dbeta, params=self.params) + + def plot(self, params, xdata, ax): + """ + Plot the fitted model, and each beta component, + as well as the fitted parameters. + """ + super(self.__class__, self).plot(params, xdata, ax) + beta1_ydata = self.beta1(xdata, params) + beta2_ydata = self.beta2(xdata, params) + ax.plot(xdata, beta1_ydata, 'b-.') + ax.plot(xdata, beta2_ydata, 'b-.') + # fitted paramters + ydata = beta1_ydata + beta2_ydata + ax.vlines(x=params["log10_rc1"].value, ymin=min(ydata), ymax=max(ydata), + linestyles="dashed") + ax.vlines(x=params["rc2"].value, ymin=min(ydata), ymax=max(ydata), + linestyles="dashed") + ax.hlines(y=(10 ** params["bkg"].value), xmin=min(xdata), + xmax=max(xdata), linestyles="dashed") + ax.text(x=params["rc1"].value, y=min(ydata), + s="beta1: %.2f\nrc1: %.2f" % (params["beta1"].value, + params["rc1"].value)) + ax.text(x=params["rc2"].value, y=min(ydata), + s="beta2: %.2f\nrc2: %.2f" % (params["beta2"].value, + params["rc2"].value)) + ax.text(x=min(xdata), y=min(ydata), + s="bkg: %.3e" % (10 ** params["bkg"].value), + verticalalignment="top") + + +class SbpFit: + """ + Class to handle the SBP fitting with single-/double-beta model. + """ + def __init__(self, model, method="lbfgsb", + xdata=None, ydata=None, xerr=None, yerr=None, xunit="pix", + name=None, obsid=None, r500_pix=None, r500_kpc=None): + self.method = method + self.model = model + self.load_data(xdata=xdata, ydata=ydata, xerr=xerr, yerr=yerr, + xunit=xunit) + self.set_source(name=name, obsid=obsid, r500_pix=r500_pix, + r500_kpc=r500_kpc) + + def set_source(self, name, obsid=None, r500_pix=None, r500_kpc=None): + self.name = name + try: + self.obsid = int(obsid) + except TypeError: + self.obsid = None + try: + self.r500_pix = float(r500_pix) + except TypeError: + self.r500_pix = None + try: + self.r500_kpc = float(r500_kpc) + except TypeError: + self.r500_kpc = None + try: + self.kpc_per_pix = self.r500_kpc / self.r500_pix + except (TypeError, ZeroDivisionError): + self.kpc_per_pix = -1 + + def load_data(self, xdata, ydata, xerr, yerr, xunit="pix"): + self.xdata = xdata + self.ydata = ydata + self.xerr = xerr + self.yerr = yerr + if xdata is not None: + self.mask = np.ones(xdata.shape, dtype=np.bool) + else: + self.mask = None + if xunit.lower() in ["pix", "pixel"]: + self.xunit = "pix" + elif xunit.lower() == "kpc": + self.xunit = "kpc" + else: + raise ValueError("invalid xunit: %s" % xunit) + + def ignore_data(self, xmin=None, xmax=None, unit=None): + """ + Ignore the data points within range [xmin, xmax]. + If xmin is None, then xmin=min(xdata); + if xmax is None, then xmax=max(xdata). + + if unit is None, then assume the same unit as `self.xunit'. + """ + if unit is None: + unit = self.xunit + if xmin is not None: + xmin = self.convert_unit(xmin, unit=unit) + else: + xmin = np.min(self.xdata) + if xmax is not None: + xmax = self.convert_unit(xmax, unit=unit) + else: + xmax = np.max(self.xdata) + ignore_idx = np.logical_and(self.xdata >= xmin, self.xdata <= xmax) + self.mask[ignore_idx] = False + # reset `f_residual' + self.f_residual = None + + def notice_data(self, xmin=None, xmax=None, unit=None): + """ + Notice the data points within range [xmin, xmax]. + If xmin is None, then xmin=min(xdata); + if xmax is None, then xmax=max(xdata). + + if unit is None, then assume the same unit as `self.xunit'. + """ + if unit is None: + unit = self.xunit + if xmin is not None: + xmin = self.convert_unit(xmin, unit=unit) + else: + xmin = np.min(self.xdata) + if xmax is not None: + xmax = self.convert_unit(xmax, unit=unit) + else: + xmax = np.max(self.xdata) + notice_idx = np.logical_and(self.xdata >= xmin, self.xdata <= xmax) + self.mask[notice_idx] = True + # reset `f_residual' + self.f_residual = None + + def convert_unit(self, x, unit): + """ + Convert the value x in given unit to be the unit `self.xunit' + """ + if unit == self.xunit: + return x + elif (unit == "pix") and (self.xunit == "kpc"): + return (x / self.r500_pix * self.r500_kpc) + elif (unit == "kpc") and (self.xunit == "pix"): + return (x / self.r500_kpc * self.r500_pix) + elif (unit == "r500") and (self.xunit == "pix"): + return (x * self.r500_pix) + elif (unit == "r500") and (self.xunit == "kpc"): + return (x * self.r500_kpc) + else: + raise ValueError("invalid units: %s vs. %s" % (unit, self.xunit)) + + def convert_to_r500(self, x, unit=None): + """ + Convert the value x in given unit to be in unit "r500". + """ + if unit is None: + unit = self.xunit + if unit == "r500": + return x + elif unit == "pix": + return (x / self.r500_pix) + elif unit == "kpc": + return (x / self.r500_kpc) + else: + raise ValueError("invalid unit: %s" % unit) + + def set_residual(self): + def f_residual(params): + if self.yerr is None: + return self.model.func(self.xdata[self.mask], params) - \ + self.ydata + else: + return (self.model.func(self.xdata[self.mask], params) - \ + self.ydata[self.mask]) / self.yerr[self.mask] + self.f_residual = f_residual + + def fit(self, method=None): + if method is None: + method = self.method + if not hasattr(self, "f_residual") or self.f_residual is None: + self.set_residual() + self.fitter = lmfit.Minimizer(self.f_residual, self.model.params) + self.fitted = self.fitter.minimize(method=method) + self.fitted_model = lambda x: self.model.func(x, self.fitted.params) + + def calc_ci(self, sigmas=[0.68, 0.90]): + # `conf_interval' requires the fitted results have valid `stderr', + # so we need to re-fit the model with method `leastsq'. + fitted = self.fitter.minimize(method="leastsq", + params=self.fitted.params) + self.ci, self.trace = lmfit.conf_interval(self.fitter, fitted, + sigmas=sigmas, trace=True) + + def make_results(self): + """ + Make the `self.results' dictionary which contains all the fitting + results as well as the confidence intervals. + """ + fitted = self.fitted + self.results = OrderedDict() + ## fitting results + self.results.update( + nfev = fitted.nfev, + ndata = fitted.ndata, + nvarys = fitted.nvarys, # number of varible paramters + nfree = fitted.nfree, # degree of freem + chisqr = fitted.chisqr, + redchi = fitted.redchi, + aic = fitted.aic, + bic = fitted.bic) + params = fitted.params + pnames = list(params.keys()) + pvalues = OrderedDict() + for pn in pnames: + par = params.get(pn) + pvalues[pn] = [par.value, par.min, par.max, par.vary] + self.results["params"] = pvalues + ## confidence intervals + if hasattr(self, "ci") and self.ci is not None: + ci = self.ci + ci_values = OrderedDict() + ci_sigmas = [ "ci%02d" % (v[0]*100) for v in ci.get(pnames[0]) ] + ci_names = sorted(list(set(ci_sigmas))) + ci_idx = { k: [] for k in ci_names } + for cn, idx in zip(ci_sigmas, range(len(ci_sigmas))): + ci_idx[cn].append(idx) + # parameters ci + for pn in pnames: + ci_pv = OrderedDict() + pv = [ v[1] for v in ci.get(pn) ] + # best + pv_best = pv[ ci_idx["ci00"][0] ] + ci_pv["best"] = pv_best + # ci of each sigma + pv2 = [ v-pv_best for v in pv ] + for cn in ci_names[1:]: + ci_pv[cn] = [ pv2[idx] for idx in ci_idx[cn] ] + ci_values[pn] = ci_pv + self.results["ci"] = ci_values + + def report(self, outfile=sys.stdout): + if not hasattr(self, "results") or self.results is None: + self.make_results() + jd = json.dumps(self.results, indent=2) + print(jd, file=outfile) + + def plot(self, ax=None, fig=None, r500_axis=True): + """ + Arguments: + * r500_axis: whether to add a second X axis in unit "r500" + """ + if ax is None: + fig, ax = plt.subplots(1, 1) + # noticed data points + eb = ax.errorbar(self.xdata[self.mask], self.ydata[self.mask], + xerr=self.xerr[self.mask], yerr=self.yerr[self.mask], + fmt="none") + # ignored data points + ignore_mask = np.logical_not(self.mask) + if np.sum(ignore_mask) > 0: + eb = ax.errorbar(self.xdata[ignore_mask], self.ydata[ignore_mask], + xerr=self.xerr[ignore_mask], yerr=self.yerr[ignore_mask], + fmt="none") + eb[-1][0].set_linestyle("-.") + # fitted model + xmax = self.xdata[-1] + self.xerr[-1] + xpred = np.power(10, np.linspace(0, np.log10(xmax), 2*len(self.xdata))) + ypred = self.fitted_model(xpred) + ymin = min(min(self.ydata), min(ypred)) + ymax = max(max(self.ydata), max(ypred)) + self.model.plot(params=self.fitted.params, xdata=xpred, ax=ax) + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlim(1.0, xmax) + ax.set_ylim(ymin/1.2, ymax*1.2) + name = self.name + if self.obsid is not None: + name += "; %s" % self.obsid + ax.set_title("Fitted Surface Brightness Profile (%s)" % name) + ax.set_xlabel("Radius (%s)" % self.xunit) + ax.set_ylabel(r"Surface Brightness (photons/cm$^2$/pixel$^2$/s)") + ax.text(x=xmax, y=ymax, + s="redchi: %.2f / %.2f = %.2f" % (self.fitted.chisqr, + self.fitted.nfree, self.fitted.chisqr/self.fitted.nfree), + horizontalalignment="right", verticalalignment="top") + plot_ret = [fig, ax] + if r500_axis: + # Add a second X-axis with labels in unit "r500" + # Credit: https://stackoverflow.com/a/28192477/4856091 + try: + ax.title.set_position([0.5, 1.1]) # raise title position + ax2 = ax.twiny() + # NOTE: the ORDER of the following lines MATTERS + ax2.set_xscale(ax.get_xscale()) + ax2_ticks = ax.get_xticks() + ax2.set_xticks(ax2_ticks) + ax2.set_xbound(ax.get_xbound()) + ax2.set_xticklabels([ "%.2g" % self.convert_to_r500(x) + for x in ax2_ticks ]) + ax2.set_xlabel("Radius (r500; r500 = %s pix = %s kpc)" % (\ + self.r500_pix, self.r500_kpc)) + ax2.grid(False) + plot_ret.append(ax2) + except ValueError: + # cannot convert X values to unit "r500" + pass + # automatically adjust layout + fig.tight_layout() + return plot_ret + + +def make_model(config, modelname): + """ + Make the model with parameters set according to the config. + """ + if modelname == "sbeta": + # single-beta model + model = FitModelSBeta() + elif modelname == "dbeta": + # double-beta model + model = FitModelDBeta() + else: + raise ValueError("Invalid model: %s" % modelname) + # set initial values and bounds for the model parameters + params = config[modelname]["params"] + for p, value in params.items(): + variable = True + if len(value) == 4 and value[3].upper() in ["FIXED", "FALSE"]: + variable = False + model.set_param(name=p, value=float(value[0]), + min=float(value[1]), max=float(value[2]), vary=variable) + return model + + +def main(): + # parser for command line options and arguments + parser = argparse.ArgumentParser( + description="Fit surface brightness profile with " + \ + "single-/double-beta model", + epilog="Version: %s (%s)" % (__version__, __date__)) + parser.add_argument("-V", "--version", action="version", + version="%(prog)s " + "%s (%s)" % (__version__, __date__)) + parser.add_argument("config", help="Config file for SBP fitting") + # exclusive argument group for model selection + grp_model = parser.add_mutually_exclusive_group(required=False) + grp_model.add_argument("-s", "--sbeta", dest="sbeta", + action="store_true", help="single-beta model for SBP") + grp_model.add_argument("-d", "--dbeta", dest="dbeta", + action="store_true", help="double-beta model for SBP") + # + args = parser.parse_args() + + config = ConfigObj(args.config) + + # determine the model name + if args.sbeta: + modelname = "sbeta" + elif args.dbeta: + modelname = "dbeta" + else: + modelname = config["model"] + + config_model = config[modelname] + # determine the "outfile" and "imgfile" + outfile = config.get("outfile") + outfile = config_model.get("outfile", outfile) + imgfile = config.get("imgfile") + imgfile = config_model.get("imgfile", imgfile) + + # SBP fitting model + model = make_model(config, modelname=modelname) + + # sbp data and fit object + sbpdata = np.loadtxt(config["sbpfile"]) + sbpfit = SbpFit(model=model, xdata=sbpdata[:, 0], xerr=sbpdata[:, 1], + ydata=sbpdata[:, 2], yerr=sbpdata[:, 3], + xunit=config.get("unit", "pix")) + sbpfit.set_source(name=config["name"], obsid=config.get("obsid"), + r500_pix=config.get("r500_pix"), r500_kpc=config.get("r500_kpc")) + + # apply data range ignorance + if "ignore" in config.keys(): + for ig in config.as_list("ignore"): + xmin, xmax = map(float, ig.split("-")) + sbpfit.ignore_data(xmin=xmin, xmax=xmax) + if "ignore_r500" in config.keys(): + for ig in config.as_list("ignore_r500"): + xmin, xmax = map(float, ig.split("-")) + sbpfit.ignore_data(xmin=xmin, xmax=xmax, unit="r500") + + # apply additional data range ignorance specified within model section + if "ignore" in config_model.keys(): + for ig in config_model.as_list("ignore"): + xmin, xmax = map(float, ig.split("-")) + sbpfit.ignore_data(xmin=xmin, xmax=xmax) + if "ignore_r500" in config_model.keys(): + for ig in config_model.as_list("ignore_r500"): + xmin, xmax = map(float, ig.split("-")) + sbpfit.ignore_data(xmin=xmin, xmax=xmax, unit="r500") + + # fit and calculate confidence intervals + sbpfit.fit() + sbpfit.calc_ci() + sbpfit.report() + with open(outfile, "w") as ofile: + sbpfit.report(outfile=ofile) + + # make and save a plot + fig = Figure(figsize=(10, 8)) + canvas = FigureCanvas(fig) + ax = fig.add_subplot(111) + sbpfit.plot(ax=ax, fig=fig, r500_axis=True) + fig.savefig(imgfile, dpi=150) + + +if __name__ == "__main__": + main() + +# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/imapUTF7.py b/python/imapUTF7.py new file mode 100644 index 0000000..2e4db0a --- /dev/null +++ b/python/imapUTF7.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# This code was originally in PloneMailList, a GPL'd software. +# http://svn.plone.org/svn/collective/mxmImapClient/trunk/imapUTF7.py +# http://bugs.python.org/issue5305 +# +# Port to Python 3.x +# Credit: https://github.com/MarechJ/py3_imap_utf7 +# +# 2016-01-23 +# Aaron LI +# + +""" +Imap folder names are encoded using a special version of utf-7 as defined in RFC +2060 section 5.1.3. + +5.1.3. Mailbox International Naming Convention + + By convention, international mailbox names are specified using a + modified version of the UTF-7 encoding described in [UTF-7]. The + purpose of these modifications is to correct the following problems + with UTF-7: + + 1) UTF-7 uses the "+" character for shifting; this conflicts with + the common use of "+" in mailbox names, in particular USENET + newsgroup names. + + 2) UTF-7's encoding is BASE64 which uses the "/" character; this + conflicts with the use of "/" as a popular hierarchy delimiter. + + 3) UTF-7 prohibits the unencoded usage of "\"; this conflicts with + the use of "\" as a popular hierarchy delimiter. + + 4) UTF-7 prohibits the unencoded usage of "~"; this conflicts with + the use of "~" in some servers as a home directory indicator. + + 5) UTF-7 permits multiple alternate forms to represent the same + string; in particular, printable US-ASCII chararacters can be + represented in encoded form. + + In modified UTF-7, printable US-ASCII characters except for "&" + represent themselves; that is, characters with octet values 0x20-0x25 + and 0x27-0x7e. The character "&" (0x26) is represented by the two- + octet sequence "&-". + + All other characters (octet values 0x00-0x1f, 0x7f-0xff, and all + Unicode 16-bit octets) are represented in modified BASE64, with a + further modification from [UTF-7] that "," is used instead of "/". + Modified BASE64 MUST NOT be used to represent any printing US-ASCII + character which can represent itself. + + "&" is used to shift to modified BASE64 and "-" to shift back to US- + ASCII. All names start in US-ASCII, and MUST end in US-ASCII (that + is, a name that ends with a Unicode 16-bit octet MUST end with a "- + "). + + For example, here is a mailbox name which mixes English, Japanese, + and Chinese text: ~peter/mail/&ZeVnLIqe-/&U,BTFw- +""" + + +import binascii +import codecs + + +## encoding + +def modified_base64(s:str): + s = s.encode('utf-16be') # UTF-16, big-endian byte order + return binascii.b2a_base64(s).rstrip(b'\n=').replace(b'/', b',') + +def doB64(_in, r): + if _in: + r.append(b'&' + modified_base64(''.join(_in)) + b'-') + del _in[:] + +def encoder(s:str): + r = [] + _in = [] + for c in s: + ordC = ord(c) + if 0x20 <= ordC <= 0x25 or 0x27 <= ordC <= 0x7e: + doB64(_in, r) + r.append(c.encode()) + elif c == '&': + doB64(_in, r) + r.append(b'&-') + else: + _in.append(c) + doB64(_in, r) + return (b''.join(r), len(s)) + + +## decoding + +def modified_unbase64(s:bytes): + b = binascii.a2b_base64(s.replace(b',', b'/') + b'===') + return b.decode('utf-16be') + +def decoder(s:bytes): + r = [] + decode = bytearray() + for c in s: + if c == ord('&') and not decode: + decode.append(ord('&')) + elif c == ord('-') and decode: + if len(decode) == 1: + r.append('&') + else: + r.append(modified_unbase64(decode[1:])) + decode = bytearray() + elif decode: + decode.append(c) + else: + r.append(chr(c)) + if decode: + r.append(modified_unbase64(decode[1:])) + bin_str = ''.join(r) + return (bin_str, len(s)) + + +class StreamReader(codecs.StreamReader): + def decode(self, s, errors='strict'): + return decoder(s) + + +class StreamWriter(codecs.StreamWriter): + def decode(self, s, errors='strict'): + return encoder(s) + + +def imap4_utf_7(name): + if name == 'imap4-utf-7': + return (encoder, decoder, StreamReader, StreamWriter) + + +codecs.register(imap4_utf_7) + + +## testing methods + +def imapUTF7Encode(ust): + "Returns imap utf-7 encoded version of string" + return ust.encode('imap4-utf-7') + +def imapUTF7EncodeSequence(seq): + "Returns imap utf-7 encoded version of strings in sequence" + return [imapUTF7Encode(itm) for itm in seq] + + +def imapUTF7Decode(st): + "Returns utf7 encoded version of imap utf-7 string" + return st.decode('imap4-utf-7') + +def imapUTF7DecodeSequence(seq): + "Returns utf7 encoded version of imap utf-7 strings in sequence" + return [imapUTF7Decode(itm) for itm in seq] + + +def utf8Decode(st): + "Returns utf7 encoded version of imap utf-7 string" + return st.decode('utf-8') + + +def utf7SequenceToUTF8(seq): + "Returns utf7 encoded version of imap utf-7 strings in sequence" + return [itm.decode('imap4-utf-7').encode('utf-8') for itm in seq] + + +__all__ = [ 'imapUTF7Encode', 'imapUTF7Decode' ] + + +if __name__ == '__main__': + testdata = [ + (u'foo\r\n\nbar\n', b'foo&AA0ACgAK-bar&AAo-'), + (u'测试', b'&bUuL1Q-'), + (u'Hello 世界', b'Hello &ThZ1TA-') + ] + for s, e in testdata: + #assert s == decoder(encoder(s)[0])[0] + assert s == imapUTF7Decode(e) + assert e == imapUTF7Encode(s) + assert s == imapUTF7Decode(imapUTF7Encode(s)) + assert e == imapUTF7Encode(imapUTF7Decode(e)) + print("All tests passed!") + +# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/msvst_starlet.py b/python/msvst_starlet.py new file mode 100755 index 0000000..e534d3d --- /dev/null +++ b/python/msvst_starlet.py @@ -0,0 +1,646 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# References: +# [1] Jean-Luc Starck, Fionn Murtagh & Jalal M. Fadili +# Sparse Image and Signal Processing: Wavelets, Curvelets, Morphological Diversity +# Section 3.5, 6.6 +# +# Credits: +# [1] https://github.com/abrazhe/image-funcut/blob/master/imfun/atrous.py +# +# Aaron LI +# Created: 2016-03-17 +# Updated: 2016-04-22 +# +# ChangeLog: +# 2016-04-22: +# * Add argument "end-scale" to specifiy the end denoising scale +# * Check outfile existence first +# * Add argument "start-scale" to specifiy the start denoising scale +# * Fix a bug about "p_cutoff" when "comp" contains ALL False's +# * Show more verbose information/details +# 2016-04-20: +# * Add argparse and main() for scripting +# + +""" +Starlet wavelet transform, i.e., isotropic undecimated wavelet transform +(IUWT), or à trous wavelet transform. +And multi-scale variance stabling transform (MS-VST), which can be used +to effectively remove the Poisson noises. +""" + +__version__ = "0.2.5" +__date__ = "2016-04-22" + + +import sys +import os +import argparse +from datetime import datetime + +import numpy as np +import scipy as sp +from scipy import signal +from astropy.io import fits + + +class B3Spline: # {{{ + """ + B3-spline wavelet. + """ + # scaling function (phi) + dec_lo = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) / 16 + dec_hi = np.array([-1.0, -4.0, 10.0, -4.0, -1.0]) / 16 + rec_lo = np.array([0.0, 0.0, 1.0, 0.0, 0.0]) + rec_hi = np.array([0.0, 0.0, 1.0, 0.0, 0.0]) +# B3Spline }}} + + +class IUWT: # {{{ + """ + Isotropic undecimated wavelet transform. + """ + ## Decomposition filters list: + # a_{scale} = convole(a_0, filters[scale]) + # Note: the zero-th scale filter (i.e., delta function) is the first + # element, thus the array index is the same as the decomposition scale. + filters = [] + + phi = None # wavelet scaling function (2D) + level = 0 # number of transform level + decomposition = None # decomposed coefficients/images + reconstruction = None # reconstructed image + + # convolution boundary condition + boundary = "symm" + + def __init__(self, phi=B3Spline.dec_lo, level=None, boundary="symm", + data=None): + self.set_wavelet(phi=phi) + self.level = level + self.boundary = boundary + self.data = np.array(data) + + def reset(self): + """ + Reset the object attributes. + """ + self.data = None + self.phi = None + self.decomposition = None + self.reconstruction = None + self.level = 0 + self.filters = [] + self.boundary = "symm" + + def load_data(self, data): + self.reset() + self.data = np.array(data) + + def set_wavelet(self, phi): + self.reset() + phi = np.array(phi) + if phi.ndim == 1: + phi_ = phi.reshape(1, -1) + self.phi = np.dot(phi_.T, phi_) + elif phi.ndim == 2: + self.phi = phi + else: + raise ValueError("Invalid phi dimension") + + def calc_filters(self): + """ + Calculate the convolution filters of each scale. + Note: the zero-th scale filter (i.e., delta function) is the first + element, thus the array index is the same as the decomposition scale. + """ + self.filters = [] + # scale 0: delta function + h = np.array([[1]]) # NOTE: 2D + self.filters.append(h) + # scale 1 + h = self.phi[::-1, ::-1] + self.filters.append(h) + for scale in range(2, self.level+1): + h_up = self.zupsample(self.phi, order=scale-1) + h2 = signal.convolve2d(h_up[::-1, ::-1], h, mode="same", + boundary=self.boundary) + self.filters.append(h2) + + def transform(self, data, scale, boundary="symm"): + """ + Perform only one scale wavelet transform for the given data. + + return: + [ approx, detail ] + """ + self.decomposition = [] + approx = signal.convolve2d(data, self.filters[scale], + mode="same", boundary=self.boundary) + detail = data - approx + return [approx, detail] + + def decompose(self, level, boundary="symm"): + """ + Perform IUWT decomposition in the plain loop way. + The filters of each scale/level are calculated first, then the + approximations of each scale/level are calculated by convolving the + raw/finest image with these filters. + + return: + [ W_1, W_2, ..., W_n, A_n ] + n = level + W: wavelet details + A: approximation + """ + self.boundary = boundary + if self.level != level or self.filters == []: + self.level = level + self.calc_filters() + self.decomposition = [] + approx = self.data + for scale in range(1, level+1): + # approximation: + approx2 = signal.convolve2d(self.data, self.filters[scale], + mode="same", boundary=self.boundary) + # wavelet details: + w = approx - approx2 + self.decomposition.append(w) + if scale == level: + self.decomposition.append(approx2) + approx = approx2 + return self.decomposition + + def decompose_recursive(self, level, boundary="symm"): + """ + Perform the IUWT decomposition in the recursive way. + + return: + [ W_1, W_2, ..., W_n, A_n ] + n = level + W: wavelet details + A: approximation + """ + self.level = level + self.boundary = boundary + self.decomposition = self.__decompose(self.data, self.phi, level=level) + return self.decomposition + + def __decompose(self, data, phi, level): + """ + 2D IUWT decomposition (or stationary wavelet transform). + + This is a convolution version, where kernel is zero-upsampled + explicitly. Not fast. + + Parameters: + - level : level of decomposition + - phi : low-pass filter kernel + - boundary : boundary conditions (passed to scipy.signal.convolve2d, + 'symm' by default) + + Returns: + list of wavelet details + last approximation. Each element in + the list is an image of the same size as the input image. + """ + if level <= 0: + return data + shapecheck = map(lambda a,b:a>b, data.shape, phi.shape) + assert np.all(shapecheck) + # approximation: + approx = signal.convolve2d(data, phi[::-1, ::-1], mode="same", + boundary=self.boundary) + # wavelet details: + w = data - approx + phi_up = self.zupsample(phi, order=1) + shapecheck = map(lambda a,b:a>b, data.shape, phi_up.shape) + if level == 1: + return [w, approx] + elif not np.all(shapecheck): + print("Maximum allowed decomposition level reached", + file=sys.stderr) + return [w, approx] + else: + return [w] + self.__decompose(approx, phi_up, level-1) + + @staticmethod + def zupsample(data, order=1): + """ + Upsample data array by interleaving it with zero's. + + h{up_order: n}[l] = (1) h[l], if l % 2^n == 0; + (2) 0, otherwise + """ + shape = data.shape + new_shape = [ (2**order * (n-1) + 1) for n in shape ] + output = np.zeros(new_shape, dtype=data.dtype) + output[[ slice(None, None, 2**order) for d in shape ]] = data + return output + + def reconstruct(self, decomposition=None): + if decomposition is not None: + reconstruction = np.sum(decomposition, axis=0) + return reconstruction + else: + self.reconstruction = np.sum(self.decomposition, axis=0) + + def get_detail(self, scale): + """ + Get the wavelet detail coefficients of given scale. + Note: 1 <= scale <= level + """ + if scale < 1 or scale > self.level: + raise ValueError("Invalid scale") + return self.decomposition[scale-1] + + def get_approx(self): + """ + Get the approximation coefficients of the largest scale. + """ + return self.decomposition[-1] +# IUWT }}} + + +class IUWT_VST(IUWT): # {{{ + """ + IUWT with Multi-scale variance stabling transform. + + Refernce: + [1] Bo Zhang, Jalal M. Fadili & Jean-Luc Starck, + IEEE Trans. Image Processing, 17, 17, 2008 + """ + # VST coefficients and the corresponding asymptotic standard deviation + # of each scale. + vst_coef = [] + + def reset(self): + super(self.__class__, self).reset() + vst_coef = [] + + def __decompose(self): + raise AttributeError("No '__decompose' attribute") + + @staticmethod + def soft_threshold(data, threshold): + if isinstance(data, np.ndarray): + data_th = data.copy() + data_th[np.abs(data) <= threshold] = 0.0 + data_th[data > threshold] -= threshold + data_th[data < -threshold] += threshold + else: + data_th = data + if np.abs(data) <= threshold: + data_th = 0.0 + elif data > threshold: + data_th -= threshold + else: + data_th += threshold + return data_th + + def tau(self, k, scale): + """ + Helper function used in VST coefficients calculation. + """ + return np.sum(np.power(self.filters[scale], k)) + + def filters_product(self, scale1, scale2): + """ + Calculate the scalar product of the filters of two scales, + considering only the overlapped part. + Helper function used in VST coefficients calculation. + """ + if scale1 > scale2: + filter_big = self.filters[scale1] + filter_small = self.filters[scale2] + else: + filter_big = self.filters[scale2] + filter_small = self.filters[scale1] + # crop the big filter to match the size of the small filter + size_big = filter_big.shape + size_small = filter_small.shape + size_diff2 = list(map(lambda a,b: (a-b)//2, size_big, size_small)) + filter_big_crop = filter_big[ + size_diff2[0]:(size_big[0]-size_diff2[0]), + size_diff2[1]:(size_big[1]-size_diff2[1])] + assert(np.all(list(map(lambda a,b: a==b, + size_small, filter_big_crop.shape)))) + product = np.sum(filter_small * filter_big_crop) + return product + + def calc_vst_coef(self): + """ + Calculate the VST coefficients and the corresponding + asymptotic standard deviation of each scale, according to the + calculated filters of each scale/level. + """ + self.vst_coef = [] + for scale in range(self.level+1): + b = 2 * np.sqrt(np.abs(self.tau(1, scale)) / self.tau(2, scale)) + c = 7.0*self.tau(2, scale) / (8.0*self.tau(1, scale)) - \ + self.tau(3, scale) / (2.0*self.tau(2, scale)) + if scale == 0: + std = -1.0 + else: + std = np.sqrt((self.tau(2, scale-1) / \ + (4 * self.tau(1, scale-1)**2)) + \ + (self.tau(2, scale) / (4 * self.tau(1, scale)**2)) - \ + (self.filters_product(scale-1, scale) / \ + (2 * self.tau(1, scale-1) * self.tau(1, scale)))) + self.vst_coef.append({ "b": b, "c": c, "std": std }) + + def vst(self, data, scale, coupled=True): + """ + Perform variance stabling transform + + XXX: parameter `coupled' why?? + Credit: MSVST-V1.0/src/libmsvst/B3VSTAtrous.h + """ + self.vst_coupled = coupled + if self.vst_coef == []: + self.calc_vst_coef() + if coupled: + b = 1.0 + else: + b = self.vst_coef[scale]["b"] + data_vst = b * np.sqrt(np.abs(data + self.vst_coef[scale]["c"])) + return data_vst + + def ivst(self, data, scale, cbias=True): + """ + Inverse variance stabling transform + NOTE: assuming that `a_{j} + c^{j}' are all positive. + + XXX: parameter `cbias' why?? + `bias correction' is recommended while reconstruct the data + after estimation + Credit: MSVST-V1.0/src/libmsvst/B3VSTAtrous.h + """ + self.vst_cbias = cbias + if cbias: + cb = 1.0 / (self.vst_coef[scale]["b"] ** 2) + else: + cb = 0.0 + data_ivst = data ** 2 + cb - self.vst_coef[scale]["c"] + return data_ivst + + def is_significant(self, scale, fdr=0.1, independent=False, verbose=False): + """ + Multiple hypothesis testing with false discovery rate (FDR) control. + + `independent': whether the test statistics of all the null + hypotheses are independent. + If `independent=True': FDR <= (m0/m) * q + otherwise: FDR <= (m0/m) * q * (1 + 1/2 + 1/3 + ... + 1/m) + + References: + [1] False discovery rate - Wikipedia + https://en.wikipedia.org/wiki/False_discovery_rate + """ + coef = self.get_detail(scale) + std = self.vst_coef[scale]["std"] + pvalues = 2.0 * (1.0 - sp.stats.norm.cdf(np.abs(coef) / std)) + p_sorted = pvalues.flatten() + p_sorted.sort() + N = len(p_sorted) + if independent: + cn = 1.0 + else: + cn = np.sum(1.0 / np.arange(1, N+1)) + p_comp = fdr * np.arange(N) / (N * cn) + comp = (p_sorted < p_comp) + if np.sum(comp) == 0: + # `comp' contains ALL False + p_cutoff = 0.0 + else: + # cutoff p-value after FDR control/correction + p_cutoff = np.max(p_sorted[comp]) + sig = (pvalues <= p_cutoff) + if verbose: + print("std/sigma: %g, p_cutoff: %g" % (std, p_cutoff), + flush=True, file=sys.stderr) + return (sig, p_cutoff) + + def denoise(self, fdr=0.1, fdr_independent=False, start_scale=1, + end_scale=None, verbose=False): + """ + Denoise the wavelet coefficients by controlling FDR. + """ + self.fdr = fdr + self.fdr_indepent = fdr_independent + self.denoised = [] + # supports of significant coefficients of each scale + self.sig_supports = [None] # make index match the scale + self.p_cutoff = [None] + if verbose: + print("MSVST denosing ...", flush=True, file=sys.stderr) + for scale in range(1, self.level+1): + coef = self.get_detail(scale) + if verbose: + print("\tScale %d: " % scale, end="", + flush=True, file=sys.stderr) + if (scale < start_scale) or \ + ((end_scale is not None) and scale > end_scale): + if verbose: + print("skipped", flush=True, file=sys.stderr) + sig, p_cutoff = None, None + else: + sig, p_cutoff = self.is_significant(scale, fdr=fdr, + independent=fdr_independent, verbose=verbose) + coef[np.logical_not(sig)] = 0.0 + # + self.denoised.append(coef) + self.sig_supports.append(sig) + self.p_cutoff.append(p_cutoff) + # append the last approximation + self.denoised.append(self.get_approx()) + + def decompose(self, level=5, boundary="symm", verbose=False): + """ + 2D IUWT decomposition with VST. + """ + self.boundary = boundary + if self.level != level or self.filters == []: + self.level = level + self.calc_filters() + self.calc_vst_coef() + self.decomposition = [] + approx = self.data + if verbose: + print("IUWT decomposing (%d levels): " % level, + end="", flush=True, file=sys.stderr) + for scale in range(1, level+1): + if verbose: + print("%d..." % scale, end="", flush=True, file=sys.stderr) + # approximation: + approx2 = signal.convolve2d(self.data, self.filters[scale], + mode="same", boundary=self.boundary) + # wavelet details: + w = self.vst(approx, scale=scale-1) - self.vst(approx2, scale=scale) + self.decomposition.append(w) + if scale == level: + self.decomposition.append(approx2) + approx = approx2 + if verbose: + print("DONE!", flush=True, file=sys.stderr) + return self.decomposition + + def reconstruct_ivst(self, denoised=True, positive_project=True): + """ + Reconstruct the original image from the *un-denoised* decomposition + by applying the inverse VST. + + This reconstruction result is also used as the `initial condition' + for the below `iterative reconstruction' algorithm. + + arguments: + * denoised: whether use th denoised data or the direct decomposition + * positive_project: whether replace negative values with zeros + """ + if denoised: + decomposition = self.denoised + else: + decomposition = self.decomposition + self.positive_project = positive_project + details = np.sum(decomposition[:-1], axis=0) + approx = self.vst(decomposition[-1], scale=self.level) + reconstruction = self.ivst(approx+details, scale=0) + if positive_project: + reconstruction[reconstruction < 0.0] = 0.0 + self.reconstruction = reconstruction + return reconstruction + + def reconstruct(self, denoised=True, niter=10, verbose=False): + """ + Reconstruct the original image using iterative method with + L1 regularization, because the denoising violates the exact inverse + procedure. + + arguments: + * denoised: whether use the denoised coefficients + * niter: number of iterations + """ + if denoised: + decomposition = self.denoised + else: + decomposition = self.decomposition + # L1 regularization + lbd = 1.0 + delta = lbd / (niter - 1) + # initial solution + solution = self.reconstruct_ivst(denoised=denoised, + positive_project=True) + # + iuwt = IUWT(level=self.level) + iuwt.calc_filters() + # iterative reconstruction + if verbose: + print("Iteratively reconstructing (%d times): " % niter, + end="", flush=True, file=sys.stderr) + for i in range(niter): + if verbose: + print("%d..." % i, end="", flush=True, file=sys.stderr) + tempd = self.data.copy() + solution_decomp = [] + for scale in range(1, self.level+1): + approx, detail = iuwt.transform(tempd, scale) + approx_sol, detail_sol = iuwt.transform(solution, scale) + # Update coefficients according to the significant supports, + # which are acquired during the denosing precodure with FDR. + sig = self.sig_supports[scale] + detail_sol[sig] = detail[sig] + detail_sol = self.soft_threshold(detail_sol, threshold=lbd) + # + solution_decomp.append(detail_sol) + tempd = approx.copy() + solution = approx_sol.copy() + # last approximation (the two are the same) + solution_decomp.append(approx) + # reconstruct + solution = iuwt.reconstruct(decomposition=solution_decomp) + # discard all negative values + solution[solution < 0] = 0.0 + # + lbd -= delta + if verbose: + print("DONE!", flush=True, file=sys.stderr) + # + self.reconstruction = solution + return self.reconstruction +# IUWT_VST }}} + + +def main(): + # commandline arguments parser + parser = argparse.ArgumentParser( + description="Poisson Noise Removal with Multi-scale Variance " + \ + "Stabling Transform and Wavelet Transform", + epilog="Version: %s (%s)" % (__version__, __date__)) + parser.add_argument("-l", "--level", dest="level", + type=int, default=5, + help="level of the IUWT decomposition") + parser.add_argument("-r", "--fdr", dest="fdr", + type=float, default=0.1, + help="false discovery rate") + parser.add_argument("-I", "--fdr-independent", dest="fdr_independent", + action="store_true", default=False, + help="whether the FDR null hypotheses are independent") + parser.add_argument("-s", "--start-scale", dest="start_scale", + type=int, default=1, + help="which scale to start the denoising (inclusive)") + parser.add_argument("-e", "--end-scale", dest="end_scale", + type=int, default=0, + help="which scale to end the denoising (inclusive)") + parser.add_argument("-n", "--niter", dest="niter", + type=int, default=10, + help="number of iterations for reconstruction") + parser.add_argument("-v", "--verbose", dest="verbose", + action="store_true", default=False, + help="show verbose progress") + parser.add_argument("-C", "--clobber", dest="clobber", + action="store_true", default=False, + help="overwrite output file if exists") + parser.add_argument("infile", help="input image with Poisson noises") + parser.add_argument("outfile", help="output denoised image") + args = parser.parse_args() + + if args.end_scale == 0: + args.end_scale = args.level + + if args.verbose: + print("infile: '%s'" % args.infile, file=sys.stderr) + print("outfile: '%s'" % args.outfile, file=sys.stderr) + print("level: %d" % args.level, file=sys.stderr) + print("fdr: %.2f" % args.fdr, file=sys.stderr) + print("fdr_independent: %s" % args.fdr_independent, file=sys.stderr) + print("start_scale: %d" % args.start_scale, file=sys.stderr) + print("end_scale: %d" % args.end_scale, file=sys.stderr) + print("niter: %d\n" % args.niter, flush=True, file=sys.stderr) + + if not args.clobber and os.path.exists(args.outfile): + raise OSError("outfile '%s' already exists" % args.outfile) + + imgfits = fits.open(args.infile) + img = imgfits[0].data + # Remove Poisson noises + msvst = IUWT_VST(data=img) + msvst.decompose(level=args.level, verbose=args.verbose) + msvst.denoise(fdr=args.fdr, fdr_independent=args.fdr_independent, + start_scale=args.start_scale, end_scale=args.end_scale, + verbose=args.verbose) + msvst.reconstruct(denoised=True, niter=args.niter, verbose=args.verbose) + img_denoised = msvst.reconstruction + # Output + imgfits[0].data = img_denoised + imgfits[0].header.add_history("%s: Removed Poisson Noises @ %s" % ( + os.path.basename(sys.argv[0]), datetime.utcnow().isoformat())) + imgfits[0].header.add_history(" TOOL: %s (v%s, %s)" % ( + os.path.basename(sys.argv[0]), __version__, __date__)) + imgfits[0].header.add_history(" PARAM: %s" % " ".join(sys.argv[1:])) + imgfits.writeto(args.outfile, checksum=True, clobber=args.clobber) + + +if __name__ == "__main__": + main() + diff --git a/python/plot.py b/python/plot.py new file mode 100644 index 0000000..b65f8a3 --- /dev/null +++ b/python/plot.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# +# Credits: http://www.aosabook.org/en/matplotlib.html +# +# Aaron LI +# 2016-03-14 +# + +# Import the FigureCanvas from the backend of your choice +# and attach the Figure artist to it. +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure +fig = Figure() +canvas = FigureCanvas(fig) + +# Import the numpy library to generate the random numbers. +import numpy as np +x = np.random.randn(10000) + +# Now use a figure method to create an Axes artist; the Axes artist is +# added automatically to the figure container fig.axes. +# Here "111" is from the MATLAB convention: create a grid with 1 row and 1 +# column, and use the first cell in that grid for the location of the new +# Axes. +ax = fig.add_subplot(111) + +# Call the Axes method hist to generate the histogram; hist creates a +# sequence of Rectangle artists for each histogram bar and adds them +# to the Axes container. Here "100" means create 100 bins. +ax.hist(x, 100) + +# Decorate the figure with a title and save it. +ax.set_title('Normal distribution with $\mu=0, \sigma=1$') +fig.savefig('matplotlib_histogram.png') + diff --git a/python/plot_tprofiles_zzh.py b/python/plot_tprofiles_zzh.py new file mode 100644 index 0000000..e5824e9 --- /dev/null +++ b/python/plot_tprofiles_zzh.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# +# Weitian LI +# 2015-09-11 +# + +""" +Plot a list of *temperature profiles* in a grid of subplots with Matplotlib. +""" + +import matplotlib.pyplot as plt + + +def plot_tprofiles(tplist, nrows, ncols, + xlim=None, ylim=None, logx=False, logy=False, + xlab="", ylab="", title=""): + """ + Plot a list of *temperature profiles* in a grid of subplots of size + nrow x ncol. Each subplot is related to a temperature profile. + All the subplots share the same X and Y axes. + The order is by row. + + The tplist is a list of dictionaries, each of which contains all the + necessary data to make the subplot. + + The dictionary consists of the following components: + tpdat = { + "name": "NAME", + "radius": [[radius points], [radius errors]], + "temperature": [[temperature points], [temperature errors]], + "radius_model": [radus points of the fitted model], + "temperature_model": [ + [fitted model value], + [lower bounds given by the model], + [upper bounds given by the model] + ] + } + + Arguments: + tplist - a list of dictionaries containing the data of each + temperature profile. + Note that the length of this list should equal to nrows*ncols. + nrows - number of rows of the subplots + ncols - number of columns of the subplots + xlim - limits of the X axis + ylim - limits of the Y axis + logx - whether to set the log scale for X axis + logy - whether to set the log scale for Y axis + xlab - label for the X axis + ylab - label for the Y axis + title - title for the whole plot + """ + assert len(tplist) == nrows*ncols, "tplist length != nrows*ncols" + # All subplots share both X and Y axes. + fig, axarr = plt.subplots(nrows, ncols, sharex=True, sharey=True) + # Set title for the whole plot. + if title != "": + fig.suptitle(title) + # Set xlab and ylab for each subplot + if xlab != "": + for ax in axarr[-1, :]: + ax.set_xlabel(xlab) + if ylab != "": + for ax in axarr[:, 0]: + ax.set_ylabel(ylab) + for ax in axarr.reshape(-1): + # Set xlim and ylim. + if xlim is not None: + ax.set_xlim(xlim) + if ylim is not None: + ax.set_ylim(ylim) + # Set xscale and yscale. + if logx: + ax.set_xscale("log", nonposx="clip") + if logy: + ax.set_yscale("log", nonposy="clip") + # Decrease the spacing between the subplots and suptitle + fig.subplots_adjust(top=0.94) + # Eleminate the spaces between each row and column. + fig.subplots_adjust(hspace=0, wspace=0) + # Hide X ticks for all subplots but the bottom row. + plt.setp([ax.get_xticklabels() for ax in axarr[:-1, :].reshape(-1)], + visible=False) + # Hide Y ticks for all subplots but the left column. + plt.setp([ax.get_yticklabels() for ax in axarr[:, 1:].reshape(-1)], + visible=False) + # Plot each temperature profile in the tplist + for i, ax in zip(range(len(tplist)), axarr.reshape(-1)): + tpdat = tplist[i] + # Add text to display the name. + # The text is placed at (0.95, 0.95), i.e., the top-right corner, + # with respect to this subplot, and the top-right part of the text + # is aligned to the above position. + ax_pois = ax.get_position() + ax.text(0.95, 0.95, tpdat["name"], + verticalalignment="top", horizontalalignment="right", + transform=ax.transAxes, color="black", fontsize=10) + # Plot data points + if isinstance(tpdat["radius"][0], list) and \ + len(tpdat["radius"]) == 2 and \ + isinstance(tpdat["temperature"][0], list) and \ + len(tpdat["temperature"]) == 2: + # Data points have symmetric errorbar + ax.errorbar(tpdat["radius"][0], tpdat["temperature"][0], + xerr=tpdat["radius"][1], yerr=tpdat["temperature"][1], + color="black", linewidth=1.5, linestyle="None") + else: + ax.plot(tpdat["radius"], tpdat["temperature"], + color="black", linewidth=1.5, linestyle="None") + # Plot model line and bounds band + if isinstance(tpdat["temperature_model"][0], list) and \ + len(tpdat["temperature_model"]) == 3: + # Model data have bounds + ax.plot(tpdat["radius_model"], tpdat["temperature_model"][0], + color="blue", linewidth=1.0) + # Plot model bounds band + ax.fill_between(tpdat["radius_model"], + y1=tpdat["temperature_model"][1], + y2=tpdat["temperature_model"][2], + color="gray", alpha=0.5) + else: + ax.plot(tpdat["radius_model"], tpdat["temperature_model"], + color="blue", linewidth=1.5) + return (fig, axarr) + +# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/randomize_events.py b/python/randomize_events.py new file mode 100755 index 0000000..e1a6e31 --- /dev/null +++ b/python/randomize_events.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# +# Randomize the (X,Y) position of each X-ray photon events according +# to a Gaussian distribution of given sigma. +# +# References: +# [1] G. Scheellenberger, T.H. Reiprich, L. Lovisari, J. Nevalainen & L. David +# 2015, A&A, 575, A30 +# +# +# Aaron LI +# Created: 2016-03-24 +# Updated: 2016-03-24 +# + +from astropy.io import fits +import numpy as np + +import os +import sys +import datetime +import argparse + + +CHANDRA_ARCSEC_PER_PIXEL = 0.492 + +def randomize_events(infile, outfile, sigma, clobber=False): + """ + Randomize the position (X,Y) of each X-ray event according to a + specified size/sigma Gaussian distribution. + """ + sigma_pix = sigma / CHANDRA_ARCSEC_PER_PIXEL + evt_fits = fits.open(infile) + evt_table = evt_fits[1].data + # (X,Y) physical coordinate + evt_x = evt_table["x"] + evt_y = evt_table["y"] + rand_x = np.random.normal(scale=sigma_pix, size=evt_x.shape)\ + .astype(evt_x.dtype) + rand_y = np.random.normal(scale=sigma_pix, size=evt_y.shape)\ + .astype(evt_y.dtype) + evt_x += rand_x + evt_y += rand_y + # Add history to FITS header + evt_hdr = evt_fits[1].header + evt_hdr.add_history("TOOL: %s @ %s" % ( + os.path.basename(sys.argv[0]), + datetime.datetime.utcnow().isoformat())) + evt_hdr.add_history("COMMAND: %s" % " ".join(sys.argv)) + evt_fits.writeto(outfile, clobber=clobber, checksum=True) + + +def main(): + parser = argparse.ArgumentParser( + description="Randomize the (X,Y) of each X-ray event") + parser.add_argument("infile", help="input event file") + parser.add_argument("outfile", help="output randomized event file") + parser.add_argument("-s", "--sigma", dest="sigma", + required=True, type=float, + help="sigma/size of the Gaussian distribution used" + \ + "to randomize the position of events (unit: arcsec)") + parser.add_argument("-C", "--clobber", dest="clobber", + action="store_true", help="overwrite output file if exists") + args = parser.parse_args() + + randomize_events(args.infile, args.outfile, + sigma=args.sigma, clobber=args.clobber) + + +if __name__ == "__main__": + main() + diff --git a/python/rebuild_ipod_db.py b/python/rebuild_ipod_db.py new file mode 100755 index 0000000..20c5454 --- /dev/null +++ b/python/rebuild_ipod_db.py @@ -0,0 +1,595 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# LICENSE: +# --------------------------------------------------------------------------- +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# --------------------------------------------------------------------------- +# +# Based on Matrin Fiedler's "rebuild_db.py" v1.0-rc1 (2006-04-26): +# http://shuffle-db.sourceforge.net/ +# + +from __future__ import print_function + + +__title__ = "iPod Shuffle Database Builder" +__author__ = "Aaron LI" +__version__ = "2.0.2" +__date__ = "2016-04-16" + + +import sys +import os +import operator +import array +import random +import fnmatch +import operator +import string +import argparse +import functools +import shutil +from collections import OrderedDict + + +domains = [] +total_count = 0 + + +class LogObj: + """ + Print and log the process information. + """ + def __init__(self, filename=None): + self.filename = filename + + def open(self): + if self.filename: + try: + self.logfile = open(self.filename, "w") + except IOError: + self.logfile = None + else: + self.logfile = None + + def log(self, line="", end="\n"): + value = line + end + if self.logfile: + self.logfile.write(value) + print(value, end="") + + def close(self): + if self.logfile: + self.logfile.close() + + +class Rule: + """ + A RuleSet for the way to handle the found playable files. + """ + SUPPORT_PROPS = ("filename", "size", "ignore", "type", + "shuffle", "reuse", "bookmark") + + def __init__(self, conditions=None, actions=None): + self.conditions = conditions + self.actions = actions + + @classmethod + def parse(cls, rule): + """ + Parse the whole line of a rule. + + Syntax: + condition1, condition2, ...: action1, action2, ... + + condition examples: + * filename ~ "*.mp3" + * size > 100000 + action examples: + * ignore = 1 + * shuffle = 1 + + Return: a object of this class with the parsed rule. + """ + conditions, actions = rule.split(":") + conditions = list(map(cls.parse_condition, conditions.split(","))) + actions = dict(map(cls.parse_action, actions.split(","))) + return cls(conditions, actions) + + @classmethod + def parse_condition(cls, cond): + sep_pos = min([ cond.find(sep) for sep in "~=<>" \ + if cond.find(sep)>0 ]) + prop = cond[:sep_pos].strip() + if prop not in cls.SUPPORT_PROPS: + raise ValueError("WARNING: unknown property '%s'" % prop) + return (prop, cond[sep_pos], + cls.parse_value(cond[sep_pos+1:].strip())) + + @classmethod + def parse_action(cls, action): + prop, value = map(str.strip, action.split("=", 1)) + if prop not in cls.SUPPORT_PROPS: + raise ValueError("WARNING: unknown property '%s'" % prop) + return (prop, cls.parse_value(value)) + + @staticmethod + def parse_value(value): + value = value.strip().strip('"').strip("'") + try: + return int(value) + except ValueError: + return value + + def match(self, props): + """ + Check whether the given props match all the conditions. + """ + def match_condition(props, cond): + """ + Check whether the given props match the given condition. + """ + try: + prop, op, ref = props[cond[0]], cond[1], cond[2] + except KeyError: + return False + if op == "~": + return fnmatch.fnmatchcase(prop.lower(), ref.lower()) + elif op == "=": + return prop == ref + elif op == ">": + return prop > ref + elif op == "<": + return prop < ref + else: + return False + # + return functools.reduce(operator.and_, + [ match_condition(props, cond) \ + for cond in self.conditions ], + True) + + +class Entries: + """ + Walk through the directory to find all files, and filter by the + extensions to get all the playable files. + """ + PLAYABLE_EXTS = (".mp3", ".m4a", ".m4b", ".m4p", ".aa", ".wav") + + def __init__(self, dirs=[], rename=True, recursive=True, ignore_dup=True): + self.entries = [] + self.add_dirs(dirs=dirs, rename=rename, recursive=recursive, + ignore_dup=ignore_dup) + + def add_dirs(self, dirs=[], rename=True, recursive=True, ignore_dup=True): + for dir in dirs: + self.add_dir(dir=dir, rename=rename, recursive=recursive, + ignore_dup=ignore_dup) + + def add_dir(self, dir, rename=True, recursive=True, ignore_dup=True): + global logobj + if recursive: + # Get all directories, and rename them if needed + dirs = [] + for dirName, subdirList, fileList in os.walk(dir): + dirs.append(dirName) + for dirName in dirs: + newDirName = self.get_newname(dirName) + if rename and newDirName != dirName: + logobj.log("Rename: '%s' -> '%s'" % (dirName, newDirName)) + shutil.move(dirName, newDirName) + # Get all files + files = [] + for dirName, subdirList, fileList in os.walk(dir): + files.extend([ os.path.join(dirName, f) for f in fileList ]) + else: + # rename the directory if needed + newDir = self.get_newname(dir) + if rename and newDir != dir: + logobj.log("Rename: '%s' -> '%s'" % (dir, newDir)) + shutil.move(dir, newDir) + files = [ os.path.join(newDir, f) for f in self.listfiles(newDir) ] + # + for fn in files: + # rename filename if needed + newfn = self.get_newname(fn) + if rename and newfn != fn: + logobj.log("Rename: '%s' -> '%s'" % (fn, newfn)) + shutil.move(fn, newfn) + fn = newfn + # filter by playable extensions + if os.path.splitext(fn)[1].lower() not in self.PLAYABLE_EXTS: + continue + if ignore_dup and (fn in self.entries): + continue + self.entries.append(fn) + print("Entry: %s" % fn) + + @staticmethod + def listfiles(path, ignore_hidden=True): + """ + List only files of a directory + """ + for f in os.listdir(path): + if os.path.isfile(os.path.join(path, f)): + if ignore_hidden and f[0] != ".": + yield f + else: + yield f + + @staticmethod + def get_newname(path): + def conv_char(ch): + safe_char = string.ascii_letters + string.digits + "-_" + if ch in safe_char: + return ch + return "_" + # + if path == ".": + return path + dirname, basename = os.path.split(path) + base, ext = os.path.splitext(basename) + newbase = "".join(map(conv_char, base)) + if basename == newbase+ext: + return os.path.join(dirname, basename) + if os.path.exists("%s/%s%s" % (dirname, newbase, ext)): + i = 0 + while os.path.exists("%s/%s_%d%s" % (dirname, newbase, i, ext)): + i += 1 + newbase += "_%d" % i + newname = "%s/%s%s" % (dirname, newbase, ext) + return newname + + def fix_and_sort(self): + """ + Fix the entries' pathes (should starts with "/"), and sort. + """ + self.entries = [ "/"+f.lstrip("./") for f in self.entries ] + self.entries.sort() + + def apply_rules(self, rules): + """ + Apply rules to the found entries. + The filtered/updated entries and properties are saved in: + 'self.entries_dict' + """ + self.entries_dict = OrderedDict() + + for fn in self.entries: + # set default properties + props = { + "filename": fn, + "size": os.stat(fn[1:]).st_size, + "ignore": 0, + "type": 1, + "shuffle": 1, + "bookmark": 0 + } + # check and apply rules + for rule in rules: + if rule.match(props): + props.update(rule.actions) + # + if props["ignore"]: + continue + # + self.entries_dict[fn] = props + + def get_entries(self): + return self.entries_dict.items() + + +class iTunesSD: + """ + Class to handle the iPod Shuffle main database + "iPod_Control/iTunes/iTunesSD" + """ + def __init__(self, dbfile="./iPod_Control/iTunes/iTunesSD"): + self.dbfile = dbfile + self.load() + + def load(self): + """ + Load original header and entries. + """ + self.old_entries = {} + self.header_main = array.array("B") # unsigned integer array + self.header_entry = array.array("B") # unsigned integer array + db = open(self.dbfile, "rb") + try: + self.header_main.fromfile(db, 18) + self.header_entry.fromfile(db, 33) + db.seek(18) + entry = db.read(558) + while len(entry) == 558: + filename = entry[33::2].split(b"\0", 1)[0] + self.old_entries[filename] = entry + entry = db.read(558) + except EOFError: + pass + db.close() + print("Loaded %d entries from existing database" % \ + len(self.old_entries)) + + def build_header(self, force=False): + global logobj + # rebuild database header + if force or len(self.header_main) != 18: + logobj.log("Rebuild iTunesSD main header ...") + del self.header_main[:] + self.header_main.fromlist([0,0,0,1,6,0,0,0,18] + [0]*9) + if force or len(self.header_entry) != 33: + logobj.log("Rebuild iTunesSD entry header ...") + del self.header_entry[:] + self.header_entry.fromlist([0,2,46,90,165,1] + [0]*20 + \ + [100,0,0,1,0,2,0]) + + def add_entries(self, entries, reuse=True): + """ + Prepare the entries for database + """ + self.entries = OrderedDict() + + for fn, props in entries.get_entries(): + if reuse and props.get("reuse") and (fn in self.old_entries): + # retrieve entry from old entries + entry = self.old_entries[fn] + else: + # build new entry + self.header_entry[29] = props["type"] + entry_data = "".join([ c+"\0" for c in fn[:261] ]) + \ + "\0"*(558 - len(self.header_entry) - 2*len(fn)) + entry = self.header_entry.tostring() + \ + entry_data.encode("utf-8") + # modify the shuffle and bookmark flags + entry = entry[:555] + chr(props["shuffle"]).encode("utf-8") + \ + chr(props["bookmark"]).encode("utf-8") + entry[557] + # + self.entries[fn] = entry + + def write(self, dbfile=None): + if dbfile is None: + dbfile = self.dbfile + # Make a backup + if os.path.exists(dbfile): + shutil.copy2(dbfile, dbfile+"_bak") + + # write main database file + with open(dbfile, "wb") as db: + self.header_main.tofile(db) + for entry in self.entries.values(): + db.write(entry) + # Update database header + num_entries = len(self.entries) + db.seek(0) + db.write(b"\0%c%c" % (num_entries>>8, num_entries&0xFF)) + + +class iTunesPState: + """ + iPod Shuffle playback state database: "iPod_Control/iTunes/iTunesPState" + """ + def __init__(self, dbfile="iPod_Control/iTunes/iTunesPState"): + self.dbfile = dbfile + self.load() + + def load(self): + with open(self.dbfile, "rb") as db: + a = array.array("B") + a.fromstring(db.read()) + self.PState = a.tolist() + + def update(self, volume=None): + if len(self.PState) != 21: + # volume 29, FW ver 1.0 + self.PState = self.listval(29) + [0]*15 + self.listval(1) + # track 0, shuffle mode, start of track + self.PState[3:15] = [0]*6 + [1] + [0]*5 + if volume is not None: + self.PState[:3] = self.listval(volume) + + def write(self, dbfile=None): + if dbfile is None: + dbfile = self.dbfile + # Make a backup + if os.path.exists(dbfile): + shutil.copy2(dbfile, dbfile+"_bak") + + with open(dbfile, "wb") as db: + array.array("B", self.PState).tofile(db) + + @staticmethod + def listval(i): + if i < 0: + i += 0x1000000 + return [i&0xFF, (i>>8)&0xFF, (i>>16)&0xFF] + + +class iTunesStats: + """ + iPod Shuffle statistics database: "iPod_Control/iTunes/iTunesStats" + """ + def __init__(self, dbfile="iPod_Control/iTunes/iTunesStats"): + self.dbfile = dbfile + + def write(self, count, dbfile=None): + if dbfile is None: + dbfile = self.dbfile + # Make a backup + if os.path.exists(dbfile): + shutil.copy2(dbfile, dbfile+"_bak") + + with open(dbfile, "wb") as db: + data = self.stringval(count) + "\0"*3 + \ + (self.stringval(18) + "\xff"*3 + "\0"*12) * count + db.write(data.encode("utf-8")) + + @staticmethod + def stringval(i): + if i < 0: + i += 0x1000000 + return "%c%c%c" % (i&0xFF, (i>>8)&0xFF, (i>>16)&0xFF) + + +class iTunesShuffle: + """ + iPod shuffle database: "iPod_Control/iTunes/iTunesShuffle" + """ + def __init__(self, dbfile="iPod_Control/iTunes/iTunesShuffle"): + self.dbfile = dbfile + + def shuffle(self, entries): + """ + Generate the shuffle sequences for the entries, and take care + of the "shuffle" property. + """ + shuffle_prop = [ props["shuffle"] + for fn, props in entries.get_entries() ] + shuffle_idx = [ idx for idx, s in enumerate(shuffle_prop) if s == 1 ] + shuffled = shuffle_idx.copy() + random.seed() + random.shuffle(shuffled) + shuffle_seq = list(range(len(shuffle_prop))) + for i, idx in enumerate(shuffle_idx): + shuffle_seq[idx] = shuffled[i] + self.shuffle_seq = shuffle_seq + + def write(self, dbfile=None): + if dbfile is None: + dbfile = self.dbfile + # Make a backup + if os.path.exists(dbfile): + shutil.copy2(dbfile, dbfile+"_bak") + + with open(dbfile, "wb") as db: + data = "".join(map(iTunesStats.stringval, self.shuffle_seq)) + db.write(data.encode("utf-8")) + + +def main(): + prog_basename = os.path.splitext(os.path.basename(sys.argv[0]))[0] + + # command line arguments + parser = argparse.ArgumentParser( + description="Rebuild iPod Shuffle Database", + epilog="Version: %s (%s)\n\n" % (__version__, __date__) + \ + "Only 1st and 2nd iPod Shuffle supported!\n\n" + \ + "The script must be placed under the iPod's root directory") + parser.add_argument("-f", "--force", dest="force", action="store_true", + help="always rebuild database entries, do NOT reuse old ones") + parser.add_argument("-M", "--no-rename", dest="norename", + action="store_false", default=True, + help="do NOT rename files") + parser.add_argument("-V", "--volume", dest="volume", type=int, + help="set playback volume (0 - 38)") + parser.add_argument("-r", "--rulesfile", dest="rulesfile", + default="%s.rules" % prog_basename, + help="additional rules filename") + parser.add_argument("-l", "--logfile", dest="logfile", + default="%s.log" % prog_basename, + help="log output filename") + parser.add_argument("dirs", nargs="*", + help="directories to be searched for playable files") + args = parser.parse_args() + + flag_reuse = not args.force + + # Start logging + global logobj + logobj = LogObj(args.logfile) + logobj.open() + + # Rules for how to handle the found playable files + rules = [] + # Add default rules + rules.append(Rule(conditions=[("filename", "~", "*.mp3")], + actions={"type":1, "shuffle":1, "bookmark":0})) + rules.append(Rule(conditions=[("filename", "~", "*.m4?")], + actions={"type":2, "shuffle":1, "bookmark":0})) + rules.append(Rule(conditions=[("filename", "~", "*.m4b")], + actions={"shuffle":0, "bookmark":1})) + rules.append(Rule(conditions=[("filename", "~", "*.aa")], + actions={"type":1, "shuffle":0, "bookmark":1, "reuse":1})) + rules.append(Rule(conditions=[("filename", "~", "*.wav")], + actions={"type":4, "shuffle":0, "bookmark":0})) + rules.append(Rule(conditions=[("filename", "~", "*.book.???")], + actions={"shuffle":0, "bookmark":1})) + rules.append(Rule(conditions=[("filename", "~", "*.announce.???")], + actions={"shuffle":0, "bookmark":0})) + rules.append(Rule(conditions=[("filename", "~", "/backup/*")], + actions={"ignore":1})) + # Load additional rules + try: + for line in open(args.rulesfile, "r").readlines(): + rules.append(Rule.parse(line)) + logobj.log("Loaded additional rules from file: %s" % args.rulesfile) + except IOError: + pass + + # cd to the directory of this script + os.chdir(os.path.dirname(sys.argv[0])) + + if not os.path.isdir("iPod_Control/iTunes"): + logobj.log("ERROR: No iPod control directory found!") + logobj.log("Please make sure that:") + logobj.log("(*) this script is placed under the iPod's root directory") + logobj.log("(*) the iPod was correctly initialized with iTunes") + sys.exit(1) + + # playable entries + logobj.log("Search for playable entries ...") + entries = Entries() + if args.dirs: + for dir in args.dirs: + entries.add_dir(dir=dir, recursive=True, rename=args.norename) + else: + entries.add_dir(".", recursive=True, rename=args.norename) + entries.fix_and_sort() + logobj.log("Apply rules to entries ...") + entries.apply_rules(rules=rules) + + # read main database file + logobj.log("Update main database ...") + db = iTunesSD(dbfile="iPod_Control/iTunes/iTunesSD") + db.build_header(force=args.force) + db.add_entries(entries=entries, reuse=flag_reuse) + assert len(db.entries) == len(entries.get_entries()) + db.write() + logobj.log("Added %d entries ..." % len(db.entries)) + + # other misc databases + logobj.log("Update playback state database ...") + db_pstate = iTunesPState(dbfile="iPod_Control/iTunes/iTunesPState") + db_pstate.update(volume=args.volume) + db_pstate.write() + logobj.log("Update statistics database ...") + db_stats = iTunesStats(dbfile="iPod_Control/iTunes/iTunesStats") + db_stats.write(count=len(db.entries)) + logobj.log("Update shuffle database ...") + db_shuffle = iTunesShuffle(dbfile="iPod_Control/iTunes/iTunesShuffle") + db_shuffle.shuffle(entries=entries) + db_shuffle.write() + + logobj.log("The iPod Shuffle database was rebuilt successfully!") + + logobj.close() + + +if __name__ == "__main__": + main() + +# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/splitBoxRegion.py b/python/splitBoxRegion.py new file mode 100755 index 0000000..5254686 --- /dev/null +++ b/python/splitBoxRegion.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8- +# +# Split the strip-shaped CCD gaps regions into a series of small +# square regions, which are used as the input regions of 'roi' to +# determine the corresponding background regions, and finally providied +# to 'dmfilth' in order to fill in the CCD gaps. +# +# Aaron LI +# 2015/08/12 +# +# Changelogs: +# v0.1.0, 2015/08/12 +# * initial version +# + + +__version__ = "0.1.0" +__date__ = "2015/08/12" + + +import os +import sys +import re +import math +import argparse +from io import TextIOWrapper + + +## BoxRegion {{{ +class BoxRegion(object): + """ + CIAO/DS9 "rotbox"/"box" region class. + + rotbox/box format: + rotbox(xc, yc, width, height, rotation) + box(xc, yc, width, height, rotation) + Notes: + rotation: [0, 360) (degree) + """ + def __init__(self, xc=None, yc=None, + width=None, height=None, rotation=None): + self.regtype = "rotbox" + self.xc = xc + self.yc = yc + self.width = width + self.height = height + self.rotation = rotation + + def __str__(self): + return "%s(%s,%s,%s,%s,%s)" % (self.regtype, self.xc, self.yc, + self.width, self.height, self.rotation) + + @classmethod + def parse(cls, regstr): + """ + Parse region string. + """ + regex_box = re.compile(r'^\s*(box|rotbox)\(([0-9. ]+),([0-9. ]+),([0-9. ]+),([0-9. ]+),([0-9. ]+)\)\s*$', re.I) + m = regex_box.match(regstr) + if m: + regtype = m.group(1) + xc = float(m.group(2)) + yc = float(m.group(3)) + width = float(m.group(4)) + height = float(m.group(5)) + rotation = float(m.group(6)) + return cls(xc, yc, width, height, rotation) + else: + return None + + def split(self, filename=None): + """ + Split strip-shaped box region into a series small square regions. + """ + angle = self.rotation * math.pi / 180.0 + # to record the center coordinates of each split region + centers = [] + if self.width > self.height: + # number of regions after split + nreg = math.ceil(self.width / self.height) + # width & height of the split region + width = self.width / nreg + height = self.height + # position of the left-most region + x_l = self.xc - 0.5*self.width * math.cos(angle) + y_l = self.yc - 0.5*self.width * math.sin(angle) + for i in range(nreg): + x = x_l + (0.5 + i) * width * math.cos(angle) + y = y_l + (0.5 + i) * width * math.sin(angle) + centers.append((x, y)) + else: + # number of regions after split + nreg = math.ceil(self.height / self.width) + # width & height of the split region + width = self.width + height = self.height / nreg + # position of the left-most region + x_l = self.xc + 0.5*self.height * math.cos(angle + math.pi/2) + y_l = self.yc + 0.5*self.height * math.sin(angle + math.pi/2) + for i in range(nreg): + x = x_l - (0.5 + i) * height * math.cos(angle + math.pi/2) + y = y_l - (0.5 + i) * height * math.sin(angle + math.pi/2) + centers.append((x, y)) + # create split regions + regions = [] + for (x, y) in centers: + regions.append(self.__class__(x, y, width+2, height+2, + self.rotation)) + # write split regions into file if specified + if isinstance(filename, str): + regout = open(filename, "w") + regout.write("\n".join(map(str, regions))) + regout.close() + else: + return regions + +## BoxRegion }}} + + +def main(): + # command line arguments + parser = argparse.ArgumentParser( + description="Split strip-shaped rotbox region into " + \ + "a series of small square regions.", + epilog="Version: %s (%s)" % (__version__, __date__)) + parser.add_argument("-V", "--version", action="version", + version="%(prog)s " + "%s (%s)" % (__version__, __date__)) + parser.add_argument("infile", help="input rotbox region file") + parser.add_argument("outfile", help="output file of the split regions") + args = parser.parse_args() + + outfile = open(args.outfile, "w") + regex_box = re.compile(r'^\s*(box|rotbox)\([0-9., ]+\)\s*$', re.I) + for line in open(args.infile, "r"): + if regex_box.match(line): + reg = BoxRegion.parse(line) + split_regs = reg.split() + outfile.write("\n".join(map(str, split_regs)) + "\n") + else: + outfile.write(line) + + outfile.close() + + +if __name__ == "__main__": + main() + diff --git a/python/splitCCDgaps.py b/python/splitCCDgaps.py new file mode 100644 index 0000000..bc26c29 --- /dev/null +++ b/python/splitCCDgaps.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8- +# +# Aaron LI +# 2015/08/12 +# + +""" +Split the long-strip-shaped CCD gaps regions into a series of little +square regions, which are used as the input regions of 'roi' to +determine the corresponding background regions, and finally providied +to 'dmfilth' in order to fill in the CCD gaps. +""" + + +import re +import math +from io import TextIOWrapper + + +class BoxRegion(object): + """ + CIAO/DS9 "rotbox"/"box" region class. + + rotbox/box format: + rotbox(xc, yc, width, height, rotation) + box(xc, yc, width, height, rotation) + Notes: + rotation: [0, 360) (degree) + """ + def __init__(self, xc=None, yc=None, + width=None, height=None, rotation=None): + self.regtype = "rotbox" + self.xc = xc + self.yc = yc + self.width = width + self.height = height + self.rotation = rotation + + def __str__(self): + return "%s(%s,%s,%s,%s,%s)" % (self.regtype, self.xc, self.yc, + self.width, self.height, self.rotation) + + @classmethod + def parse(cls, regstr): + """ + Parse region string. + """ + regex_box = re.compile(r'^(box|rotbox)\(([0-9.]+),([0-9.]+),([0-9.]+),([0-9.]+),([0-9.]+)\)$', re.I) + m = regex_box.match(regstr) + if m: + regtype = m.group(1) + xc = float(m.group(2)) + yc = float(m.group(3)) + width = float(m.group(4)) + height = float(m.group(5)) + rotation = float(m.group(6)) + return cls(xc, yc, width, height, rotation) + else: + return None + + def split(self, filename=None): + """ + Split long-strip-shaped box region into a series square box regions. + """ + angle = self.rotation * math.pi / 180.0 + # to record the center coordinates of each split region + centers = [] + if self.width > self.height: + # number of regions after split + nreg = math.ceil(self.width / self.height) + # width & height of the split region + width = self.width / nreg + height = self.height + # position of the left-most region + x_l = self.xc - 0.5*self.width * math.cos(angle) + y_l = self.yc - 0.5*self.width * math.sin(angle) + for i in range(nreg): + x = x_l + (0.5 + i) * width * math.cos(angle) + y = y_l + (0.5 + i) * width * math.sin(angle) + centers.append((x, y)) + else: + # number of regions after split + nreg = math.ceil(self.height / self.width) + # width & height of the split region + width = self.width + height = self.height / nreg + # position of the left-most region + x_l = self.xc + 0.5*self.height * math.cos(angle + math.pi/2) + y_l = self.yc + 0.5*self.height * math.sin(angle + math.pi/2) + for i in range(nreg): + x = x_l - (0.5 + i) * height * math.cos(angle + math.pi/2) + y = y_l - (0.5 + i) * height * math.sin(angle + math.pi/2) + centers.append((x, y)) + # create split regions + regions = [] + for (x, y) in centers: + regions.append(self.__class__(x, y, width+2, height+2, + self.rotation)) + # write split regions into file if specified + if isinstance(filename, str): + regout = open(filename, "w") + regout.write("\n".join(map(str, regions))) + regout.close() + else: + return regions + + diff --git a/python/xkeywordsync.py b/python/xkeywordsync.py new file mode 100644 index 0000000..9a7cd79 --- /dev/null +++ b/python/xkeywordsync.py @@ -0,0 +1,533 @@ +#!/bin/usr/env python3 +# -*- coding: utf-8 -*- +# +# Credits: +# [1] Gaute Hope: gauteh/abunchoftags +# https://github.com/gauteh/abunchoftags/blob/master/keywsync.cc +# +# TODO: +# * Support case-insensitive tags merge +# (ref: http://stackoverflow.com/a/1480230) +# * Accept a specified mtime, and only deal with files with newer mtime. +# +# Aaron LI +# Created: 2016-01-24 +# + +""" +Sync message 'X-Keywords' header with notmuch tags. + +* tags-to-keywords: + Check if the messages in the query have a matching 'X-Keywords' header + to the list of notmuch tags. + If not, update the 'X-Keywords' and re-write the message. + +* keywords-to-tags: + Check if the messages in the query have matching notmuch tags to the + 'X-Keywords' header. + If not, update the tags in the notmuch database. + +* merge-keywords-tags: + Merge the 'X-Keywords' labels and notmuch tags, and update both. +""" + +__version__ = "0.1.2" +__date__ = "2016-01-25" + +import os +import sys +import argparse +import email + +# Require Python 3.4, or install package 'enum34' +from enum import Enum + +from notmuch import Database, Query + +from imapUTF7 import imapUTF7Decode, imapUTF7Encode + + +class SyncDirection(Enum): + """ + Synchronization direction + """ + MERGE_KEYWORDS_TAGS = 0 # Merge 'X-Keywords' and notmuch tags and + # update both + KEYWORDS_TO_TAGS = 1 # Sync 'X-Keywords' header to notmuch tags + TAGS_TO_KEYWORDS = 2 # Sync notmuch tags to 'X-Keywords' header + +class SyncMode(Enum): + """ + Sync mode + """ + ADD_REMOVE = 0 # Allow add & remove tags/keywords + ADD_ONLY = 1 # Only allow add tags/keywords + REMOVE_ONLY = 2 # Only allow remove tags/keywords + + +class KwMessage: + """ + Message class to deal with 'X-Keywords' header synchronization + with notmuch tags. + + NOTE: + * The same message may have multiple files with different keywords + (e.g, the same message exported under each label by Gmail) + managed by OfflineIMAP. + For example: a message file in OfflineIMAP synced folder of + '[Gmail]/All Mail' have keywords ['google', 'test']; however, + the file in synced folder 'test' of the same message only have + keywords ['google'] without the keyword 'test'. + * All files associated to the same message are regarded as the same. + The keywords are extracted from all files and merged. + And the same updated keywords are written back to all files, which + results all files finally having the same 'X-Keywords' header. + * You may only sync the '[Gmail]/All Mail' folder without other + folders exported according the labels by Gmail. + """ + # Replace some special characters before mapping keyword to tag + enable_replace_chars = True + chars_replace = { + '/' : '.', + } + # Mapping between (Gmail) keywords and notmuch tags (before ignoring tags) + keywords_mapping = { + '\\Inbox' : 'inbox', + '\\Important' : 'important', + '\\Starred' : 'flagged', + '\\Sent' : 'sent', + '\\Muted' : 'killed', + '\\Draft' : 'draft', + '\\Trash' : 'deleted', + '\\Junk' : 'spam', + } + # Tags ignored from syncing + # These tags are either internal tags or tags handled by maildir flags. + enable_ignore_tags = True + tags_ignored = set([ + 'new', 'unread', 'attachment', 'signed', 'encrypted', + 'flagged', 'replied', 'passed', 'draft', + ]) + # Ignore case when merging tags + tags_ignorecase = True + + # Whether the tags updated against the message 'X-Keywords' header + tags_updated = False + # Added & removed tags for notmuch database against 'X-Keywords' + tags_added = [] + tags_removed = [] + # Newly updated/merged notmuch tags against 'X-Keywords' + tags_new = [] + + # Whether the keywords updated against the notmuch tags + keywords_updated = False + # Added & removed tags for 'X-Keywords' against notmuch database + tags_kw_added = [] + tags_kw_removed = [] + # Newly updated/merged tags for 'X-Keywords' against notmuch database + tags_kw_new = [] + + def __init__(self, msg, filename=None): + self.message = msg + self.filename = filename + self.allfiles = [ fn for fn in msg.get_filenames() ] + self.tags = set(msg.get_tags()) + + def sync(self, direction, mode=SyncMode.ADD_REMOVE, + dryrun=False, verbose=False): + """ + Wrapper function to sync between 'X-Keywords' and notmuch tags. + """ + if direction == SyncDirection.KEYWORDS_TO_TAGS: + self.sync_keywords_to_tags(sync_mode=mode, dryrun=dryrun, + verbose=verbose) + elif direction == SyncDirection.TAGS_TO_KEYWORDS: + self.sync_tags_to_keywords(sync_mode=mode, dryrun=dryrun, + verbose=verbose) + elif direction == SyncDirection.MERGE_KEYWORDS_TAGS: + self.merge_keywords_tags(sync_mode=mode, dryrun=dryrun, + verbose=verbose) + else: + raise ValueError("Invalid sync direction: %s" % direction) + + def sync_keywords_to_tags(self, sync_mode=SyncMode.ADD_REMOVE, + dryrun=False, verbose=False): + """ + Wrapper function to sync 'X-Keywords' to notmuch tags. + """ + self.get_keywords() + self.map_keywords() + self.merge_tags(sync_direction=SyncDirection.KEYWORDS_TO_TAGS, + sync_mode=sync_mode) + if dryrun or verbose: + print('* MSG: %s' % self.message) + print(' TAG: [%s] +[%s] -[%s] => [%s]' % ( + ','.join(self.tags), ','.join(self.tags_added), + ','.join(self.tags_removed), ','.join(self.tags_new))) + if not dryrun: + self.update_tags() + + def sync_tags_to_keywords(self, sync_mode=SyncMode.ADD_REMOVE, + dryrun=False, verbose=False): + """ + Wrapper function to sync notmuch tags to 'X-Keywords' + """ + self.get_keywords() + self.map_keywords() + self.merge_tags(sync_direction=SyncDirection.TAGS_TO_KEYWORDS, + sync_mode=sync_mode) + keywords_new = self.map_tags(tags=self.tags_kw_new) + if dryrun or verbose: + print('* MSG: %s' % self.message) + print('* FILES: %s' % ' ; '.join(self.allfiles)) + print(' XKW: {%s} +[%s] -[%s] => {%s}' % ( + ','.join(self.keywords), ','.join(self.tags_kw_added), + ','.join(self.tags_kw_removed), ','.join(keywords_new))) + if not dryrun: + self.update_keywords(keywords_new=keywords_new) + + def merge_keywords_tags(self, sync_mode=SyncMode.ADD_REMOVE, + dryrun=False, verbose=False): + """ + Wrapper function to merge 'X-Keywords' and notmuch tags + """ + self.get_keywords() + self.map_keywords() + self.merge_tags(sync_direction=SyncDirection.MERGE_KEYWORDS_TAGS, + sync_mode=sync_mode) + keywords_new = self.map_tags(tags=self.tags_kw_new) + if dryrun or verbose: + print('* MSG: %s' % self.message) + print('* FILES: %s' % ' ; '.join(self.allfiles)) + print(' TAG: [%s] +[%s] -[%s] => [%s]' % ( + ','.join(self.tags), ','.join(self.tags_added), + ','.join(self.tags_removed), ','.join(self.tags_new))) + print(' XKW: {%s} +[%s] -[%s] => {%s}' % ( + ','.join(self.keywords), ','.join(self.tags_kw_added), + ','.join(self.tags_kw_removed), ','.join(keywords_new))) + if not dryrun: + self.update_tags() + self.update_keywords(keywords_new=keywords_new) + + def get_keywords(self): + """ + Get 'X-Keywords' header from all files associated with the same + message, decode, split and merge. + + NOTE: Do NOT simply use the `message.get_header()` method, which + cannot get the complete keywords from all files. + """ + keywords_utf7 = [] + for fn in self.allfiles: + msg = email.message_from_file(open(fn, 'r')) + val = msg['X-Keywords'] + if val: + keywords_utf7.append(val) + else: + print("WARNING: 'X-Keywords' header not found or empty " +\ + "for file: %s" % fn, file=sys.stderr) + keywords_utf7 = ','.join(keywords_utf7) + if keywords_utf7 != '': + keywords = imapUTF7Decode(keywords_utf7.encode()).split(',') + keywords = [ kw.strip() for kw in keywords ] + # Remove duplications + keywords = set(keywords) + else: + keywords = set() + self.keywords = keywords + return keywords + + def map_keywords(self, keywords=None): + """ + Map keywords to notmuch tags according to the mapping table. + """ + if keywords is None: + keywords = self.keywords + if self.enable_replace_chars: + # Replace specified characters in keywords + trans = str.maketrans(self.chars_replace) + keywords = [ kw.translate(trans) for kw in keywords ] + # Map keywords to tags + tags = set([ self.keywords_mapping.get(kw, kw) for kw in keywords ]) + self.tags_kw = tags + return tags + + def map_tags(self, tags=None): + """ + Map tags to keywords according to the inversed mapping table. + """ + if tags is None: + tags = self.tags + if self.enable_replace_chars: + # Inversely replace specified characters in tags + chars_replace_inv = { v: k for k, v in self.chars_replace.items() } + trans = str.maketrans(chars_replace_inv) + tags = [ tag.translate(trans) for tag in tags ] + # Map keywords to tags + keywords_mapping_inv = { v:k for k,v in self.keywords_mapping.items() } + keywords = set([ keywords_mapping_inv.get(tag, tag) for tag in tags ]) + self.keywords_tags = keywords + return keywords + + def merge_tags(self, sync_direction, sync_mode=SyncMode.ADD_REMOVE, + tags_nm=None, tags_kw=None): + """ + Merge the tags from notmuch database and 'X-Keywords' header, + according to the specified sync direction and operation restriction. + + TODO: support case-insensitive set operations + """ + # Added & removed tags for notmuch database against 'X-Keywords' + tags_added = [] + tags_removed = [] + # Newly updated/merged notmuch tags against 'X-Keywords' + tags_new = [] + # Added & removed tags for 'X-Keywords' against notmuch database + tags_kw_added = [] + tags_kw_removed = [] + # Newly updated/merged tags for 'X-Keywords' against notmuch database + tags_kw_new = [] + # + if tags_nm is None: + tags_nm = self.tags + if tags_kw is None: + tags_kw = self.tags_kw + if self.enable_ignore_tags: + # Remove ignored tags before merge + tags_nm2 = tags_nm.difference(self.tags_ignored) + tags_kw2 = tags_kw.difference(self.tags_ignored) + else: + tags_nm2 = tags_nm + tags_kw2 = tags_kw + # + if sync_direction == SyncDirection.KEYWORDS_TO_TAGS: + # Sync 'X-Keywords' to notmuch tags + tags_added = tags_kw2.difference(tags_nm2) + tags_removed = tags_nm2.difference(tags_kw2) + elif sync_direction == SyncDirection.TAGS_TO_KEYWORDS: + # Sync notmuch tags to 'X-Keywords' + tags_kw_added = tags_nm2.difference(tags_kw2) + tags_kw_removed = tags_kw2.difference(tags_nm2) + elif sync_direction == SyncDirection.MERGE_KEYWORDS_TAGS: + # Merge both notmuch tags and 'X-Keywords' + tags_merged = tags_nm2.union(tags_kw2) + # notmuch tags + tags_added = tags_merged.difference(tags_nm2) + tags_removed = tags_nm2.difference(tags_merged) + # tags for 'X-Keywords' + tags_kw_added = tags_merged.difference(tags_kw2) + tags_kw_removed = tags_kw2.difference(tags_merged) + else: + raise ValueError("Invalid synchronization direction") + # Apply sync operation restriction + self.tags_added = [] + self.tags_removed = [] + self.tags_kw_added = [] + self.tags_kw_removed = [] + tags_new = tags_nm # Use un-ignored notmuch tags + tags_kw_new = tags_kw # Use un-ignored 'X-Keywords' tags + if sync_mode != SyncMode.REMOVE_ONLY: + self.tags_added = tags_added + self.tags_kw_added = tags_kw_added + tags_new = tags_new.union(tags_added) + tags_kw_new = tags_kw_new.union(tags_kw_added) + if sync_mode != SyncMode.ADD_ONLY: + self.tags_removed = tags_removed + self.tags_kw_removed = tags_kw_removed + tags_new = tags_new.difference(tags_removed) + tags_kw_new = tags_kw_new.difference(tags_kw_removed) + # + self.tags_new = tags_new + self.tags_kw_new = tags_kw_new + if self.tags_added or self.tags_removed: + self.tags_updated = True + if self.tags_kw_added or self.tags_kw_removed: + self.keywords_updated = True + # + return { + 'tags_updated' : self.tags_updated, + 'tags_added' : self.tags_added, + 'tags_removed' : self.tags_removed, + 'tags_new' : self.tags_new, + 'keywords_updated' : self.keywords_updated, + 'tags_kw_added' : self.tags_kw_added, + 'tags_kw_removed' : self.tags_kw_removed, + 'tags_kw_new' : self.tags_kw_new, + } + + def update_keywords(self, keywords_new=None, outfile=None): + """ + Encode the keywords (default: self.keywords_new) and write back to + all message files. + + If parameter 'outfile' specified, then write the updated message + to that file instead of overwriting. + + NOTE: + * The modification time of the message file should be kept to prevent + OfflineIMAP from treating it as a new one (and the previous a + deleted one). + * All files associated with the same message are updated to have + the same 'X-Keywords' header. + """ + if not self.keywords_updated: + # keywords NOT updated, just skip + return + + if keywords_new is None: + keywords_new = self.keywords_new + # + if outfile is not None: + infile = self.allfiles[0:1] + outfile = [ os.path.expanduser(outfile) ] + else: + infile = self.allfiles + outfile = self.allfiles + # + for ifname, ofname in zip(infile, outfile): + msg = email.message_from_file(open(ifname, 'r')) + fstat = os.stat(ifname) + if keywords_new == []: + # Delete 'X-Keywords' header + print("WARNING: delete 'X-Keywords' header from file: %s" % + ifname, file=sys.stderr) + del msg['X-Keywords'] + else: + # Update 'X-Keywords' header + keywords = ','.join(keywords_new) + keywords_utf7 = imapUTF7Encode(keywords).decode() + # Delete then add, to avoid multiple occurrences + del msg['X-Keywords'] + msg['X-Keywords'] = keywords_utf7 + # Write updated message + with open(ofname, 'w') as fp: + fp.write(msg.as_string()) + # Reset the timestamps + os.utime(ofname, ns=(fstat.st_atime_ns, fstat.st_mtime_ns)) + + def update_tags(self, tags_added=None, tags_removed=None): + """ + Update notmuch tags according to keywords. + """ + if not self.tags_updated: + # tags NOT updated, just skip + return + + if tags_added is None: + tags_added = self.tags_added + if tags_removed is None: + tags_removed = self.tags_removed + # Use freeze/thaw for safer transactions to change tag values. + self.message.freeze() + for tag in tags_added: + self.message.add_tag(tag, sync_maildir_flags=False) + for tag in tags_removed: + self.message.remove_tag(tag, sync_maildir_flags=False) + self.message.thaw() + + +def get_notmuch_revision(dbpath=None): + """ + Get the current revision and UUID of notmuch database. + """ + import subprocess + import tempfile + if dbpath: + tf = tempfile.NamedTemporaryFile() + # Create a minimal notmuch config for the specified dbpath + config = '[database]\npath=%s\n' % os.path.expanduser(dbpath) + tf.file.write(config.encode()) + tf.file.flush() + cmd = 'notmuch --config=%s count --lastmod' % tf.name + output = subprocess.check_output(cmd, shell=True) + tf.close() + else: + cmd = 'notmuch count --lastmod' + output = subprocess.check_output(cmd, shell=True) + # Extract output + dbinfo = output.decode().split() + return { 'revision': int(dbinfo[2]), 'uuid': dbinfo[1] } + + +def main(): + parser = argparse.ArgumentParser( + description="Sync message 'X-Keywords' header with notmuch tags.") + parser.add_argument("-V", "--version", action="version", + version="%(prog)s " + "v%s (%s)" % (__version__, __date__)) + parser.add_argument("-q", "--query", dest="query", required=True, + help="notmuch database query string") + parser.add_argument("-p", "--db-path", dest="dbpath", + help="notmuch database path (default to try user configuration)") + parser.add_argument("-n", "--dry-run", dest="dryrun", + action="store_true", help="dry run") + parser.add_argument("-v", "--verbose", dest="verbose", + action="store_true", help="show verbose information") + # Exclusive argument group for sync mode + exgroup1 = parser.add_mutually_exclusive_group(required=True) + exgroup1.add_argument("-m", "--merge-keywords-tags", + dest="direction_merge", action="store_true", + help="merge 'X-Keywords' and tags and update both") + exgroup1.add_argument("-k", "--keywords-to-tags", + dest="direction_keywords2tags", action="store_true", + help="sync 'X-Keywords' to notmuch tags") + exgroup1.add_argument("-t", "--tags-to-keywords", + dest="direction_tags2keywords", action="store_true", + help="sync notmuch tags to 'X-Keywords'") + # Exclusive argument group for tag operation mode + exgroup2 = parser.add_mutually_exclusive_group(required=False) + exgroup2.add_argument("-a", "--add-only", dest="mode_addonly", + action="store_true", help="only add notmuch tags") + exgroup2.add_argument("-r", "--remove-only", dest="mode_removeonly", + action="store_true", help="only remove notmuch tags") + # Parse + args = parser.parse_args() + # Sync direction + if args.direction_merge: + sync_direction = SyncDirection.MERGE_KEYWORDS_TAGS + elif args.direction_keywords2tags: + sync_direction = SyncDirection.KEYWORDS_TO_TAGS + elif args.direction_tags2keywords: + sync_direction = SyncDirection.TAGS_TO_KEYWORDS + else: + raise ValueError("Invalid synchronization direction") + # Sync mode + if args.mode_addonly: + sync_mode = SyncMode.ADD_ONLY + elif args.mode_removeonly: + sync_mode = SyncMode.REMOVE_ONLY + else: + sync_mode = SyncMode.ADD_REMOVE + # + if args.dbpath: + dbpath = os.path.abspath(os.path.expanduser(args.dbpath)) + else: + dbpath = None + # + db = Database(path=dbpath, create=False, mode=Database.MODE.READ_WRITE) + dbinfo = get_notmuch_revision(dbpath=dbpath) + q = Query(db, args.query) + total_msgs = q.count_messages() + msgs = q.search_messages() + # + if args.verbose: + print("# Notmuch database path: %s" % dbpath) + print("# Database revision: %d (uuid: %s)" % + (dbinfo['revision'], dbinfo['uuid'])) + print("# Query: %s" % args.query) + print("# Sync direction: %s" % sync_direction.name) + print("# Sync mode: %s" % sync_mode.name) + print("# Total messages to check: %d" % total_msgs) + print("# Dryn run: %s" % args.dryrun) + # + for msg in msgs: + kwmsg = KwMessage(msg) + kwmsg.sync(direction=sync_direction, mode=sync_mode, + dryrun=args.dryrun, verbose=args.verbose) + # + db.close() + + +if __name__ == "__main__": + main() + +# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # |