diff options
Diffstat (limited to 'python')
| -rwxr-xr-x | python/adjust_spectrum_error.py | 170 | ||||
| -rwxr-xr-x | python/calc_radial_psd.py | 450 | ||||
| -rwxr-xr-x | python/crosstalk_deprojection.py | 1808 | ||||
| -rwxr-xr-x | python/fit_sbp.py | 807 | ||||
| -rw-r--r-- | python/imapUTF7.py | 189 | ||||
| -rwxr-xr-x | python/msvst_starlet.py | 646 | ||||
| -rw-r--r-- | python/plot.py | 35 | ||||
| -rw-r--r-- | python/plot_tprofiles_zzh.py | 126 | ||||
| -rwxr-xr-x | python/randomize_events.py | 72 | ||||
| -rwxr-xr-x | python/rebuild_ipod_db.py | 595 | ||||
| -rwxr-xr-x | python/splitBoxRegion.py | 148 | ||||
| -rw-r--r-- | python/splitCCDgaps.py | 107 | ||||
| -rw-r--r-- | python/xkeywordsync.py | 533 | 
13 files changed, 5686 insertions, 0 deletions
| diff --git a/python/adjust_spectrum_error.py b/python/adjust_spectrum_error.py new file mode 100755 index 0000000..0f80ec7 --- /dev/null +++ b/python/adjust_spectrum_error.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Squeeze the spectrum according to the grouping specification, then +calculate the statistical errors for each group, and apply error +adjustments (e.g., incorporate the systematic uncertainties). +""" + +__version__ = "0.1.0" +__date__ = "2016-01-11" + + +import sys +import argparse + +import numpy as np +from astropy.io import fits + + +class Spectrum: +    """ +    Spectrum class to keep spectrum information and perform manipulations. +    """ +    header = None +    channel = None +    counts = None +    grouping = None +    quality = None + +    def __init__(self, specfile): +        f = fits.open(specfile) +        spechdu = f['SPECTRUM'] +        self.header = spechdu.header +        self.channel = spechdu.data.field('CHANNEL') +        self.counts = spechdu.data.field('COUNTS') +        self.grouping = spechdu.data.field('GROUPING') +        self.quality = spechdu.data.field('QUALITY') +        f.close() + +    def squeezeByGrouping(self): +        """ +        Squeeze the spectrum according to the grouping specification, +        i.e., sum the counts belonging to the same group, and place the +        sum as the first channel within each group with other channels +        of counts zero's. +        """ +        counts_squeezed = [] +        cnt_sum = 0 +        cnt_num = 0 +        first = True +        for grp, cnt in zip(self.grouping, self.counts): +            if first and grp == 1: +                # first group +                cnt_sum = cnt +                cnt_num = 1 +                first = False +            elif grp == 1: +                # save previous group +                counts_squeezed.append(cnt_sum) +                counts_squeezed += [ 0 for i in range(cnt_num-1) ] +                # start new group +                cnt_sum = cnt +                cnt_num = 1 +            else: +                # group continues +                cnt_sum += cnt +                cnt_num += 1 +        # last group +        # save previous group +        counts_squeezed.append(cnt_sum) +        counts_squeezed += [ 0 for i in range(cnt_num-1) ] +        self.counts_squeezed = np.array(counts_squeezed, dtype=np.int32) + +    def calcStatErr(self, gehrels=False): +        """ +        Calculate the statistical errors for the grouped channels, +        and save as the STAT_ERR column. +        """ +        idx_nz = np.nonzero(self.counts_squeezed) +        stat_err = np.zeros(self.counts_squeezed.shape) +        if gehrels: +            # Gehrels +            stat_err[idx_nz] = 1 + np.sqrt(self.counts_squeezed[idx_nz] + 0.75) +        else: +            stat_err[idx_nz] = np.sqrt(self.counts_squeezed[idx_nz]) +        self.stat_err = stat_err + +    @staticmethod +    def parseSysErr(syserr): +        """ +        Parse the string format of syserr supplied in the commandline. +        """ +        items = map(str.strip, syserr.split(',')) +        syserr_spec = [] +        for item in items: +            spec = item.split(':') +            try: +                spec = (int(spec[0]), int(spec[1]), float(spec[2])) +            except: +                raise ValueError("invalid syserr specficiation") +            syserr_spec.append(spec) +        return syserr_spec + +    def applySysErr(self, syserr): +        """ +        Apply systematic error adjustments to the above calculated +        statistical errors. +        """ +        syserr_spec = self.parseSysErr(syserr) +        for lo, hi, se in syserr_spec: +            err_adjusted = self.stat_err[(lo-1):(hi-1)] * np.sqrt(1+se) +        self.stat_err_adjusted = err_adjusted + +    def updateHeader(self): +        """ +        Update header accordingly. +        """ +        # POISSERR +        self.header['POISSERR'] = False + +    def write(self, filename, clobber=False): +        """ +        Write the updated/modified spectrum block to file. +        """ +        channel_col = fits.Column(name='CHANNEL', format='J', +                array=self.channel) +        counts_col = fits.Column(name='COUNTS', format='J', +                array=self.counts_squeezed) +        stat_err_col = fits.Column(name='STAT_ERR', format='D', +                array=self.stat_err_adjusted) +        grouping_col = fits.Column(name='GROUPING', format='I', +                array=self.grouping) +        quality_col = fits.Column(name='QUALITY', format='I', +                array=self.quality) +        spec_cols = fits.ColDefs([channel_col, counts_col, stat_err_col, +                                  grouping_col, quality_col]) +        spechdu = fits.BinTableHDU.from_columns(spec_cols, header=self.header) +        spechdu.writeto(filename, clobber=clobber) + + +def main(): +    parser = argparse.ArgumentParser( +            description="Apply systematic error adjustments to spectrum.") +    parser.add_argument("-V", "--version", action="version", +            version="%(prog)s " + "%s (%s)" % (__version__, __date__)) +    parser.add_argument("infile", help="input spectrum file") +    parser.add_argument("outfile", help="output adjusted spectrum file") +    parser.add_argument("-e", "--syserr", dest="syserr", required=True, +            help="systematic error specification; " + \ +                 "syntax: ch1low:ch1high:syserr1,...") +    parser.add_argument("-C", "--clobber", dest="clobber", +            action="store_true", help="overwrite output file if exists") +    parser.add_argument("-G", "--gehrels", dest="gehrels", +            action="store_true", help="use Gehrels error?") +    args = parser.parse_args() + +    spec = Spectrum(args.infile) +    spec.squeezeByGrouping() +    spec.calcStatErr(gehrels=args.gehrels) +    spec.applySysErr(syserr=args.syserr) +    spec.updateHeader() +    spec.write(args.outfile, clobber=args.clobber) + + +if __name__ == "__main__": +    main() + + +#  vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/calc_radial_psd.py b/python/calc_radial_psd.py new file mode 100755 index 0000000..23bd819 --- /dev/null +++ b/python/calc_radial_psd.py @@ -0,0 +1,450 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Credit: +# [1]  Radially averaged power spectrum of 2D real-valued matrix +#      Evan Ruzanski +#      'raPsd2d.m' +#      https://www.mathworks.com/matlabcentral/fileexchange/23636-radially-averaged-power-spectrum-of-2d-real-valued-matrix +# +# XXX: +# * If the input image is NOT SQUARE; then are the horizontal frequencies +#   the same as the vertical frequencies ?? +# +# Aaron LI <aaronly.me@gmail.com> +# Created: 2015-04-22 +# Updated: 2016-04-28 +# +# Changelog: +# 2016-04-28: +#   * Fix wrong meshgrid with respect to the shift zero-frequency component +#   * Use "numpy.fft" instead of "scipy.fftpack" +#   * Split method "pad_square()" from "calc_radial_psd()" +#   * Hide numpy warning when dividing by zero +#   * Add method "AstroImage.fix_shapes()" +#   * Add support for background subtraction and exposure correction +#   * Show verbose information during calculation +#   * Add class "AstroImage" +#   * Set default value for 'args.png' +#   * Rename from 'radialPSD2d.py' to 'calc_radial_psd.py' +# 2016-04-26: +#   * Adjust plot function +#   * Update normalize argument; Add pixel argument +# 2016-04-25: +#   * Update plot function +#   * Add command line scripting support +#   * Encapsulate the functions within class 'PSD' +#   * Update docs/comments +# + +""" +Compute the radially averaged power spectral density (i.e., power spectrum). +""" + +__version__ = "0.5.0" +__date__    = "2016-04-28" + + +import sys +import os +import argparse + +import numpy as np +from astropy.io import fits + +import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure + +plt.style.use("ggplot") + + +class PSD: +    """ +    Computes the 2D power spectral density and the radially averaged power +    spectral density (i.e., 1D power spectrum). +    """ +    # 2D image data +    img = None +    # value and unit of 1 pixel for the input image +    pixel = (None, None) +    # whether to normalize the power spectral density by image size +    normalize = True +    # 2D power spectral density +    psd2d = None +    # 1D (radially averaged) power spectral density +    freqs     = None +    psd1d     = None +    psd1d_err = None + +    def __init__(self, img, pixel=(1.0, "pixel"), normalize=True): +        self.img = img.astype(np.float) +        self.pixel = pixel +        self.normalize = normalize + +    def calc_psd2d(self, verbose=False): +        """ +        Computes the 2D power spectral density of the given image. +        Note that the low frequency components are shifted to the center +        of the FFT'ed image. + +        NOTE: +        The zero-frequency component is shifted to position of index (0-based) +            (ceil((n-1) / 2), ceil((m-1) / 2)), +        where (n, m) are the number of rows and columns of the image/psd2d. + +        Return: +            2D power spectral density, which is dimensionless if normalized, +            otherwise has unit ${pixel_unit}^2. +        """ +        if verbose: +            print("Calculating 2D power spectral density ... ", +                    end="", flush=True) +        rows, cols = self.img.shape +        ## Compute the power spectral density (i.e., power spectrum) +        imgf = np.fft.fftshift(np.fft.fft2(self.img)) +        if self.normalize: +            norm = rows * cols * self.pixel[0]**2 +        else: +            norm = 1.0  # Do not normalize +        self.psd2d = (np.abs(imgf) / norm) ** 2 +        if verbose: +            print("DONE", flush=True) +        return self.psd2d + +    def calc_radial_psd1d(self, verbose=False): +        """ +        Computes the radially averaged power spectral density from the +        provided 2D power spectral density. + +        Return: +            (freqs, radial_psd, radial_psd_err) +            freqs: spatial freqencies (unit: ${pixel_unit}^(-1)) +            radial_psd: radially averaged power spectral density for each +                        frequency +            radial_psd_err: standard deviations of each radial_psd +        """ +        if verbose: +            print("Calculating radial (1D) power spectral density ... ", +                    end="", flush=True) +        if verbose: +            print("padding ... ", end="", flush=True) +        psd2d    = self.pad_square(self.psd2d, value=np.nan) +        dim      = psd2d.shape[0] +        dim_half = (dim+1) // 2 +        # NOTE: +        # The zero-frequency component is shifted to position of index +        # (0-based): (ceil((n-1) / 2), ceil((m-1) / 2)) +        px       = np.arange(dim_half-dim, dim_half) +        x, y     = np.meshgrid(px, px) +        rho, phi = self.cart2pol(x, y) +        rho      = np.around(rho).astype(np.int) +        radial_psd     = np.zeros(dim_half) +        radial_psd_err = np.zeros(dim_half) +        if verbose: +            print("radially averaging ... ", +                    end="", flush=True) +        for r in range(dim_half): +            # Get the indices of the elements satisfying rho[i,j]==r +            ii, jj = (rho == r).nonzero() +            # Calculate the mean value at a given radii +            data              = psd2d[ii, jj] +            radial_psd[r]     = np.nanmean(data) +            radial_psd_err[r] = np.nanstd(data) +        # Calculate frequencies +        f = np.fft.fftfreq(dim, d=self.pixel[0]) +        freqs = np.abs(f[:dim_half]) +        # +        self.freqs     = freqs +        self.psd1d     = radial_psd +        self.psd1d_err = radial_psd_err +        if verbose: +            print("DONE", end="", flush=True) +        return (freqs, radial_psd, radial_psd_err) + +    @staticmethod +    def cart2pol(x, y): +        """ +        Convert Cartesian coordinates to polar coordinates. +        """ +        rho = np.sqrt(x**2 + y**2) +        phi = np.arctan2(y, x) +        return (rho, phi) + +    @staticmethod +    def pol2cart(rho, phi): +        """ +        Convert polar coordinates to Cartesian coordinates. +        """ +        x = rho * np.cos(phi) +        y = rho * np.sin(phi) +        return (x, y) + +    @staticmethod +    def pad_square(data, value=np.nan): +        """ +        Symmetrically pad the supplied data matrix to make it square. +        The padding rows are equally added to the top and bottom, +        as well as the columns to the left and right sides. +        The padded rows/columns are filled with the specified value. +        """ +        mat = data.copy() +        rows, cols = mat.shape +        dim_diff = abs(rows - cols) +        dim_max  = max(rows, cols) +        if rows > cols: +            # pad columns +            if dim_diff // 2 == 0: +                cols_left     = np.zeros((rows, dim_diff/2)) +                cols_left[:]  = value +                cols_right    = np.zeros((rows, dim_diff/2)) +                cols_right[:] = value +                mat           = np.hstack((cols_left, mat, cols_right)) +            else: +                cols_left     = np.zeros((rows, np.floor(dim_diff/2))) +                cols_left[:]  = value +                cols_right    = np.zeros((rows, np.floor(dim_diff/2)+1)) +                cols_right[:] = value +                mat           = np.hstack((cols_left, mat, cols_right)) +        elif rows < cols: +            # pad rows +            if dim_diff // 2 == 0: +                rows_top       = np.zeros((dim_diff/2, cols)) +                rows_top[:]    = value +                rows_bottom    = np.zeros((dim_diff/2, cols)) +                rows_bottom[:] = value +                mat            = np.vstack((rows_top, mat, rows_bottom)) +            else: +                rows_top       = np.zeros((np.floor(dim_diff/2), cols)) +                rows_top[:]    = value +                rows_bottom    = np.zeros((np.floor(dim_diff/2)+1, cols)) +                rows_bottom[:] = value +                mat            = np.vstack((rows_top, mat, rows_bottom)) +        return mat + +    def plot(self, ax=None, fig=None): +        """ +        Make a plot of the radial (1D) PSD with matplotlib. +        """ +        if ax is None: +            fig, ax = plt.subplots(1, 1) +        # +        xmin = self.freqs[1] / 1.2  # ignore the first 0 +        xmax = self.freqs[-1] +        ymin = np.nanmin(self.psd1d) / 10.0 +        ymax = np.nanmax(self.psd1d + self.psd1d_err) +        # +        eb = ax.errorbar(self.freqs, self.psd1d, yerr=self.psd1d_err, +                         fmt="none") +        ax.plot(self.freqs, self.psd1d, "ko") +        ax.set_xscale("log") +        ax.set_yscale("log") +        ax.set_xlim(xmin, xmax) +        ax.set_ylim(ymin, ymax) +        ax.set_title("Radially Averaged Power Spectral Density") +        ax.set_xlabel(r"k (%s$^{-1}$)" % self.pixel[1]) +        if self.normalize: +            ax.set_ylabel("Power") +        else: +            ax.set_ylabel(r"Power (%s$^2$)" % self.pixel[1]) +        fig.tight_layout() +        return (fig, ax) + + +class AstroImage: +    """ +    Manipulate the astronimcal counts image, as well as the corresponding +    exposure map and background map. +    """ +    # input counts image +    image = None +    # exposure map with respect to the input counts image +    expmap = None +    # background map (e.g., stowed background) +    bkgmap = None +    # exposure time of the input image +    exposure     = None +    # exposure time of the background map +    exposure_bkg = None + +    def __init__(self, image, expmap=None, bkgmap=None, verbose=False): +        self.load_image(image,   verbose=verbose) +        self.load_expmap(expmap, verbose=verbose) +        self.load_bkgmap(bkgmap, verbose=verbose) + +    def load_image(self, image, verbose=False): +        if verbose: +            print("Loading image ... ", end="", flush=True) +        with fits.open(image) as imgfits: +            self.image = imgfits[0].data.astype(np.float) +            self.exposure = imgfits[0].header["EXPOSURE"] +        if verbose: +            print("DONE", flush=True) + +    def load_expmap(self, expmap, verbose=False): +        if expmap: +            if verbose: +                print("Loading exposure map ... ", end="", flush=True) +            with fits.open(expmap) as imgfits: +                self.expmap = imgfits[0].data.astype(np.float) +            if verbose: +                print("DONE", flush=True) + +    def load_bkgmap(self, bkgmap, verbose=False): +        if bkgmap: +            if verbose: +                print("Loading background map ... ", end="", flush=True) +            with fits.open(bkgmap) as imgfits: +                self.bkgmap = imgfits[0].data.astype(np.float) +                self.exposure_bkg = imgfits[0].header["EXPOSURE"] +            if verbose: +                print("DONE", flush=True) + +    def fix_shapes(self, tolerance=2, verbose=False): +        """ +        Fix the shapes of self.expmap and self.bkgmap to make them have +        the same shape as the self.image. + +        NOTE: +        * if the image is bigger than the reference image, then its +          columns on the right and rows on the botton are clipped; +        * if the image is smaller than the reference image, then padding +          columns on the right and rows on the botton are added. +        * Original images are REPLACED! + +        Arguments: +          * tolerance: allow absolute difference between images +        """ +        def _fix_shape(img, ref, tol=tolerance, verbose=verbose): +            if img.shape == ref.shape: +                if verbose: +                    print("SKIPPED", flush=True) +                return img +            elif np.allclose(img.shape, ref.shape, atol=tol): +                if verbose: +                    print(img.shape, "->", ref.shape, flush=True) +                rows, cols = img.shape +                rows_ref, cols_ref = ref.shape +                # rows +                if rows > rows_ref: +                    img_fixed = img[:rows_ref, :] +                else: +                    img_fixed = np.row_stack((img, +                        np.zeros((rows_ref-rows, cols), dtype=img.dtype))) +                # columns +                if cols > cols_ref: +                    img_fixed = img_fixed[:, :cols_ref] +                else: +                    img_fixed = np.column_stack((img_fixed, +                        np.zeros((rows_ref, cols_ref-cols), dtype=img.dtype))) +                return img_fixed +            else: +                raise ValueError("shape difference exceeds tolerance: " + \ +                        "(%d, %d) vs. (%d, %d)" % (img.shape + ref.shape)) +        # +        if self.bkgmap is not None: +            if verbose: +                print("Fixing shape for bkgmap ... ", end="", flush=True) +            self.bkgmap = _fix_shape(self.bkgmap, self.image) +        if self.expmap is not None: +            if verbose: +                print("Fixing shape for expmap ... ", end="", flush=True) +            self.expmap = _fix_shape(self.expmap, self.image) + +    def subtract_bkg(self, verbose=False): +        if verbose: +            print("Subtracting background ... ", end="", flush=True) +        self.image -= (self.bkgmap / self.exposure_bkg * self.exposure) +        if verbose: +            print("DONE", flush=True) + +    def correct_exposure(self, cut=0.015, verbose=False): +        """ +        Correct the image for exposure by dividing by the expmap to +        create the exposure-corrected image. + +        Arguments: +          * cut: the threshold percentage with respect to the maximum +                 exposure map value; and those pixels with lower values +                 than this threshold will be excluded/clipped (set to ZERO) +                 if set to None, then skip clipping image +        """ +        if verbose: +            print("Correcting image for exposure ... ", end="", flush=True) +        with np.errstate(divide="ignore", invalid="ignore"): +            self.image /= self.expmap +        # set invalid values to ZERO +        self.image[ ~ np.isfinite(self.image) ] = 0.0 +        if verbose: +            print("DONE", flush=True) +        if cut is not None: +            # clip image according the exposure threshold +            if verbose: +                print("Clipping image (%s) ... " % cut, end="", flush=True) +            threshold = cut * np.max(self.expmap) +            self.image[ self.expmap < threshold ] = 0.0 +            if verbose: +                print("DONE", flush=True) + + +def main(): +    parser = argparse.ArgumentParser( +            description="Compute the radially averaged power spectral density", +            epilog="Version: %s (%s)" % (__version__, __date__)) +    parser.add_argument("-V", "--version", action="version", +            version="%(prog)s " + "%s (%s)" % (__version__, __date__)) +    parser.add_argument("-v", "--verbose", dest="verbose", +            action="store_true", help="show verbose information") +    parser.add_argument("-C", "--clobber", dest="clobber", +            action="store_true", +            help="overwrite the output files if already exist") +    parser.add_argument("-i", "--infile", dest="infile", +            required=True, help="input image") +    parser.add_argument("-b", "--bkgmap", dest="bkgmap", default=None, +            help="background map (for background subtraction)") +    parser.add_argument("-e", "--expmap", dest="expmap", default=None, +            help="exposure map (for exposure correction)") +    parser.add_argument("-o", "--outfile", dest="outfile", +            required=True, help="output file to store the PSD data") +    parser.add_argument("-p", "--png", dest="png", default=None, +            help="plot the PSD and save (default: same basename as outfile)") +    args = parser.parse_args() + +    if args.png is None: +        args.png = os.path.splitext(args.outfile)[0] + ".png" + +    # Check output files whether already exists +    if (not args.clobber) and os.path.exists(args.outfile): +        raise ValueError("outfile '%s' already exists" % args.outfile) +    if (not args.clobber) and os.path.exists(args.png): +        raise ValueError("output png '%s' already exists" % args.png) + +    # Load image data +    image = AstroImage(image=args.infile, expmap=args.expmap, +                       bkgmap=args.bkgmap, verbose=args.verbose) +    image.fix_shapes(verbose=args.verbose) +    if args.bkgmap: +        image.subtract_bkg(verbose=args.verbose) +    if args.expmap: +        image.correct_exposure(verbose=args.verbose) + +    # Calculate the power spectral density +    psd = PSD(img=image.image, normalize=True) +    psd.calc_psd2d(verbose=args.verbose) +    freqs, psd1d, psd1d_err = psd.calc_radial_psd1d(verbose=args.verbose) + +    # Write out PSD results +    psd_data = np.column_stack((freqs, psd1d, psd1d_err)) +    np.savetxt(args.outfile, psd_data, header="freqs  psd1d  psd1d_err") + +    # Make and save a plot +    fig = Figure(figsize=(10, 8)) +    canvas = FigureCanvas(fig) +    ax = fig.add_subplot(111) +    psd.plot(ax=ax, fig=fig) +    fig.savefig(args.png, format="png", dpi=150) + + +if __name__ == "__main__": +    main() + diff --git a/python/crosstalk_deprojection.py b/python/crosstalk_deprojection.py new file mode 100755 index 0000000..d5bab05 --- /dev/null +++ b/python/crosstalk_deprojection.py @@ -0,0 +1,1808 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# References: +# [1] Definition of RMF and ARF file formats +#     https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html +# [2] The OGIP Spectral File Format +#     https://heasarc.gsfc.nasa.gov/docs/heasarc/ofwg/docs/summary/ogip_92_007_summary.html +# [3] CIAO: Auxiliary Response File +#     http://cxc.harvard.edu/ciao/dictionary/arf.html +# [4] CIAO: Redistribution Matrix File +#     http://cxc.harvard.edu/ciao/dictionary/rmf.html +# [5] astropy - FITS format code +#     http://docs.astropy.org/en/stable/io/fits/usage/table.html#column-creation +# [6] XSPEC - Spectral Fitting +#     https://heasarc.gsfc.nasa.gov/docs/xanadu/xspec/manual/XspecSpectralFitting.html +# [7] Direct X-ray Spectra Deprojection +#     https://www-xray.ast.cam.ac.uk/papers/dsdeproj/ +#     Sanders & Fabian 2007, MNRAS, 381, 1381 +# +# +# Weitian LI +# Created: 2016-03-26 +# Updated: 2016-04-20 +# +# ChangeLog: +# 2016-04-20: +#   * Add argument 'add_history' to some methods (to avoid many duplicated +#     histories due to Monte Carlo) +#   * Rename 'reset_header_keywords()' to 'fix_header_keywords()', +#     and add mandatory spectral keywords if missing +#   * Add method 'fix_header()' to class 'Crosstalk' and 'Deprojection', +#     and fix the headers before write spectra +# 2016-04-19: +#   * Ignore numpy error due to division by zero +#   * Update tool description and sample configuration +#   * Add two other main methods: `main_deprojection()' and `main_crosstalk()' +#   * Add argument 'group_squeeze' to some methods for better performance +#   * Rename from 'correct_crosstalk.py' to 'crosstalk_deprojection.py' +# 2016-04-18: +#   * Implement deprojection function: class Deprojection +#   * Support spectral grouping (supply the grouping specification) +#   * Add grouping, estimate_errors, copy, randomize, etc. methods +#   * Utilize the Monte Carlo techniques to estimate the final spectral errors +#   * Collect all ARFs and RMFs within dictionaries +# 2016-04-06: +#   * Fix `RMF: get_rmfimg()' for XMM EPIC RMF +# 2016-04-02: +#   * Interpolate ARF in order to match the spectral channel energies +#   * Add version and date information +#   * Update documentations +#   * Update header history contents +# 2016-04-01: +#   * Greatly update the documentations (e.g., description, sample config) +#   * Add class `RMF' +#   * Add method `get_energy()' for class `ARF' +#   * Split out class `SpectrumSet' from `Spectrum' +#   * Implement background subtraction +#   * Add config `subtract_bkg' and corresponding argument +# +# XXX/FIXME: +#   * Deprojection: account for ARF differences across different regions +# +# TODO: +#   * Split classes ARF, RMF, Spectrum, and SpectrumSet to a separate module +# + +__version__ = "0.5.2" +__date__    = "2016-04-20" + + +""" +Correct the crosstalk effect of XMM spectra by subtracting the photons +that scattered from the surrounding regions due to the finite PSF, and +by compensating the photons that scattered to the surrounding regions, +according to the generated crosstalk ARFs by SAS `arfgen'. + +After the crosstalk effect being corrected, the deprojection is performed +to deproject the crosstalk-corrected spectra to derive the spectra with +both the crosstalk effect and projection effect corrected. + + +Sample config file (in `ConfigObj' syntax): +----------------------------------------------------------- +# operation mode: deprojection, crosstalk, or both (default) +mode = both +# supply a *groupped* spectrum (from which the "GROUPING" and "QUALITY" +# are used to group all the following spectra) +grouping = spec_grp.pi +# whether to subtract the background before crosstalk correction +subtract_bkg = True +# whether to fix the negative channel values due to spectral subtractions +fix_negative = False +# Monte Carlo times for spectral error estimation +mc_times = 5000 +# show progress details and verbose information +verbose = True +# overwrite existing files +clobber = False + +# NOTE: +# ONLY specifiy ONE set of projected spectra (i.e., from the same detector +# of one observation), since ALL the following specified spectra will be +# used for the deprojection. + +[reg1] +... + +[reg2] +outfile = deprojcc_reg2.pi +spec = reg2.pi +arf = reg2.arf +rmf = reg2.rmf +bkg = reg2_bkg.pi +  [[cross_in]] +    [[[in1]]] +    spec = reg1.pi +    arf = reg1.arf +    rmf = reg1.rmf +    bkg = reg1_bkg.pi +    cross_arf = reg_1-2.arf +    [[[in2]]] +    spec = reg3.pi +    arf = reg3.arf +    rmf = reg3.rmf +    bkg = reg3_bkg.pi +    cross_arf = reg_3-2.arf +  [[cross_out]] +  cross_arf = reg_2-1.arf, reg_2-3.arf + +[...] +... +----------------------------------------------------------- +""" + +WARNING = """ +********************************* WARNING ************************************ +The generated spectra are substantially modified (e.g., scale, add, subtract), +therefore, take special care when interpretating the fitting results, +especially the metal abundances and normalizations. +****************************************************************************** +""" + + +import sys +import os +import argparse +from datetime import datetime +from copy import copy + +import numpy as np +import scipy as sp +import scipy.interpolate +from astropy.io import fits +from configobj import ConfigObj + + +def group_data(data, grouping): +    """ +    Group the data with respect to the supplied `grouping' specification +    (i.e., "GROUPING" columns of a spectrum).  The channel counts of the +    same group are summed up and assigned to the FIRST channel of this +    group, while the OTHRE channels are all set to ZERO. +    """ +    data_grp = np.array(data).copy() +    for i in reversed(range(len(data))): +        if grouping[i] == 1: +            # the beginning channel of a group +            continue +        else: +            # other channels of a group +            data_grp[i-1] += data_grp[i] +            data_grp[i]    = 0 +    assert np.isclose(sum(data_grp), sum(data)) +    return data_grp + + +class ARF:  # {{{ +    """ +    Class to handle the ARF (ancillary/auxiliary response file), +    which contains the combined instrumental effective area +    (telescope/filter/detector) and the quantum efficiency (QE) as a +    function of energy averaged over time. +    The effective area is [cm^2], and the QE is [counts/photon]; they are +    multiplied together to create the ARF, resulting in [cm^2 counts/photon]. + +    **CAVEAT/NOTE**: +    Generally, the "ENERG_LO" and "ENERG_HI" columns of an ARF are *different* +    to the "E_MIN" and "E_MAX" columns of a RMF (which are corresponding +    to the spectrum channel energies). +    For the XMM EPIC *pn* and Chandra *ACIS*, the generated ARF does NOT have +    the same number of data points to that of spectral channels, i.e., the +    "ENERG_LO" and "ENERG_HI" columns of ARF is different to the "E_MIN" and +    "E_MAX" columns of RMF. +    Therefore it is necessary to interpolate and extrapolate the ARF curve +    in order to match the spectrum (or RMF "EBOUNDS" extension). +    As for the XMM EPIC *MOS1* and *MOS2*, the ARF data points match the +    spectral channels, i.e., the energy positions of each ARF data point and +    spectral channel are consistent.  Thus the interpolation is not needed. + +    References: +    [1] CIAO: Auxiliary Response File +        http://cxc.harvard.edu/ciao/dictionary/arf.html +    [2] Definition of RMF and ARF file formats +        https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html +    """ +    filename = None +    fitsobj  = None +    # only consider the "SPECTRUM" extension +    header   = None +    energ_lo = None +    energ_hi = None +    specresp = None +    # function of the interpolated ARF +    f_interp = None +    # energies of the spectral channels +    energy_channel = None +    # spectral channel grouping specification +    grouping       = None +    groupped       = False +    # groupped ARF channels with respect to the grouping +    specresp_grp   = None + +    def __init__(self, filename): +        self.filename = filename +        self.fitsobj  = fits.open(filename) +        ext_specresp  = self.fitsobj["SPECRESP"] +        self.header   = ext_specresp.header +        self.energ_lo = ext_specresp.data["ENERG_LO"] +        self.energ_hi = ext_specresp.data["ENERG_HI"] +        self.specresp = ext_specresp.data["SPECRESP"] + +    def get_data(self, groupped=False, group_squeeze=False, copy=True): +        if groupped: +            specresp = self.specresp_grp +            if group_squeeze: +                specresp = specresp[self.grouping == 1] +        else: +            specresp = self.specresp +        if copy: +            return specresp.copy() +        else: +            return specresp + +    def get_energy(self, mean="geometric"): +        """ +        Return the mean energy values of the ARF. + +        Arguments: +          * mean: type of the mean energy: +                  + "geometric": geometric mean, i.e., e = sqrt(e_min*e_max) +                  + "arithmetic": arithmetic mean, i.e., e = 0.5*(e_min+e_max) +        """ +        if mean == "geometric": +            energy = np.sqrt(self.energ_lo * self.energ_hi) +        elif mean == "arithmetic": +            energy = 0.5 * (self.energ_lo + self.energ_hi) +        else: +            raise ValueError("Invalid mean type: %s" % mean) +        return energy + +    def interpolate(self, x=None, verbose=False): +        """ +        Cubic interpolate the ARF curve using `scipy.interpolate' + +        If the requested point is outside of the data range, the +        fill value of *zero* is returned. + +        Arguments: +          * x: points at which the interpolation to be calculated. + +        Return: +          If x is None, then the interpolated function is returned, +          otherwise, the interpolated data are returned. +        """ +        if not hasattr(self, "f_interp") or self.f_interp is None: +            energy = self.get_energy() +            arf    = self.get_data(copy=False) +            if verbose: +                print("INFO: interpolating '%s' (this may take a while) ..." \ +                        % self.filename, file=sys.stderr) +            f_interp = sp.interpolate.interp1d(energy, arf, kind="cubic", +                    bounds_error=False, fill_value=0.0, assume_sorted=True) +            self.f_interp = f_interp +        if x is not None: +            return self.f_interp(x) +        else: +            return self.f_interp + +    def apply_grouping(self, energy_channel, grouping, verbose=False): +        """ +        Group the ARF channels (INTERPOLATED with respect to the spectral +        channels) by the supplied grouping specification. + +        Arguments: +          * energy_channel: energies of the spectral channel +          * grouping: spectral grouping specification + +        Return: `self.specresp_grp' +        """ +        if self.groupped: +            return +        if verbose: +            print("INFO: Grouping spectrum '%s' ..." % self.filename, +                    file=sys.stderr) +        self.energy_channel = energy_channel +        self.grouping = grouping +        # interpolate the ARF w.r.t the spectral channel energies +        arf_interp = self.interpolate(x=energy_channel, verbose=verbose) +        self.specresp_grp = group_data(arf_interp, grouping) +        self.groupped = True +# class ARF }}} + + +class RMF:  # {{{ +    """ +    Class to handle the RMF (redistribution matrix file), +    which maps from energy space into detector pulse height (or position) +    space.  Since detectors are not perfect, this involves a spreading of +    the observed counts by the detector resolution, which is expressed as +    a matrix multiplication. +    For X-ray spectral analysis, the RMF encodes the probability R(E,p) +    that a detected photon of energy E will be assisgned to a given +    channel value (PHA or PI) of p. + +    The standard Legacy format [2] for the RMF uses a binary table in which +    each row contains R(E,p) for a single value of E as a function of p. +    Non-zero sequences of elements of R(E,p) are encoded using a set of +    variable length array columns.  This format is compact but hard to +    manipulate and understand. + +    **CAVEAT/NOTE**: +    + See also the above ARF CAVEAT/NOTE. +    + The "EBOUNDS" extension contains the `CHANNEL', `E_MIN' and `E_MAX' +      columns.  This `CHANNEL' is the same as that of a spectrum.  Therefore, +      the energy values determined from the `E_MIN' and `E_MAX' columns are +      used to interpolate and extrapolate the ARF curve. +    + The `ENERG_LO' and `ENERG_HI' columns of the "MATRIX" extension are +      the same as that of a ARF. + +    References: +    [1] CIAO: Redistribution Matrix File +        http://cxc.harvard.edu/ciao/dictionary/rmf.html +    [2] Definition of RMF and ARF file formats +        https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html +    """ +    filename    = None +    fitsobj     = None +    ## extension "MATRIX" +    hdr_matrix  = None +    energ_lo    = None +    energ_hi    = None +    n_grp       = None +    f_chan      = None +    n_chan      = None +    # raw squeezed RMF matrix data +    matrix      = None +    ## extension "EBOUNDS" +    hdr_ebounds = None +    channel     = None +    e_min       = None +    e_max       = None +    ## converted 2D RMF matrix/image from the squeezed binary table +    #  size: len(energ_lo) x len(channel) +    rmfimg      = None + +    def __init__(self, filename): +        self.filename    = filename +        self.fitsobj     = fits.open(filename) +        ## "MATRIX" extension +        ext_matrix       = self.fitsobj["MATRIX"] +        self.hdr_matrix  = ext_matrix.header +        self.energ_lo    = ext_matrix.data["ENERG_LO"] +        self.energ_hi    = ext_matrix.data["ENERG_HI"] +        self.n_grp       = ext_matrix.data["N_GRP"] +        self.f_chan      = ext_matrix.data["F_CHAN"] +        self.n_chan      = ext_matrix.data["N_CHAN"] +        self.matrix      = ext_matrix.data["MATRIX"] +        ## "EBOUNDS" extension +        ext_ebounds      = self.fitsobj["EBOUNDS"] +        self.hdr_ebounds = ext_ebounds.header +        self.channel     = ext_ebounds.data["CHANNEL"] +        self.e_min       = ext_ebounds.data["E_MIN"] +        self.e_max       = ext_ebounds.data["E_MAX"] + +    def get_energy(self, mean="geometric"): +        """ +        Return the mean energy values of the RMF "EBOUNDS". + +        Arguments: +          * mean: type of the mean energy: +                  + "geometric": geometric mean, i.e., e = sqrt(e_min*e_max) +                  + "arithmetic": arithmetic mean, i.e., e = 0.5*(e_min+e_max) +        """ +        if mean == "geometric": +            energy = np.sqrt(self.e_min * self.e_max) +        elif mean == "arithmetic": +            energy = 0.5 * (self.e_min + self.e_max) +        else: +            raise ValueError("Invalid mean type: %s" % mean) +        return energy + +    def get_rmfimg(self): +        """ +        Convert the RMF data in squeezed binary table (standard Legacy format) +        to a 2D image/matrix. +        """ +        def _make_rmfimg_row(n_channel, dtype, f_chan, n_chan, mat_row): +            # make sure that `f_chan' and `n_chan' are 1-D numpy array +            f_chan = np.array(f_chan).reshape(-1) +            f_chan -= 1  # FITS indices are 1-based +            n_chan = np.array(n_chan).reshape(-1) +            idx = np.concatenate([ np.arange(f, f+n) \ +                    for f, n in zip(f_chan, n_chan) ]) +            rmfrow = np.zeros(n_channel, dtype=dtype) +            rmfrow[idx] = mat_row +            return rmfrow +        # +        if self.rmfimg is None: +            # Make the 2D RMF matrix/image +            n_energy  = len(self.energ_lo) +            n_channel = len(self.channel) +            rmf_dtype = self.matrix[0].dtype +            rmfimg = np.zeros(shape=(n_energy, n_channel), dtype=rmf_dtype) +            for i in np.arange(n_energy)[self.n_grp > 0]: +                rmfimg[i, :] = _make_rmfimg_row(n_channel, rmf_dtype, +                        self.f_chan[i], self.n_chan[i], self.matrix[i]) +            self.rmfimg = rmfimg +        return self.rmfimg + +    def write_rmfimg(self, outfile, clobber=False): +        rmfimg = self.get_rmfimg() +        # merge headers +        header = self.hdr_matrix.copy(strip=True) +        header.extend(self.hdr_ebounds.copy(strip=True)) +        outfits = fits.PrimaryHDU(data=rmfimg, header=header) +        outfits.writeto(outfile, checksum=True, clobber=clobber) +# class RMF }}} + + +class Spectrum:  # {{{ +    """ +    Class that deals with the X-ray spectrum file (usually *.pi). +    """ +    filename  = None +    # FITS object return by `fits.open()' +    fitsobj   = None +    # header of "SPECTRUM" extension +    header    = None +    # "SPECTRUM" extension data +    channel   = None +    # name of the spectrum data column (i.e., type, "COUNTS" or "RATE") +    spec_type = None +    # unit of the spectrum data ("count" for "COUNTS", "count/s" for "RATE") +    spec_unit = None +    # spectrum data +    spec_data = None +    # estimated spectral errors for each channel/group +    spec_err  = None +    # statistical errors for each channel/group +    stat_err  = None +    # grouping and quality +    grouping  = None +    quality   = None +    # whether the spectral data being groupped +    groupped  = False +    # several important keywords +    EXPOSURE  = None +    BACKSCAL  = None +    AREASCAL  = None +    RESPFILE  = None +    ANCRFILE  = None +    BACKFILE  = None +    # numpy dtype and FITS format code of the spectrum data +    spec_dtype       = None +    spec_fits_format = None +    # output filename for writing the spectrum if no filename provided +    outfile = None + +    def __init__(self, filename, outfile=None): +        self.filename = filename +        self.fitsobj  = fits.open(filename) +        ext_spec      = self.fitsobj["SPECTRUM"] +        self.header   = ext_spec.header.copy(strip=True) +        colnames      = ext_spec.columns.names +        if "COUNTS" in colnames: +            self.spec_type = "COUNTS" +        elif "RATE" in colnames: +            self.spec_type = "RATE" +        else: +            raise ValueError("Invalid spectrum file") +        self.channel          = ext_spec.data.columns["CHANNEL"].array +        col_spec_data         = ext_spec.data.columns[self.spec_type] +        self.spec_data        = col_spec_data.array.copy() +        self.spec_unit        = col_spec_data.unit +        self.spec_dtype       = col_spec_data.dtype +        self.spec_fits_format = col_spec_data.format +        # grouping and quality +        if "GROUPING" in colnames: +            self.grouping = ext_spec.data.columns["GROUPING"].array +        if "QUALITY"  in colnames: +            self.quality  = ext_spec.data.columns["QUALITY"].array +        # keywords +        self.EXPOSURE = self.header.get("EXPOSURE") +        self.BACKSCAL = self.header.get("BACKSCAL") +        self.AREASCAL = self.header.get("AREASCAL") +        self.RESPFILE = self.header.get("RESPFILE") +        self.ANCRFILE = self.header.get("ANCRFILE") +        self.BACKFILE = self.header.get("BACKFILE") +        # output filename +        self.outfile  = outfile + +    def get_data(self, group_squeeze=False, copy=True): +        """ +        Get the spectral data (i.e., self.spec_data). + +        Arguments: +          * group_squeeze: whether squeeze the spectral data according to +                           the grouping (i.e., exclude the channels that +                           are not the first channel of the group, which +                           also have value of ZERO). +                           This argument is effective only the grouping +                           being applied. +        """ +        if group_squeeze and self.groupped: +            spec_data = self.spec_data[self.grouping == 1] +        else: +            spec_data = self.spec_data +        if copy: +            return spec_data.copy() +        else: +            return spec_data + +    def get_channel(self, copy=True): +        if copy: +            return self.channel.copy() +        else: +            return self.channel + +    def set_data(self, spec_data, group_squeeze=True): +        """ +        Set the spectral data of this spectrum to the supplied data. +        """ +        if group_squeeze and self.groupped: +            assert sum(self.grouping == 1) == len(spec_data) +            self.spec_data[self.grouping == 1] = spec_data +        else: +            assert len(self.spec_data) == len(spec_data) +            self.spec_data = spec_data.copy() + +    def add_stat_err(self, stat_err, group_squeeze=True): +        """ +        Add the "STAT_ERR" column as the statistical errors of each spectral +        group, which are estimated by utilizing the Monte Carlo techniques. +        """ +        self.stat_err = np.zeros(self.spec_data.shape, +                dtype=self.spec_data.dtype) +        if group_squeeze and self.groupped: +            assert sum(self.grouping == 1) == len(stat_err) +            self.stat_err[self.grouping == 1] = stat_err +        else: +            assert len(self.stat_err) == len(stat_err) +            self.stat_err = stat_err.copy() +        self.header["POISSERR"] = False + +    def apply_grouping(self, grouping=None, quality=None): +        """ +        Apply the spectral channel grouping specification to the spectrum. + +        NOTE: +        * The spectral data (i.e., self.spec_data) is MODIFIED! +        * The spectral data within the same group are summed up. +        * The self grouping is overwritten if `grouping' is supplied, as well +          as the self quality. +        """ +        if grouping is not None: +            self.grouping = grouping +        if quality is not None: +            self.quality = quality +        self.spec_data = group_data(self.spec_data, self.grouping) +        self.groupped = True + +    def estimate_errors(self, gehrels=True): +        """ +        Estimate the statistical errors of each spectral group (after +        applying grouping) for the source spectrum (and background spectrum). + +        If `gehrels=True', the statistical error for a spectral group with +        N photons is given by `1 + sqrt(N + 0.75)'; otherwise, the error +        is given by `sqrt(N)'. + +        Results: `self.spec_err' +        """ +        eps = 1.0e-10 +        if gehrels: +            self.spec_err = 1.0 + np.sqrt(self.spec_data + 0.75) +        else: +            self.spec_err = np.sqrt(self.spec_data) +        # replace the zeros with a very small value (because +        # `np.random.normal' requires `scale' > 0) +        self.spec_err[self.spec_err <= 0.0] = eps + +    def copy(self): +        """ +        Return a copy of this object, with the `np.ndarray' properties are +        copied. +        """ +        new = copy(self) +        for k, v in self.__dict__.items(): +            if isinstance(v, np.ndarray): +                setattr(new, k, v.copy()) +        return new + +    def randomize(self): +        """ +        Randomize the spectral data according to the estimated spectral +        group errors by assuming the normal distribution. + +        NOTE: this method should be called AFTER the `copy()' method. +        """ +        if self.spec_err is None: +            raise ValueError("No valid 'spec_err' presents") +        if self.groupped: +            idx = self.grouping == 1 +            self.spec_data[idx] = np.random.normal(self.spec_data[idx], +                    self.spec_err[idx]) +        else: +            self.spec_data = np.random.normal(self.spec_data, self.spec_err) +        return self + +    def fix_header_keywords(self, +            reset_kw=["ANCRFILE", "RESPFILE", "BACKFILE"]): +        """ +        Reset the keywords to "NONE" to avoid confusion or mistakes, +        and also add mandatory spectral keywords if missing. + +        Reference: +        [1] The OGIP Spectral File Format, Sec. 3.1.1 +            https://heasarc.gsfc.nasa.gov/docs/heasarc/ofwg/docs/summary/ogip_92_007_summary.html +        """ +        default_keywords = { +                ## Mandatory keywords +                #"EXTNAME"  : "SPECTRUM", +                "TELESCOP" : "NONE", +                "INSTRUME" : "NONE", +                "FILTER"   : "NONE", +                #"EXPOSURE" : <integration_time (s)>, +                "BACKFILE" : "NONE", +                "CORRFILE" : "NONE", +                "CORRSCAL" : 1.0, +                "RESPFILE" : "NONE", +                "ANCRFILE" : "NONE", +                "HDUCLASS" : "OGIP", +                "HDUCLAS1" : "SPECTRUM", +                "HDUVERS"  : "1.2.1", +                "POISSERR" : True, +                #"CHANTYPE" : "PI", +                #"DETCHANS" : <total_number_of_detector_channels>, +                ## Optional keywords for further information +                "BACKSCAL" : 1.0, +                "AREASCAL" : 1.0, +                # Type of spectral data: +                #   (1) "TOTAL": gross spectrum (source+bkg); +                #   (2) "NET": background-subtracted spectrum +                #   (3) "BKG" background spectrum +                #"HDUCLAS2" : "NET", +                # Details of the type of data: +                #   (1) "COUNT": data stored as counts +                #   (2) "RATE": data stored as counts/s +                "HDUCLAS3" : { "COUNTS":"COUNT", +                               "RATE":"RATE" }.get(self.spec_type), +        } +        # add mandatory keywords if missing +        for kw, value in default_keywords.items(): +            if kw not in self.header: +                self.header[kw] = value +        # reset the specified keywords +        for kw in reset_kw: +            self.header[kw] = default_keywords.get(kw) + +    def write(self, filename=None, clobber=False): +        """ +        Create a new "SPECTRUM" table/extension and replace the original +        one, then write to output file. +        """ +        if filename is None: +            filename = self.outfile +        columns = [ +                fits.Column(name="CHANNEL", format="I", array=self.channel), +                fits.Column(name=self.spec_type, format=self.spec_fits_format, +                            unit=self.spec_unit, array=self.spec_data), +        ] +        if self.grouping is not None: +            columns.append(fits.Column(name="GROUPING", +                                       format="I", array=self.grouping)) +        if self.quality  is not None: +            columns.append(fits.Column(name="QUALITY", +                                       format="I", array=self.quality)) +        if self.stat_err is not None: +            columns.append(fits.Column(name="STAT_ERR", unit=self.spec_unit, +                                       format=self.spec_fits_format, +                                       array=self.stat_err)) +        ext_spec_cols = fits.ColDefs(columns) +        ext_spec = fits.BinTableHDU.from_columns(ext_spec_cols, +                                                 header=self.header) +        self.fitsobj["SPECTRUM"] = ext_spec +        self.fitsobj.writeto(filename, clobber=clobber, checksum=True) +# class Spectrum }}} + + +class SpectrumSet(Spectrum):  # {{{ +    """ +    This class handles a set of spectrum, including the source spectrum, +    RMF, ARF, and the background spectrum. + +    **NOTE**: +    The "COUNTS" column data are converted from "int32" to "float32", +    since this spectrum will be subtracted/compensated according to the +    ratios of ARFs. +    """ +    # ARF object for this spectrum +    arf = None +    # RMF object for this spectrum +    rmf = None +    # background Spectrum object for this spectrum +    bkg = None +    # inner and outer radius of the region from which the spectrum extracted +    radius_inner = None +    radius_outer = None +    # total angular range of the spectral region +    angle = None + +    # numpy dtype and FITS format code to which the spectrum data be +    # converted if the data is "COUNTS" +    #_spec_dtype       = np.float32 +    #_spec_fits_format = "E" +    _spec_dtype       = np.float64 +    _spec_fits_format = "D" + +    def __init__(self, filename, outfile=None, arf=None, rmf=None, bkg=None): +        super().__init__(filename, outfile) +        # convert spectrum data type if necessary +        if self.spec_data.dtype != self._spec_dtype: +            self.spec_data        = self.spec_data.astype(self._spec_dtype) +            self.spec_dtype       = self._spec_dtype +            self.spec_fits_format = self._spec_fits_format +        if arf is not None: +            if isinstance(arf, ARF): +                self.arf = arf +            else: +                self.arf = ARF(arf) +        if rmf is not None: +            if isinstance(rmf, RMF): +                self.rmf = rmf +            else: +                self.rmf = RMF(rmf) +        if bkg is not None: +            if isinstance(bkg, Spectrum): +                self.bkg = bkg +            else: +                self.bkg = Spectrum(bkg) +            # convert background spectrum data type if necessary +            if self.bkg.spec_data.dtype != self._spec_dtype: +                self.bkg.spec_data = self.bkg.spec_data.astype(self._spec_dtype) +                self.bkg.spec_dtype = self._spec_dtype +                self.bkg.spec_fits_format = self._spec_fits_format + +    def get_energy(self, mean="geometric"): +        """ +        Get the energy values of each channel if RMF present. + +        NOTE: +        The "E_MIN" and "E_MAX" columns of the RMF is required to calculate +        the spectrum channel energies. +        And the channel energies are generally different to the "ENERG_LO" +        and "ENERG_HI" of the corresponding ARF. +        """ +        if self.rmf is None: +            return None +        else: +            return self.rmf.get_energy(mean=mean) + +    def get_arf(self, mean="geometric", groupped=True, copy=True): +        """ +        Get the interpolated ARF data w.r.t the spectral channel energies +        if the ARF presents. + +        Arguments: +          * groupped: (bool) whether to get the groupped ARF + +        Return: (groupped) interpolated ARF data +        """ +        if self.arf is None: +            return None +        else: +            return self.arf.get_data(groupped=groupped, copy=copy) + +    def read_xflt(self): +        """ +        Read the XFLT000# keywords from the header, check the validity (e.g., +        "XFLT0001" should equals "XFLT0002", "XFLT0003" should equals 0). +        Sum all the additional XFLT000# pairs (e.g., ) which describes the +        regions angluar ranges. +        """ +        eps = 1.0e-6 +        xflt0001 = float(self.header["XFLT0001"]) +        xflt0002 = float(self.header["XFLT0002"]) +        xflt0003 = float(self.header["XFLT0003"]) +        # XFLT000# validity check +        assert np.isclose(xflt0001, xflt0002) +        assert abs(xflt0003) < eps +        # outer radius of the region +        self.radius_outer = xflt0001 +        # angular regions +        self.angle = 0.0 +        num = 4 +        while True: +            try: +                angle_begin = float(self.header["XFLT%04d" % num]) +                angle_end   = float(self.header["XFLT%04d" % (num+1)]) +                num += 2 +            except KeyError: +                break +            self.angle += (angle_end - angle_begin) +        # if NO additional XFLT000# keys exist, assume "annulus" region +        if self.angle < eps: +            self.angle = 360.0 + +    def scale(self): +        """ +        Scale the spectral data (and spectral group errors if present) of +        the source spectrum (and background spectra if present) according +        to the region angular size to make it correspond to the whole annulus +        region (i.e., 360 degrees). + +        NOTE: the spectral data and errors (i.e., `self.spec_data', and +        `self.spec_err') is MODIFIED! +        """ +        self.spec_data *= (360.0 / self.angle) +        if self.spec_err is not None: +            self.spec_err *= (360.0 / self.angle) +        # also scale the background spectrum if present +        if self.bkg: +            self.bkg.spec_data *= (360.0 / self.angle) +            if self.bkg.spec_err is not None: +                self.bkg.spec_err *= (360.0 / self.angle) + +    def apply_grouping(self, grouping=None, quality=None, verbose=False): +        """ +        Apply the spectral channel grouping specification to the source +        spectrum, the ARF (which is used during the later spectral +        manipulations), and the background spectrum (if presents). + +        NOTE: +        * The spectral data (i.e., self.spec_data) is MODIFIED! +        * The spectral data within the same group are summed up. +        * The self grouping is overwritten if `grouping' is supplied, as well +          as the self quality. +        """ +        super().apply_grouping(grouping=grouping, quality=quality) +        # also group the ARF accordingly +        self.arf.apply_grouping(energy_channel=self.get_energy(), +                                grouping=self.grouping, verbose=verbose) +        # group the background spectrum if present +        if self.bkg: +            self.bkg.spec_data = group_data(self.bkg.spec_data, self.grouping) + +    def estimate_errors(self, gehrels=True): +        """ +        Estimate the statistical errors of each spectral group (after +        applying grouping) for the source spectrum (and background spectrum). + +        If `gehrels=True', the statistical error for a spectral group with +        N photons is given by `1 + sqrt(N + 0.75)'; otherwise, the error +        is given by `sqrt(N)'. + +        Results: `self.spec_err' (and `self.bkg.spec_err') +        """ +        super().estimate_errors(gehrels=gehrels) +        eps = 1.0e-10 +        # estimate the errors for background spectrum if present +        if self.bkg: +            if gehrels: +                self.bkg.spec_err = 1.0 + np.sqrt(self.bkg.spec_data + 0.75) +            else: +                self.bkg.spec_err = np.sqrt(self.bkg.spec_data) +            self.bkg.spec_err[self.bkg.spec_err <= 0.0] = eps + +    def subtract_bkg(self, inplace=True, add_history=False, verbose=False): +        """ +        Subtract the background contribution from the source spectrum. +        The `EXPOSURE' and `BACKSCAL' values are required to calculate +        the fraction/ratio for the background subtraction. + +        Arguments: +          * inplace: whether replace the `spec_data' with the background- +                     subtracted spectrum data; If True, the attribute +                     `spec_bkg_subtracted' is also set to `True' when +                     the subtraction finished. +                     The keywords "BACKSCAL" and "AREASCAL" are set to 1.0. + +        Return: +          background-subtracted spectrum data +        """ +        ratio = (self.EXPOSURE / self.bkg.EXPOSURE) * \ +                (self.BACKSCAL / self.bkg.BACKSCAL) * \ +                (self.AREASCAL / self.bkg.AREASCAL) +        operation = "  SUBTRACT_BACKGROUND: %s - %s * %s" % \ +                (self.filename, ratio, self.bkg.filename) +        if verbose: +            print(operation, file=sys.stderr) +        spec_data_subbkg = self.spec_data - ratio * self.bkg.get_data() +        if inplace: +            self.spec_data = spec_data_subbkg +            self.spec_bkg_subtracted = True +            self.BACKSCAL = 1.0 +            self.AREASCAL = 1.0 +            # update header +            self.header["BACKSCAL"] = 1.0 +            self.header["AREASCAL"] = 1.0 +            self.header["BACKFILE"] = "NONE" +            self.header["HDUCLAS2"] = "NET"  # background-subtracted spectrum +            # also record history +            if add_history: +                self.header.add_history(operation) +        return spec_data_subbkg + +    def subtract(self, spectrumset, cross_arf, groupped=False, +            group_squeeze=False, add_history=False, verbose=False): +        """ +        Subtract the photons that originate from the surrounding regions +        but were scattered into this spectrum due to the finite PSF. + +        The background of this spectrum and the given spectrum should +        both be subtracted before applying this subtraction for crosstalk +        correction, as well as the below `compensate()' procedure. + +        NOTE: +        1. The crosstalk ARF must be provided, since the `spectrumset.arf' +           is required to be its ARF without taking crosstalk into account: +               spec1_new = spec1 - spec2 * (cross_arf_2_to_1 / arf2) +        2. The ARF are interpolated to match the energies of spetral channels. +        """ +        operation = "  SUBTRACT: %s - (%s/%s) * %s" % (self.filename, +                cross_arf.filename, spectrumset.arf.filename, +                spectrumset.filename) +        if verbose: +            print(operation, file=sys.stderr) +        energy = self.get_energy() +        if groupped: +            spectrumset.arf.apply_grouping(energy_channel=energy, +                    grouping=self.grouping, verbose=verbose) +            cross_arf.apply_grouping(energy_channel=energy, +                    grouping=self.grouping, verbose=verbose) +            arfresp_spec  = spectrumset.arf.get_data(groupped=True, +                    group_squeeze=group_squeeze) +            arfresp_cross = cross_arf.get_data(groupped=True, +                    group_squeeze=group_squeeze) +        else: +            arfresp_spec  = spectrumset.arf.interpolate(x=energy, +                    verbose=verbose) +            arfresp_cross = cross_arf.interpolate(x=energy, verbose=verbose) +        with np.errstate(divide="ignore", invalid="ignore"): +            arf_ratio = arfresp_cross / arfresp_spec +            # fix nan/inf values due to division by zero +            arf_ratio[ ~ np.isfinite(arf_ratio) ] = 0.0 +        spec_data = self.get_data(group_squeeze=group_squeeze) - \ +                spectrumset.get_data(group_squeeze=group_squeeze)*arf_ratio +        self.set_data(spec_data, group_squeeze=group_squeeze) +        # record history +        if add_history: +            self.header.add_history(operation) + +    def compensate(self, cross_arf, groupped=False, group_squeeze=False, +            add_history=False, verbose=False): +        """ +        Compensate the photons that originate from this regions but were +        scattered into the surrounding regions due to the finite PSF. + +        formula: +            spec1_new = spec1 + spec1 * (cross_arf_1_to_2 / arf1) +        """ +        operation = "  COMPENSATE: %s + (%s/%s) * %s" % (self.filename, +                cross_arf.filename, self.arf.filename, self.filename) +        if verbose: +            print(operation, file=sys.stderr) +        energy = self.get_energy() +        if groupped: +            cross_arf.apply_grouping(energy_channel=energy, +                    grouping=self.grouping, verbose=verbose) +            arfresp_this  = self.arf.get_data(groupped=True, +                    group_squeeze=group_squeeze) +            arfresp_cross = cross_arf.get_data(groupped=True, +                    group_squeeze=group_squeeze) +        else: +            arfresp_this  = self.arf.interpolate(x=energy, verbose=verbose) +            arfresp_cross = cross_arf.interpolate(x=energy, verbose=verbose) +        with np.errstate(divide="ignore", invalid="ignore"): +            arf_ratio = arfresp_cross / arfresp_this +            # fix nan/inf values due to division by zero +            arf_ratio[ ~ np.isfinite(arf_ratio) ] = 0.0 +        spec_data = self.get_data(group_squeeze=group_squeeze) + \ +                self.get_data(group_squeeze=group_squeeze) * arf_ratio +        self.set_data(spec_data, group_squeeze=group_squeeze) +        # record history +        if add_history: +            self.header.add_history(operation) + +    def fix_negative(self, add_history=False, verbose=False): +        """ +        The subtractions may lead to negative counts, it may be necessary +        to fix these channels with negative values. +        """ +        neg_counts = self.spec_data < 0 +        N = len(neg_counts) +        neg_channels = np.arange(N, dtype=np.int)[neg_counts] +        if len(neg_channels) > 0: +            print("WARNING: %d channels have NEGATIVE counts" % \ +                    len(neg_channels), file=sys.stderr) +        i = 0 +        while len(neg_channels) > 0: +            i += 1 +            if verbose: +                if i == 1: +                    print("*** Fixing negative channels: iter %d..." % i, +                            end="", file=sys.stderr) +                else: +                    print("%d..." % i, end="", file=sys.stderr) +            for ch in neg_channels: +                neg_val = self.spec_data[ch] +                if ch < N-2: +                    self.spec_data[ch] = 0 +                    self.spec_data[(ch+1):(ch+3)] -= 0.5 * np.abs(neg_val) +                else: +                    # just set to zero if it is the last 2 channels +                    self.spec_data[ch] = 0 +            # update negative channels indices +            neg_counts = self.spec_data < 0 +            neg_channels = np.arange(N, dtype=np.int)[neg_counts] +        if i > 0: +            print("FIXED!", file=sys.stderr) +            # record history +            if add_history: +                self.header.add_history("  FIXED NEGATIVE CHANNELS") + +    def set_radius_inner(self, radius_inner): +        """ +        Set the inner radius of the spectral region. +        """ +        assert radius_inner < self.radius_outer +        self.radius_inner = radius_inner + +    def copy(self): +        """ +        Return a copy of this object. +        """ +        new = super().copy() +        if self.bkg: +            new.bkg = self.bkg.copy() +        return new + +    def randomize(self): +        """ +        Randomize the source (and background if present) spectral data +        according to the estimated spectral group errors by assuming the +        normal distribution. + +        NOTE: this method should be called AFTER the `copy()' method. +        """ +        super().randomize() +        if self.bkg: +            self.bkg.spec_data = np.random.normal(self.bkg.spec_data, +                                                  self.bkg.spec_err) +            self.bkg.spec_data[self.grouping == -1] = 0.0 +        return self +# class SpectrumSet }}} + + +class Crosstalk:  # {{{ +    """ +    XMM-Newton PSF Crosstalk effect correction. +    """ +    # `SpectrumSet' object for the spectrum to be corrected +    spectrumset = None +    # NOTE/XXX: do NOT use list (e.g., []) here, otherwise, all the +    #           instances will share these list properties. +    # `SpectrumSet' and `ARF' objects corresponding to the spectra from +    # which the photons were scattered into this spectrum. +    cross_in_specset = None +    cross_in_arf     = None +    # `ARF' objects corresponding to the regions to which the photons of +    # this spectrum were scattered into. +    cross_out_arf = None +    # grouping specification and quality data +    grouping = None +    quality = None +    # whether the spectrum is groupped +    groupped = False + +    def __init__(self, config, arf_dict={}, rmf_dict={}, +            grouping=None, quality=None): +        """ +        Arguments: +          * config: a section of the whole config file (`ConfigObj' object) +        """ +        self.cross_in_specset = [] +        self.cross_in_arf     = [] +        self.cross_out_arf    = [] +        # this spectrum to be corrected +        self.spectrumset = SpectrumSet(filename=config["spec"], +                outfile=config["outfile"], +                arf=arf_dict.get(config["arf"], config["arf"]), +                rmf=rmf_dict.get(config.get("rmf"), config.get("rmf")), +                bkg=config.get("bkg")) +        # spectra and cross arf from which photons were scattered in +        for reg_in in config["cross_in"].values(): +            specset = SpectrumSet(filename=reg_in["spec"], +                    arf=arf_dict.get(reg_in["arf"], reg_in["arf"]), +                    rmf=rmf_dict.get(reg_in.get("rmf"), reg_in.get("rmf")), +                    bkg=reg_in.get("bkg")) +            self.cross_in_specset.append(specset) +            self.cross_in_arf.append(arf_dict.get(reg_in["cross_arf"], +                                                  ARF(reg_in["cross_arf"]))) +        # regions into which the photons of this spectrum were scattered into +        if "cross_out" in config.sections: +            cross_arf = config["cross_out"].as_list("cross_arf") +            for arffile in cross_arf: +                self.cross_out_arf.append(arf_dict.get(arffile, ARF(arffile))) +        # grouping and quality +        self.grouping = grouping +        self.quality = quality + +    def apply_grouping(self, verbose=False): +        self.spectrumset.apply_grouping(grouping=self.grouping, +                quality=self.quality, verbose=verbose) +        # also group the related surrounding spectra +        for specset in self.cross_in_specset: +            specset.apply_grouping(grouping=self.grouping, +                quality=self.quality, verbose=verbose) +        self.groupped = True + +    def estimate_errors(self, gehrels=True, verbose=False): +        if verbose: +            print("INFO: Estimating spectral errors ...") +        self.spectrumset.estimate_errors(gehrels=gehrels) +        # also estimate errors for the related surrounding spectra +        for specset in self.cross_in_specset: +            specset.estimate_errors(gehrels=gehrels) + +    def do_correction(self, subtract_bkg=True, fix_negative=False, +            group_squeeze=True, add_history=False, verbose=False): +        """ +        Perform the crosstalk correction.  The background contribution +        for each spectrum is subtracted first if `subtract_bkg' is True. +        The basic correction procedures are recorded to the header. +        """ +        if add_history: +            self.spectrumset.header.add_history("Crosstalk Correction BEGIN") +            self.spectrumset.header.add_history("  TOOL: %s (v%s) @ %s" % (\ +                    os.path.basename(sys.argv[0]), __version__, +                    datetime.utcnow().isoformat())) +        # background subtraction +        if subtract_bkg: +            if verbose: +                print("INFO: subtract background ...", file=sys.stderr) +            self.spectrumset.subtract_bkg(inplace=True, +                    add_history=add_history, verbose=verbose) +            # also apply background subtraction to the surrounding spectra +            for specset in self.cross_in_specset: +                specset.subtract_bkg(inplace=True, +                        add_history=add_history, verbose=verbose) +        # subtractions +        if verbose: +            print("INFO: apply subtractions ...", file=sys.stderr) +        for specset, cross_arf in zip(self.cross_in_specset, +                self.cross_in_arf): +            self.spectrumset.subtract(spectrumset=specset, +                    cross_arf=cross_arf, groupped=self.groupped, +                    group_squeeze=group_squeeze, add_history=add_history, +                    verbose=verbose) +        # compensations +        if verbose: +            print("INFO: apply compensations ...", file=sys.stderr) +        for cross_arf in self.cross_out_arf: +            self.spectrumset.compensate(cross_arf=cross_arf, +                    groupped=self.groupped, group_squeeze=group_squeeze, +                    add_history=add_history, verbose=verbose) +        # fix negative values in channels +        if fix_negative: +            if verbose: +                print("INFO: fix negative channel values ...", file=sys.stderr) +            self.spectrumset.fix_negative(add_history=add_history, +                    verbose=verbose) +        if add_history: +            self.spectrumset.header.add_history("END Crosstalk Correction") + +    def fix_header(self): +        # fix header keywords +        self.spectrumset.fix_header_keywords( +                reset_kw=["RESPFILE", "ANCRFILE", "BACKFILE"]) + +    def copy(self): +        new = copy(self) +        # properly handle the copy of spectrumsets +        new.spectrumset = self.spectrumset.copy() +        new.cross_in_specset = [ specset.copy() \ +                for specset in self.cross_in_specset ] +        return new + +    def randomize(self): +        self.spectrumset.randomize() +        for specset in self.cross_in_specset: +            specset.randomize() +        return self + +    def get_spectrum(self, copy=True): +        if copy: +            return self.spectrumset.copy() +        else: +            return self.spectrumset + +    def write(self, filename=None, clobber=False): +        self.spectrumset.write(filename=filename, clobber=clobber) +# class Crosstalk }}} + + +class Deprojection:  # {{{ +    """ +    Perform the deprojection on a set of PROJECTED spectra with the +    assumption of spherical symmetry of the source object, and produce +    the DEPROJECTED spectra. + +    NOTE: +    * Assumption of the spherical symmetry +    * Background should be subtracted before deprojection +    * ARF differences of different regions are taken into account + +    Reference & Credit: +    [1] Direct X-ray Spectra Deprojection +        https://www-xray.ast.cam.ac.uk/papers/dsdeproj/ +        Sanders & Fabian 2007, MNRAS, 381, 1381 +    """ +    spectra  = None +    grouping = None +    quality  = None + +    def __init__(self, spectra, grouping=None, quality=None, verbose=False): +        """ +        Arguments: +          * spectra: a set of spectra from the inner-most to the outer-most +                     regions (e.g., spectra after correcting crosstalk effect) +          * grouping: grouping specification for all the spectra +          * quality: quality column for the spectra +        """ +        self.spectra = [] +        for spec in spectra: +            if not isinstance(spec, SpectrumSet): +                raise ValueError("Not a 'SpectrumSet' object") +            spec.read_xflt() +            self.spectra.append(spec) +        self.spectra  = spectra +        self.grouping = grouping +        self.quality  = quality +        # sort spectra by `radius_outer' +        self.spectra.sort(key=lambda x: x.radius_outer) +        # set the inner radii +        radii_inner = [0.0] + [ x.radius_outer for x in self.spectra[:-1] ] +        for spec, rin in zip(self.spectra, radii_inner): +            spec.set_radius_inner(rin) +            if verbose: +                print("Deprojection: loaded spectrum: radius: (%s, %s)" % \ +                        (spec.radius_inner, spec.radius_outer), +                        file=sys.stderr) +        # check EXPOSURE validity (all spectra must have the same exposures) +        exposures = [ spec.EXPOSURE for spec in self.spectra ] +        assert np.allclose(exposures[:-1], exposures[1:]) + +    def subtract_bkg(self, verbose=True): +        for spec in self.spectra: +            if not spec.bkg: +                raise ValueError("Spectrum '%s' has NO background" % \ +                        spec.filename) +            spec.subtract_bkg(inplace=True, verbose=verbose) + +    def apply_grouping(self, verbose=False): +        for spec in self.spectra: +            spec.apply_grouping(grouping=self.grouping, quality=self.quality, +                    verbose=verbose) + +    def estimate_errors(self, gehrels=True): +        for spec in self.spectra: +            spec.estimate_errors(gehrels=gehrels) + +    def scale(self): +        """ +        Scale the spectral data according to the region angular size. +        """ +        for spec in self.spectra: +            spec.scale() + +    def do_deprojection(self, group_squeeze=True, +            add_history=False, verbose=False): +        # +        # TODO/XXX: How to apply ARF correction here??? +        # +        num_spec = len(self.spectra) +        tmp_spec_data = self.spectra[0].get_data(group_squeeze=group_squeeze) +        spec_shape = tmp_spec_data.shape +        spec_dtype = tmp_spec_data.dtype +        spec_per_vol = [None] * num_spec +        # +        for shellnum in reversed(range(num_spec)): +            if verbose: +                print("DEPROJECTION: deprojecting shell %d ..." % shellnum, +                        file=sys.stderr) +            spec = self.spectra[shellnum] +            # calculate projected spectrum of outlying shells +            proj_spec = np.zeros(spec_shape, spec_dtype) +            for outer in range(shellnum+1, num_spec): +                vol = self.projected_volume( +                        r1=self.spectra[outer].radius_inner, +                        r2=self.spectra[outer].radius_outer, +                        R1=spec.radius_inner, +                        R2=spec.radius_outer) +                proj_spec += spec_per_vol[outer] * vol +            # +            this_spec = spec.get_data(group_squeeze=group_squeeze, copy=True) +            deproj_spec = this_spec - proj_spec +            # calculate the volume that this spectrum is from +            this_vol = self.projected_volume( +                    r1=spec.radius_inner, r2=spec.radius_outer, +                    R1=spec.radius_inner, R2=spec.radius_outer) +            # calculate the spectral data per unit volume +            spec_per_vol[shellnum] = deproj_spec / this_vol +        # set the spectral data to these deprojected values +        self.set_spec_data(spec_per_vol, group_squeeze=group_squeeze) +        # add history to header +        if add_history: +            self.add_history() + +    def get_spec_data(self, group_squeeze=True, copy=True): +        """ +        Extract the spectral data of each spectrum after deprojection +        performed. +        """ +        return [ spec.get_data(group_squeeze=group_squeeze, copy=copy) +                 for spec in self.spectra ] + +    def set_spec_data(self, spec_data, group_squeeze=True): +        """ +        Set `spec_data' for each spectrum to the deprojected spectral data. +        """ +        assert len(spec_data) == len(self.spectra) +        for spec, data in zip(self.spectra, spec_data): +            spec.set_data(data, group_squeeze=group_squeeze) + +    def add_stat_err(self, stat_err, group_squeeze=True): +        """ +        Add the "STAT_ERR" column to each spectrum. +        """ +        assert len(stat_err) == len(self.spectra) +        for spec, err in zip(self.spectra, stat_err): +            spec.add_stat_err(err, group_squeeze=group_squeeze) + +    def add_history(self): +        """ +        Append a brief history about this tool to the header. +        """ +        history = "Deprojected by %s (v%s) @ %s" % ( +                os.path.basename(sys.argv[0]), __version__, +                datetime.utcnow().isoformat()) +        for spec in self.spectra: +            spec.header.add_history(history) + +    def fix_header(self): +        # fix header keywords +        for spec in self.spectra: +            spec.fix_header_keywords( +                    reset_kw=["RESPFILE", "ANCRFILE", "BACKFILE"]) + +    def write(self, filenames=[], clobber=False): +        """ +        Write the deprojected spectra to output file. +        """ +        if filenames == []: +            filenames = [ spec.outfile for spec in self.spectra ] +        for spec, outfile in zip(self.spectra, filenames): +            spec.write(filename=outfile, clobber=clobber) + +    @staticmethod +    def projected_volume(r1, r2, R1, R2): +        """ +        Calculate the projected volume of a spherical shell of radii r1 -> r2 +        onto an annulus on the sky of radius R1 -> R2. + +        This volume is the integral: +          Int(R=R1,R2) Int(x=sqrt(r1^2-R^2),sqrt(r2^2-R^2)) 2*pi*R dx dR +          = +          Int(R=R1,R2) 2*pi*R * (sqrt(r2^2-R^2) - sqrt(r1^2-R^2)) dR + +        Note that the above integral is only half the total volume +        (i.e., front only). +        """ +        def sqrt_trunc(x): +            if x > 0: +                return np.sqrt(x) +            else: +                return 0.0 +        # +        p1 = sqrt_trunc(r1**2 - R2**2) +        p2 = sqrt_trunc(r1**2 - R1**2) +        p3 = sqrt_trunc(r2**2 - R2**2) +        p4 = sqrt_trunc(r2**2 - R1**2) +        return 2.0 * (2.0/3.0) * np.pi * ((p1**3 - p2**3) + (p4**3 - p3**3)) +# class Deprojection }}} + + +# Helper functions {{{ +def calc_median_errors(results): +    """ +    Calculate the median and errors for the spectral data gathered +    through Monte Carlo simulations. + +    TODO: investigate the errors calculation approach used here! +    """ +    results = np.array(results) +    # `results' now has shape: (mc_times, num_spec, num_channel) +    # sort by the Monte Carlo simulation axis +    results.sort(0) +    mc_times = results.shape[0] +    medians  = results[ int(mc_times * 0.5) ] +    lowerpcs = results[ int(mc_times * 0.1585) ] +    upperpcs = results[ int(mc_times * 0.8415) ] +    errors   = np.sqrt(0.5 * ((medians-lowerpcs)**2 + (upperpcs-medians)**2)) +    return (medians, errors) + + +def set_argument(name, default, cmdargs, config): +    value = default +    if name in config.keys(): +        value = config.as_bool(name) +    value_cmd = vars(cmdargs)[name] +    if value_cmd != default: +        value = value_cmd  # command arguments overwrite others +    return value +# helper functions }}} + + +# main routine {{{ +def main(config, subtract_bkg, fix_negative, mc_times, +        verbose=False, clobber=False): +    # collect ARFs and RMFs into dictionaries (avoid interpolation every time) +    arf_files = set() +    rmf_files = set() +    for region in config.sections: +        config_reg = config[region] +        arf_files.add(config_reg.get("arf")) +        rmf_files.add(config_reg.get("rmf")) +        for reg_in in config_reg["cross_in"].values(): +            arf_files.add(reg_in.get("arf")) +            arf_files.add(reg_in.get("cross_arf")) +        if "cross_out" in config_reg.sections: +            for arf in config_reg["cross_out"].as_list("cross_arf"): +                arf_files.add(arf) +    arf_files = arf_files - set([None]) +    arf_dict  = { arf: ARF(arf) for arf in arf_files } +    rmf_files = rmf_files - set([None]) +    rmf_dict  = { rmf: RMF(rmf) for rmf in rmf_files } +    if verbose: +        print("INFO: arf_files:", arf_files, file=sys.stderr) +        print("INFO: rmf_files:", rmf_files, file=sys.stderr) + +    # get the GROUPING and QUALITY data +    grouping_fits = fits.open(config["grouping"]) +    grouping = grouping_fits["SPECTRUM"].data.columns["GROUPING"].array +    quality  = grouping_fits["SPECTRUM"].data.columns["QUALITY"].array +    # squeeze the groupped spectral data, etc. +    group_squeeze = True + +    # crosstalk objects (BEFORE background subtraction) +    crosstalks_cleancopy = [] +    # crosstalk-corrected spectra +    cc_spectra = [] + +    # correct crosstalk effects for each region first +    for region in config.sections: +        if verbose: +            print("INFO: processing '%s' ..." % region, file=sys.stderr) +        crosstalk = Crosstalk(config.get(region), +                arf_dict=arf_dict, rmf_dict=rmf_dict, +                grouping=grouping, quality=quality) +        crosstalk.apply_grouping(verbose=verbose) +        crosstalk.estimate_errors(verbose=verbose) +        # keep a (almost) clean copy of the crosstalk object +        crosstalks_cleancopy.append(crosstalk.copy()) +        if verbose: +            print("INFO: doing crosstalk correction ...", file=sys.stderr) +        crosstalk.do_correction(subtract_bkg=subtract_bkg, +                fix_negative=fix_negative, group_squeeze=group_squeeze, +                add_history=True, verbose=verbose) +        cc_spectra.append(crosstalk.get_spectrum(copy=True)) + +    # load back the crosstalk-corrected spectra for deprojection +    if verbose: +        print("INFO: preparing spectra for deprojection ...", file=sys.stderr) +    deprojection = Deprojection(spectra=cc_spectra, grouping=grouping, +            quality=quality, verbose=verbose) +    if verbose: +        print("INFO: scaling spectra according the region angular size...", +                file=sys.stderr) +    deprojection.scale() +    if verbose: +        print("INFO: doing deprojection ...", file=sys.stderr) +    deprojection.do_deprojection(add_history=True, verbose=verbose) +    deproj_results = [ deprojection.get_spec_data( +            group_squeeze=group_squeeze, copy=True) ] + +    # Monte Carlo for spectral group error estimation +    print("INFO: Monte Carlo to estimate spectral errors (%d times) ..." % \ +            mc_times, file=sys.stderr) +    for i in range(mc_times): +        if i % 100 == 0: +            print("%d..." % i, end="", flush=True, file=sys.stderr) +        # correct crosstalk effects +        cc_spectra_copy = [] +        for crosstalk in crosstalks_cleancopy: +            # copy and randomize +            crosstalk_copy = crosstalk.copy().randomize() +            crosstalk_copy.do_correction(subtract_bkg=subtract_bkg, +                    fix_negative=fix_negative, group_squeeze=group_squeeze, +                    add_history=False, verbose=False) +            cc_spectra_copy.append(crosstalk_copy.get_spectrum(copy=True)) +        # deproject spectra +        deprojection_copy = Deprojection(spectra=cc_spectra_copy, +                grouping=grouping, quality=quality, verbose=False) +        deprojection_copy.scale() +        deprojection_copy.do_deprojection(add_history=False, verbose=False) +        deproj_results.append(deprojection_copy.get_spec_data( +                group_squeeze=group_squeeze, copy=True)) +    print("DONE!", flush=True, file=sys.stderr) + +    if verbose: +        print("INFO: Calculating the median and errors for each spectrum ...", +                file=sys.stderr) +    medians, errors = calc_median_errors(deproj_results) +    deprojection.set_spec_data(medians, group_squeeze=group_squeeze) +    deprojection.add_stat_err(errors, group_squeeze=group_squeeze) +    if verbose: +        print("INFO: Writing the crosstalk-corrected and deprojected " + \ +              "spectra with estimated statistical errors ...", file=sys.stderr) +    deprojection.fix_header() +    deprojection.write(clobber=clobber) +# main routine }}} + + +# main_deprojection routine {{{ +def main_deprojection(config, mc_times, verbose=False, clobber=False): +    """ +    Only perform the spectral deprojection. +    """ +    # collect ARFs and RMFs into dictionaries (avoid interpolation every time) +    arf_files = set() +    rmf_files = set() +    for region in config.sections: +        config_reg = config[region] +        arf_files.add(config_reg.get("arf")) +        rmf_files.add(config_reg.get("rmf")) +    arf_files = arf_files - set([None]) +    arf_dict  = { arf: ARF(arf) for arf in arf_files } +    rmf_files = rmf_files - set([None]) +    rmf_dict  = { rmf: RMF(rmf) for rmf in rmf_files } +    if verbose: +        print("INFO: arf_files:", arf_files, file=sys.stderr) +        print("INFO: rmf_files:", rmf_files, file=sys.stderr) + +    # get the GROUPING and QUALITY data +    grouping_fits = fits.open(config["grouping"]) +    grouping = grouping_fits["SPECTRUM"].data.columns["GROUPING"].array +    quality  = grouping_fits["SPECTRUM"].data.columns["QUALITY"].array +    # squeeze the groupped spectral data, etc. +    group_squeeze = True + +    # load spectra for deprojection +    if verbose: +        print("INFO: preparing spectra for deprojection ...", file=sys.stderr) +    proj_spectra = [] +    for region in config.sections: +        config_reg = config[region] +        specset = SpectrumSet(filename=config_reg["spec"], +                outfile=config_reg["outfile"], +                arf=arf_dict.get(config_reg["arf"], config_reg["arf"]), +                rmf=rmf_dict.get(config_reg["rmf"], config_reg["rmf"]), +                bkg=config_reg["bkg"]) +        proj_spectra.append(specset) + +    deprojection = Deprojection(spectra=proj_spectra, grouping=grouping, +            quality=quality, verbose=verbose) +    deprojection.apply_grouping(verbose=verbose) +    deprojection.estimate_errors() +    if verbose: +        print("INFO: scaling spectra according the region angular size ...", +                file=sys.stderr) +    deprojection.scale() + +    # keep a (almost) clean copy of the input projected spectra +    proj_spectra_cleancopy = [ spec.copy() for spec in proj_spectra ] + +    if verbose: +        print("INFO: subtract the background ...", file=sys.stderr) +    deprojection.subtract_bkg(verbose=verbose) +    if verbose: +        print("INFO: doing deprojection ...", file=sys.stderr) +    deprojection.do_deprojection(add_history=True, verbose=verbose) +    deproj_results = [ deprojection.get_spec_data( +            group_squeeze=group_squeeze, copy=True) ] + +    # Monte Carlo for spectral group error estimation +    print("INFO: Monte Carlo to estimate spectral errors (%d times) ..." % \ +            mc_times, file=sys.stderr) +    for i in range(mc_times): +        if i % 100 == 0: +            print("%d..." % i, end="", flush=True, file=sys.stderr) +        # copy and randomize the input projected spectra +        proj_spectra_copy = [ spec.copy().randomize() +                              for spec in proj_spectra_cleancopy ] +        # deproject spectra +        deprojection_copy = Deprojection(spectra=proj_spectra_copy, +                grouping=grouping, quality=quality, verbose=False) +        deprojection_copy.subtract_bkg(verbose=False) +        deprojection_copy.do_deprojection(add_history=False, verbose=False) +        deproj_results.append(deprojection_copy.get_spec_data( +                group_squeeze=group_squeeze, copy=True)) +    print("DONE!", flush=True, file=sys.stderr) + +    if verbose: +        print("INFO: Calculating the median and errors for each spectrum ...", +                file=sys.stderr) +    medians, errors = calc_median_errors(deproj_results) +    deprojection.set_spec_data(medians, group_squeeze=group_squeeze) +    deprojection.add_stat_err(errors, group_squeeze=group_squeeze) +    if verbose: +        print("INFO: Writing the deprojected spectra " + \ +              "with estimated statistical errors ...", file=sys.stderr) +    deprojection.fix_header() +    deprojection.write(clobber=clobber) +# main_deprojection routine }}} + + +# main_crosstalk routine {{{ +def main_crosstalk(config, subtract_bkg, fix_negative, mc_times, +        verbose=False, clobber=False): +    """ +    Only perform the crosstalk correction. +    """ +    # collect ARFs and RMFs into dictionaries (avoid interpolation every time) +    arf_files = set() +    rmf_files = set() +    for region in config.sections: +        config_reg = config[region] +        arf_files.add(config_reg.get("arf")) +        rmf_files.add(config_reg.get("rmf")) +        for reg_in in config_reg["cross_in"].values(): +            arf_files.add(reg_in.get("arf")) +            arf_files.add(reg_in.get("cross_arf")) +        if "cross_out" in config_reg.sections: +            for arf in config_reg["cross_out"].as_list("cross_arf"): +                arf_files.add(arf) +    arf_files = arf_files - set([None]) +    arf_dict  = { arf: ARF(arf) for arf in arf_files } +    rmf_files = rmf_files - set([None]) +    rmf_dict  = { rmf: RMF(rmf) for rmf in rmf_files } +    if verbose: +        print("INFO: arf_files:", arf_files, file=sys.stderr) +        print("INFO: rmf_files:", rmf_files, file=sys.stderr) + +    # get the GROUPING and QUALITY data +    if "grouping" in config.keys(): +        grouping_fits = fits.open(config["grouping"]) +        grouping = grouping_fits["SPECTRUM"].data.columns["GROUPING"].array +        quality  = grouping_fits["SPECTRUM"].data.columns["QUALITY"].array +        group_squeeze = True +    else: +        grouping = None +        quality  = None +        group_squeeze = False + +    # crosstalk objects (BEFORE background subtraction) +    crosstalks_cleancopy = [] +    # crosstalk-corrected spectra +    cc_spectra = [] + +    # correct crosstalk effects for each region first +    for region in config.sections: +        if verbose: +            print("INFO: processing '%s' ..." % region, file=sys.stderr) +        crosstalk = Crosstalk(config.get(region), +                arf_dict=arf_dict, rmf_dict=rmf_dict, +                grouping=grouping, quality=quality) +        if grouping is not None: +            crosstalk.apply_grouping(verbose=verbose) +        crosstalk.estimate_errors(verbose=verbose) +        # keep a (almost) clean copy of the crosstalk object +        crosstalks_cleancopy.append(crosstalk.copy()) +        if verbose: +            print("INFO: doing crosstalk correction ...", file=sys.stderr) +        crosstalk.do_correction(subtract_bkg=subtract_bkg, +                fix_negative=fix_negative, group_squeeze=group_squeeze, +                add_history=True, verbose=verbose) +        crosstalk.fix_header() +        cc_spectra.append(crosstalk.get_spectrum(copy=True)) + +    # spectral data of the crosstalk-corrected spectra +    cc_results = [] +    cc_results.append([ spec.get_data(group_squeeze=group_squeeze, copy=True) +                        for spec in cc_spectra ]) + +    # Monte Carlo for spectral group error estimation +    print("INFO: Monte Carlo to estimate spectral errors (%d times) ..." % \ +            mc_times, file=sys.stderr) +    for i in range(mc_times): +        if i % 100 == 0: +            print("%d..." % i, end="", flush=True, file=sys.stderr) +        # correct crosstalk effects +        cc_spectra_copy = [] +        for crosstalk in crosstalks_cleancopy: +            # copy and randomize +            crosstalk_copy = crosstalk.copy().randomize() +            crosstalk_copy.do_correction(subtract_bkg=subtract_bkg, +                    fix_negative=fix_negative, group_squeeze=group_squeeze, +                    add_history=False, verbose=False) +            cc_spectra_copy.append(crosstalk_copy.get_spectrum(copy=True)) +        cc_results.append([ spec.get_data(group_squeeze=group_squeeze, +                                          copy=True) +                            for spec in cc_spectra_copy ]) +    print("DONE!", flush=True, file=sys.stderr) + +    if verbose: +        print("INFO: Calculating the median and errors for each spectrum ...", +                file=sys.stderr) +    medians, errors = calc_median_errors(cc_results) +    if verbose: +        print("INFO: Writing the crosstalk-corrected spectra " + \ +                "with estimated statistical errors ...", +                file=sys.stderr) +    for spec, data, err in zip(cc_spectra, medians, errors): +        spec.set_data(data, group_squeeze=group_squeeze) +        spec.add_stat_err(err, group_squeeze=group_squeeze) +        spec.write(clobber=clobber) +# main_crosstalk routine }}} + + +if __name__ == "__main__": +    # arguments' default values +    default_mode = "both" +    default_mc_times = 5000 +    # commandline arguments parser +    parser = argparse.ArgumentParser( +            description="Correct the crosstalk effects for XMM EPIC spectra", +            epilog="Version: %s (%s)" % (__version__, __date__)) +    parser.add_argument("config", help="config file in which describes " +\ +            "the crosstalk relations ('ConfigObj' syntax)") +    parser.add_argument("-m", "--mode", dest="mode", default=default_mode, +            help="operation mode (both | crosstalk | deprojection)") +    parser.add_argument("-B", "--no-subtract-bkg", dest="subtract_bkg", +            action="store_false", help="do NOT subtract background first") +    parser.add_argument("-N", "--fix-negative", dest="fix_negative", +            action="store_true", help="fix negative channel values") +    parser.add_argument("-M", "--mc-times", dest="mc_times", +            type=int, default=default_mc_times, +            help="Monte Carlo times for error estimation") +    parser.add_argument("-C", "--clobber", dest="clobber", +            action="store_true", help="overwrite output file if exists") +    parser.add_argument("-v", "--verbose", dest="verbose", +            action="store_true", help="show verbose information") +    args = parser.parse_args() +    # merge commandline arguments and config +    config       = ConfigObj(args.config) +    subtract_bkg = set_argument("subtract_bkg", True,  args, config) +    fix_negative = set_argument("fix_negative", False, args, config) +    verbose      = set_argument("verbose",      False, args, config) +    clobber      = set_argument("clobber",      False, args, config) +    # operation mode +    mode = config.get("mode", default_mode) +    if args.mode != default_mode: +        mode = args.mode +    # Monte Carlo times +    mc_times = config.as_int("mc_times") +    if args.mc_times != default_mc_times: +        mc_times = args.mc_times + +    if mode.lower() == "both": +        print("MODE: CROSSTALK + DEPROJECTION", file=sys.stderr) +        main(config, subtract_bkg=subtract_bkg, fix_negative=fix_negative, +                mc_times=mc_times, verbose=verbose, clobber=clobber) +    elif mode.lower() == "deprojection": +        print("MODE: DEPROJECTION", file=sys.stderr) +        main_deprojection(config, mc_times=mc_times, +                verbose=verbose, clobber=clobber) +    elif mode.lower() == "crosstalk": +        print("MODE: CROSSTALK", file=sys.stderr) +        main_crosstalk(config, subtract_bkg=subtract_bkg, +                fix_negative=fix_negative, mc_times=mc_times, +                verbose=verbose, clobber=clobber) +    else: +        raise ValueError("Invalid operation mode: %s" % mode) +    print(WARNING) + +#  vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/fit_sbp.py b/python/fit_sbp.py new file mode 100755 index 0000000..c22e0c8 --- /dev/null +++ b/python/fit_sbp.py @@ -0,0 +1,807 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Aaron LI +# Created: 2016-03-13 +# Updated: 2016-04-26 +# +# Changelogs: +# 2016-04-26: +#   * Reorder some methods of classes 'FitModelSBeta' and 'FitModelDBeta' +#   * Change the output file extension from ".txt" to ".json" +# 2016-04-21: +#   * Plot another X axis with unit "r500", with R500 values marked +#   * Adjust output image size/resolution +# 2016-04-20: +#   * Support "pix" and "kpc" units +#   * Allow ignore data w.r.t R500 value +#   * Major changes to the config syntax +#   * Add commandline argument to select the sbp model +# 2016-04-05: +#   * Allow fix parameters +# 2016-03-31: +#   * Remove `ci_report()' +#   * Add `make_results()' to orgnize all results as s Python dictionary +#   * Report results as json string +# 2016-03-28: +#   * Add `main()', `make_model()' +#   * Use `configobj' to handle configurations +#   * Save fit results and plot +#   * Add `ci_report()' +# 2016-03-14: +#   * Refactor classes `FitModelSBeta' and `FitModelDBeta' +#   * Add matplotlib plot support +#   * Add `ignore_data()' and `notice_data()' support +#   * Add classes `FitModelSBetaNorm' and `FitModelDBetaNorm' +# +# TODO: +#   * to allow fit the outer beta component, then fix it, and fit the inner one +#   * to integrate basic information of config file to the output json +#   * to output the ignored radius range in the same unit as input sbp data +# + +""" +Fit the surface brightness profile (SBP) with the single-beta model: +  s(r) = s0 * [1.0 + (r/rc)^2] ^ (0.5-3*beta) + bkg +or the double-beta model: +  s(r) = s01 * [1.0 + (r/rc1)^2] ^ (0.5-3*beta1) + +         s02 * [1.0 + (r/rc2)^2] ^ (0.5-3*beta2) + bkg + + +Sample config file: +------------------------------------------------- +name     = <NAME> +obsid    = <OBSID> +r500_pix = <R500_PIX> +r500_kpc = <R500_KPC> + +sbpfile  = sbprofile.txt +# unit of radius: pix (default) or kpc +unit     = pixel + +# sbp model: "sbeta" or "dbeta" +model    = sbeta +#model    = dbeta + +# output file to store the fitting results +outfile  = sbpfit.json +# output file to save the fitting plot +imgfile  = sbpfit.png + +# data range to be ignored during fitting (same unit as the above "unit") +#ignore      = 0.0-20.0, +# specify the ignore range w.r.t R500 ("r500_pix" or "r500_kpc" required) +#ignore_r500 = 0.0-0.15, + +[sbeta] +# model-related options (OVERRIDE the upper level options) +outfile     = sbpfit_sbeta.json +imgfile     = sbpfit_sbeta.png +#ignore      = 0.0-20.0, +#ignore_r500 = 0.0-0.15, +  [[params]] +  # model parameters +  # name = initial, lower, upper, variable (FIXED/False to fix the parameter) +  s0    = 1.0e-8,  0.0,  1.0e-6 +  rc    = 30.0,    5.0,  1.0e4 +  #rc    = 30.0,    5.0,  1.0e4,  FIXED +  beta  = 0.7,     0.3,  1.1 +  bkg   = 1.0e-10, 0.0,  1.0e-8 + + +[dbeta] +outfile     = sbpfit_dbeta.json +imgfile     = sbpfit_dbeta.png +#ignore      = 0.0-20.0, +#ignore_r500 = 0.0-0.15, +  [[params]] +  s01   = 1.0e-8,  0.0,  1.0e-6 +  rc1   = 50.0,    10.0, 1.0e4 +  beta1 = 0.7,     0.3,  1.1 +  s02   = 1.0e-8,  0.0,  1.0e-6 +  rc2   = 30.0,    2.0,  5.0e2 +  beta2 = 0.7,     0.3,  1.1 +  bkg   = 1.0e-10, 0.0,  1.0e-8 +------------------------------------------------- +""" + +__version__ = "0.6.2" +__date__    = "2016-04-26" + + +import os +import sys +import re +import argparse +import json +from collections import OrderedDict + +import numpy as np +import lmfit +import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure +from configobj import ConfigObj + + +plt.style.use("ggplot") + + +class FitModel: +    """ +    Meta-class of the fitting model. + +    The supplied `func' should have the following syntax: +        y = f(x, params) +    where the `params' is `lmfit.Parameters' instance which contains all +    the model parameters to be fitted, and should be provided as well. +    """ +    def __init__(self, name=None, func=None, params=lmfit.Parameters()): +        self.name = name +        self.func = func +        self.params = params + +    def f(self, x): +        return self.func(x, self.params) + +    def get_param(self, name=None): +        """ +        Return the requested `Parameter' object or the whole +        `Parameters' object of no name supplied. +        """ +        try: +            return self.params[name] +        except KeyError: +            return self.params + +    def set_param(self, name, *args, **kwargs): +        """ +        Set the properties of the specified parameter. +        """ +        param = self.params[name] +        param.set(*args, **kwargs) + +    def plot(self, params, xdata, ax): +        """ +        Plot the fitted model. +        """ +        f_fitted = lambda x: self.func(x, params) +        ydata = f_fitted(xdata) +        ax.plot(xdata, ydata, 'k-') + +class FitModelSBeta(FitModel): +    """ +    The single-beta model to be fitted. +    Single-beta model, with a constant background. +    """ +    params = lmfit.Parameters() +    params.add_many( # (name, value, vary, min, max, expr) +                    ("s0",   1.0e-8, True, 0.0, 1.0e-6, None), +                    ("rc",   30.0,   True, 1.0, 1.0e4,  None), +                    ("beta", 0.7,    True, 0.3, 1.1,    None), +                    ("bkg",  1.0e-9, True, 0.0, 1.0e-7, None)) + +    def __init__(self): +        super(self.__class__, self).__init__(name="Single-beta", +                func=self.sbeta, params=self.params) + +    @staticmethod +    def sbeta(r, params): +        parvals = params.valuesdict() +        s0   = parvals["s0"] +        rc   = parvals["rc"] +        beta = parvals["beta"] +        bkg  = parvals["bkg"] +        return s0 * np.power((1 + (r/rc)**2), (0.5 - 3*beta)) + bkg + +    def plot(self, params, xdata, ax): +        """ +        Plot the fitted model, as well as the fitted parameters. +        """ +        super(self.__class__, self).plot(params, xdata, ax) +        ydata = self.sbeta(xdata, params) +        # fitted paramters +        ax.vlines(x=params["rc"].value, ymin=min(ydata), ymax=max(ydata), +                linestyles="dashed") +        ax.hlines(y=params["bkg"].value, xmin=min(xdata), xmax=max(xdata), +                linestyles="dashed") +        ax.text(x=params["rc"].value, y=min(ydata), +                s="beta: %.2f\nrc: %.2f" % (params["beta"].value, +                    params["rc"].value)) +        ax.text(x=min(xdata), y=min(ydata), +                s="bkg: %.3e" % params["bkg"].value, +                verticalalignment="top") + + +class FitModelDBeta(FitModel): +    """ +    The double-beta model to be fitted. +    Double-beta model, with a constant background. + +    NOTE: +    the first beta component (s01, rc1, beta1) describes the main and +    outer SBP; while the second beta component (s02, rc2, beta2) accounts +    for the central brightness excess. +    """ +    params = lmfit.Parameters() +    params.add("s01",   value=1.0e-8, min=0.0,  max=1.0e-6) +    params.add("rc1",   value=50.0,   min=10.0, max=1.0e4) +    params.add("beta1", value=0.7,    min=0.3,  max=1.1) +    #params.add("df_s0", value=1.0e-8, min=0.0,  max=1.0e-6) +    #params.add("s02",   expr="s01 + df_s0") +    params.add("s02",   value=1.0e-8, min=0.0,  max=1.0e-6) +    #params.add("df_rc", value=30.0,   min=0.0,  max=1.0e4) +    #params.add("rc2",   expr="rc1 - df_rc") +    params.add("rc2",   value=20.0,   min=1.0,  max=5.0e2) +    params.add("beta2", value=0.7,    min=0.3,  max=1.1) +    params.add("bkg",   value=1.0e-9, min=0.0,  max=1.0e-7) + +    def __init__(self): +        super(self.__class__, self).__init__(name="Double-beta", +                func=self.dbeta, params=self.params) + +    @classmethod +    def dbeta(self, r, params): +        return self.beta1(r, params) + self.beta2(r, params) + +    @staticmethod +    def beta1(r, params): +        """ +        This beta component describes the main/outer part of the SBP. +        """ +        parvals = params.valuesdict() +        s01   = parvals["s01"] +        rc1   = parvals["rc1"] +        beta1 = parvals["beta1"] +        bkg   = parvals["bkg"] +        return s01 * np.power((1 + (r/rc1)**2), (0.5 - 3*beta1)) + bkg + +    @staticmethod +    def beta2(r, params): +        """ +        This beta component describes the central/excess part of the SBP. +        """ +        parvals = params.valuesdict() +        s02   = parvals["s02"] +        rc2   = parvals["rc2"] +        beta2 = parvals["beta2"] +        return s02 * np.power((1 + (r/rc2)**2), (0.5 - 3*beta2)) + +    def plot(self, params, xdata, ax): +        """ +        Plot the fitted model, and each beta component, +        as well as the fitted parameters. +        """ +        super(self.__class__, self).plot(params, xdata, ax) +        beta1_ydata = self.beta1(xdata, params) +        beta2_ydata = self.beta2(xdata, params) +        ax.plot(xdata, beta1_ydata, 'b-.') +        ax.plot(xdata, beta2_ydata, 'b-.') +        # fitted paramters +        ydata = beta1_ydata + beta2_ydata +        ax.vlines(x=params["rc1"].value, ymin=min(ydata), ymax=max(ydata), +                linestyles="dashed") +        ax.vlines(x=params["rc2"].value, ymin=min(ydata), ymax=max(ydata), +                linestyles="dashed") +        ax.hlines(y=params["bkg"].value, xmin=min(xdata), xmax=max(xdata), +                linestyles="dashed") +        ax.text(x=params["rc1"].value, y=min(ydata), +                s="beta1: %.2f\nrc1: %.2f" % (params["beta1"].value, +                    params["rc1"].value)) +        ax.text(x=params["rc2"].value, y=min(ydata), +                s="beta2: %.2f\nrc2: %.2f" % (params["beta2"].value, +                    params["rc2"].value)) +        ax.text(x=min(xdata), y=min(ydata), +                s="bkg: %.3e" % params["bkg"].value, +                verticalalignment="top") + + +class FitModelSBetaNorm(FitModel): +    """ +    The single-beta model to be fitted. +    Single-beta model, with a constant background. +    Normalized the `s0' and `bkg' parameters by take the logarithm. +    """ +    params = lmfit.Parameters() +    params.add_many( # (name, value, vary, min, max, expr) +                    ("log10_s0",   -8.0, True, -12.0, -6.0,  None), +                    ("rc",   30.0, True, 1.0,   1.0e4, None), +                    ("beta", 0.7,  True, 0.3,   1.1,   None), +                    ("log10_bkg",  -9.0, True, -12.0, -7.0,  None)) + +    @staticmethod +    def sbeta(r, params): +        parvals = params.valuesdict() +        s0   = 10 ** parvals["log10_s0"] +        rc   = parvals["rc"] +        beta = parvals["beta"] +        bkg  = 10 ** parvals["log10_bkg"] +        return s0 * np.power((1 + (r/rc)**2), (0.5 - 3*beta)) + bkg + +    def __init__(self): +        super(self.__class__, self).__init__(name="Single-beta", +                func=self.sbeta, params=self.params) + +    def plot(self, params, xdata, ax): +        """ +        Plot the fitted model, as well as the fitted parameters. +        """ +        super(self.__class__, self).plot(params, xdata, ax) +        ydata = self.sbeta(xdata, params) +        # fitted paramters +        ax.vlines(x=params["rc"].value, ymin=min(ydata), ymax=max(ydata), +                linestyles="dashed") +        ax.hlines(y=(10 ** params["bkg"].value), xmin=min(xdata), +                xmax=max(xdata), linestyles="dashed") +        ax.text(x=params["rc"].value, y=min(ydata), +                s="beta: %.2f\nrc: %.2f" % (params["beta"].value, +                    params["rc"].value)) +        ax.text(x=min(xdata), y=min(ydata), +                s="bkg: %.3e" % (10 ** params["bkg"].value), +                verticalalignment="top") + + +class FitModelDBetaNorm(FitModel): +    """ +    The double-beta model to be fitted. +    Double-beta model, with a constant background. +    Normalized the `s01', `s02' and `bkg' parameters by take the logarithm. + +    NOTE: +    the first beta component (s01, rc1, beta1) describes the main and +    outer SBP; while the second beta component (s02, rc2, beta2) accounts +    for the central brightness excess. +    """ +    params = lmfit.Parameters() +    params.add("log10_s01",   value=-8.0, min=-12.0, max=-6.0) +    params.add("rc1",   value=50.0, min=10.0,  max=1.0e4) +    params.add("beta1", value=0.7,  min=0.3,   max=1.1) +    #params.add("df_s0", value=1.0e-8, min=0.0, max=1.0e-6) +    #params.add("s02",   expr="s01 + df_s0") +    params.add("log10_s02",   value=-8.0, min=-12.0, max=-6.0) +    #params.add("df_rc", value=30.0, min=0.0,   max=1.0e4) +    #params.add("rc2",   expr="rc1 - df_rc") +    params.add("rc2",   value=20.0, min=1.0,   max=5.0e2) +    params.add("beta2", value=0.7,  min=0.3,   max=1.1) +    params.add("log10_bkg",   value=-9.0, min=-12.0, max=-7.0) + +    @staticmethod +    def beta1(r, params): +        """ +        This beta component describes the main/outer part of the SBP. +        """ +        parvals = params.valuesdict() +        s01   = 10 ** parvals["log10_s01"] +        rc1   = parvals["rc1"] +        beta1 = parvals["beta1"] +        bkg   = 10 ** parvals["log10_bkg"] +        return s01 * np.power((1 + (r/rc1)**2), (0.5 - 3*beta1)) + bkg + +    @staticmethod +    def beta2(r, params): +        """ +        This beta component describes the central/excess part of the SBP. +        """ +        parvals = params.valuesdict() +        s02   = 10 ** parvals["log10_s02"] +        rc2   = parvals["rc2"] +        beta2 = parvals["beta2"] +        return s02 * np.power((1 + (r/rc2)**2), (0.5 - 3*beta2)) + +    @classmethod +    def dbeta(self, r, params): +        return self.beta1(r, params) + self.beta2(r, params) + +    def __init__(self): +        super(self.__class__, self).__init__(name="Double-beta", +                func=self.dbeta, params=self.params) + +    def plot(self, params, xdata, ax): +        """ +        Plot the fitted model, and each beta component, +        as well as the fitted parameters. +        """ +        super(self.__class__, self).plot(params, xdata, ax) +        beta1_ydata = self.beta1(xdata, params) +        beta2_ydata = self.beta2(xdata, params) +        ax.plot(xdata, beta1_ydata, 'b-.') +        ax.plot(xdata, beta2_ydata, 'b-.') +        # fitted paramters +        ydata = beta1_ydata + beta2_ydata +        ax.vlines(x=params["log10_rc1"].value, ymin=min(ydata), ymax=max(ydata), +                linestyles="dashed") +        ax.vlines(x=params["rc2"].value, ymin=min(ydata), ymax=max(ydata), +                linestyles="dashed") +        ax.hlines(y=(10 ** params["bkg"].value), xmin=min(xdata), +                xmax=max(xdata), linestyles="dashed") +        ax.text(x=params["rc1"].value, y=min(ydata), +                s="beta1: %.2f\nrc1: %.2f" % (params["beta1"].value, +                    params["rc1"].value)) +        ax.text(x=params["rc2"].value, y=min(ydata), +                s="beta2: %.2f\nrc2: %.2f" % (params["beta2"].value, +                    params["rc2"].value)) +        ax.text(x=min(xdata), y=min(ydata), +                s="bkg: %.3e" % (10 ** params["bkg"].value), +                verticalalignment="top") + + +class SbpFit: +    """ +    Class to handle the SBP fitting with single-/double-beta model. +    """ +    def __init__(self, model, method="lbfgsb", +            xdata=None, ydata=None, xerr=None, yerr=None, xunit="pix", +            name=None, obsid=None, r500_pix=None, r500_kpc=None): +        self.method = method +        self.model = model +        self.load_data(xdata=xdata, ydata=ydata, xerr=xerr, yerr=yerr, +                xunit=xunit) +        self.set_source(name=name, obsid=obsid, r500_pix=r500_pix, +                r500_kpc=r500_kpc) + +    def set_source(self, name, obsid=None, r500_pix=None, r500_kpc=None): +        self.name = name +        try: +            self.obsid = int(obsid) +        except TypeError: +            self.obsid = None +        try: +            self.r500_pix = float(r500_pix) +        except TypeError: +            self.r500_pix = None +        try: +            self.r500_kpc = float(r500_kpc) +        except TypeError: +            self.r500_kpc = None +        try: +            self.kpc_per_pix = self.r500_kpc / self.r500_pix +        except (TypeError, ZeroDivisionError): +            self.kpc_per_pix = -1 + +    def load_data(self, xdata, ydata, xerr, yerr, xunit="pix"): +        self.xdata = xdata +        self.ydata = ydata +        self.xerr  = xerr +        self.yerr  = yerr +        if xdata is not None: +            self.mask = np.ones(xdata.shape, dtype=np.bool) +        else: +            self.mask = None +        if xunit.lower() in ["pix", "pixel"]: +            self.xunit = "pix" +        elif xunit.lower() == "kpc": +            self.xunit = "kpc" +        else: +            raise ValueError("invalid xunit: %s" % xunit) + +    def ignore_data(self, xmin=None, xmax=None, unit=None): +        """ +        Ignore the data points within range [xmin, xmax]. +        If xmin is None, then xmin=min(xdata); +        if xmax is None, then xmax=max(xdata). + +        if unit is None, then assume the same unit as `self.xunit'. +        """ +        if unit is None: +            unit = self.xunit +        if xmin is not None: +            xmin = self.convert_unit(xmin, unit=unit) +        else: +            xmin = np.min(self.xdata) +        if xmax is not None: +            xmax = self.convert_unit(xmax, unit=unit) +        else: +            xmax = np.max(self.xdata) +        ignore_idx = np.logical_and(self.xdata >= xmin, self.xdata <= xmax) +        self.mask[ignore_idx] = False +        # reset `f_residual' +        self.f_residual = None + +    def notice_data(self, xmin=None, xmax=None, unit=None): +        """ +        Notice the data points within range [xmin, xmax]. +        If xmin is None, then xmin=min(xdata); +        if xmax is None, then xmax=max(xdata). + +        if unit is None, then assume the same unit as `self.xunit'. +        """ +        if unit is None: +            unit = self.xunit +        if xmin is not None: +            xmin = self.convert_unit(xmin, unit=unit) +        else: +            xmin = np.min(self.xdata) +        if xmax is not None: +            xmax = self.convert_unit(xmax, unit=unit) +        else: +            xmax = np.max(self.xdata) +        notice_idx = np.logical_and(self.xdata >= xmin, self.xdata <= xmax) +        self.mask[notice_idx] = True +        # reset `f_residual' +        self.f_residual = None + +    def convert_unit(self, x, unit): +        """ +        Convert the value x in given unit to be the unit `self.xunit' +        """ +        if unit == self.xunit: +            return x +        elif (unit == "pix")  and (self.xunit == "kpc"): +            return (x / self.r500_pix * self.r500_kpc) +        elif (unit == "kpc")  and (self.xunit == "pix"): +            return (x / self.r500_kpc * self.r500_pix) +        elif (unit == "r500") and (self.xunit == "pix"): +            return (x * self.r500_pix) +        elif (unit == "r500") and (self.xunit == "kpc"): +            return (x * self.r500_kpc) +        else: +            raise ValueError("invalid units: %s vs. %s" % (unit, self.xunit)) + +    def convert_to_r500(self, x, unit=None): +        """ +        Convert the value x in given unit to be in unit "r500". +        """ +        if unit is None: +            unit = self.xunit +        if unit == "r500": +            return x +        elif unit == "pix": +            return (x / self.r500_pix) +        elif unit == "kpc": +            return (x / self.r500_kpc) +        else: +            raise ValueError("invalid unit: %s" % unit) + +    def set_residual(self): +        def f_residual(params): +            if self.yerr is None: +                return self.model.func(self.xdata[self.mask], params) - \ +                        self.ydata +            else: +                return (self.model.func(self.xdata[self.mask], params) - \ +                        self.ydata[self.mask]) / self.yerr[self.mask] +        self.f_residual = f_residual + +    def fit(self, method=None): +        if method is None: +            method = self.method +        if not hasattr(self, "f_residual") or self.f_residual is None: +            self.set_residual() +        self.fitter = lmfit.Minimizer(self.f_residual, self.model.params) +        self.fitted = self.fitter.minimize(method=method) +        self.fitted_model = lambda x: self.model.func(x, self.fitted.params) + +    def calc_ci(self, sigmas=[0.68, 0.90]): +        # `conf_interval' requires the fitted results have valid `stderr', +        # so we need to re-fit the model with method `leastsq'. +        fitted = self.fitter.minimize(method="leastsq", +                params=self.fitted.params) +        self.ci, self.trace = lmfit.conf_interval(self.fitter, fitted, +                sigmas=sigmas, trace=True) + +    def make_results(self): +        """ +        Make the `self.results' dictionary which contains all the fitting +        results as well as the confidence intervals. +        """ +        fitted = self.fitted +        self.results = OrderedDict() +        ## fitting results +        self.results.update( +                nfev   = fitted.nfev, +                ndata  = fitted.ndata, +                nvarys = fitted.nvarys,  # number of varible paramters +                nfree  = fitted.nfree,  # degree of freem +                chisqr = fitted.chisqr, +                redchi = fitted.redchi, +                aic    = fitted.aic, +                bic    = fitted.bic) +        params = fitted.params +        pnames = list(params.keys()) +        pvalues = OrderedDict() +        for pn in pnames: +            par = params.get(pn) +            pvalues[pn] = [par.value, par.min, par.max, par.vary] +        self.results["params"] = pvalues +        ## confidence intervals +        if hasattr(self, "ci") and self.ci is not None: +            ci = self.ci +            ci_values = OrderedDict() +            ci_sigmas = [ "ci%02d" % (v[0]*100) for v in ci.get(pnames[0]) ] +            ci_names = sorted(list(set(ci_sigmas))) +            ci_idx = { k: [] for k in ci_names } +            for cn, idx in zip(ci_sigmas, range(len(ci_sigmas))): +                ci_idx[cn].append(idx) +            # parameters ci +            for pn in pnames: +                ci_pv = OrderedDict() +                pv = [ v[1] for v in ci.get(pn) ] +                # best +                pv_best = pv[ ci_idx["ci00"][0] ] +                ci_pv["best"] = pv_best +                # ci of each sigma +                pv2 = [ v-pv_best for v in pv ] +                for cn in ci_names[1:]: +                    ci_pv[cn] = [ pv2[idx] for idx in ci_idx[cn] ] +                ci_values[pn] = ci_pv +            self.results["ci"] = ci_values + +    def report(self, outfile=sys.stdout): +        if not hasattr(self, "results") or self.results is None: +            self.make_results() +        jd = json.dumps(self.results, indent=2) +        print(jd, file=outfile) + +    def plot(self, ax=None, fig=None, r500_axis=True): +        """ +        Arguments: +          * r500_axis: whether to add a second X axis in unit "r500" +        """ +        if ax is None: +            fig, ax = plt.subplots(1, 1) +        # noticed data points +        eb = ax.errorbar(self.xdata[self.mask], self.ydata[self.mask], +                xerr=self.xerr[self.mask], yerr=self.yerr[self.mask], +                fmt="none") +        # ignored data points +        ignore_mask = np.logical_not(self.mask) +        if np.sum(ignore_mask) > 0: +            eb = ax.errorbar(self.xdata[ignore_mask], self.ydata[ignore_mask], +                    xerr=self.xerr[ignore_mask], yerr=self.yerr[ignore_mask], +                    fmt="none") +            eb[-1][0].set_linestyle("-.") +        # fitted model +        xmax = self.xdata[-1] + self.xerr[-1] +        xpred = np.power(10, np.linspace(0, np.log10(xmax), 2*len(self.xdata))) +        ypred = self.fitted_model(xpred) +        ymin = min(min(self.ydata), min(ypred)) +        ymax = max(max(self.ydata), max(ypred)) +        self.model.plot(params=self.fitted.params, xdata=xpred, ax=ax) +        ax.set_xscale("log") +        ax.set_yscale("log") +        ax.set_xlim(1.0, xmax) +        ax.set_ylim(ymin/1.2, ymax*1.2) +        name = self.name +        if self.obsid is not None: +            name += "; %s" % self.obsid +        ax.set_title("Fitted Surface Brightness Profile (%s)" % name) +        ax.set_xlabel("Radius (%s)" % self.xunit) +        ax.set_ylabel(r"Surface Brightness (photons/cm$^2$/pixel$^2$/s)") +        ax.text(x=xmax, y=ymax, +                s="redchi: %.2f / %.2f = %.2f" % (self.fitted.chisqr, +                    self.fitted.nfree, self.fitted.chisqr/self.fitted.nfree), +                horizontalalignment="right", verticalalignment="top") +        plot_ret = [fig, ax] +        if r500_axis: +            # Add a second X-axis with labels in unit "r500" +            # Credit: https://stackoverflow.com/a/28192477/4856091 +            try: +                ax.title.set_position([0.5, 1.1])  # raise title position +                ax2 = ax.twiny() +                # NOTE: the ORDER of the following lines MATTERS +                ax2.set_xscale(ax.get_xscale()) +                ax2_ticks = ax.get_xticks() +                ax2.set_xticks(ax2_ticks) +                ax2.set_xbound(ax.get_xbound()) +                ax2.set_xticklabels([ "%.2g" % self.convert_to_r500(x) +                                    for x in ax2_ticks ]) +                ax2.set_xlabel("Radius (r500; r500 = %s pix = %s kpc)" % (\ +                        self.r500_pix, self.r500_kpc)) +                ax2.grid(False) +                plot_ret.append(ax2) +            except ValueError: +                # cannot convert X values to unit "r500" +                pass +        # automatically adjust layout +        fig.tight_layout() +        return plot_ret + + +def make_model(config, modelname): +    """ +    Make the model with parameters set according to the config. +    """ +    if modelname == "sbeta": +        # single-beta model +        model = FitModelSBeta() +    elif modelname == "dbeta": +        # double-beta model +        model = FitModelDBeta() +    else: +        raise ValueError("Invalid model: %s" % modelname) +    # set initial values and bounds for the model parameters +    params = config[modelname]["params"] +    for p, value in params.items(): +        variable = True +        if len(value) == 4 and value[3].upper() in ["FIXED", "FALSE"]: +            variable = False +        model.set_param(name=p, value=float(value[0]), +                min=float(value[1]), max=float(value[2]), vary=variable) +    return model + + +def main(): +    # parser for command line options and arguments +    parser = argparse.ArgumentParser( +            description="Fit surface brightness profile with " + \ +                        "single-/double-beta model", +            epilog="Version: %s (%s)" % (__version__, __date__)) +    parser.add_argument("-V", "--version", action="version", +            version="%(prog)s " + "%s (%s)" % (__version__, __date__)) +    parser.add_argument("config", help="Config file for SBP fitting") +    # exclusive argument group for model selection +    grp_model = parser.add_mutually_exclusive_group(required=False) +    grp_model.add_argument("-s", "--sbeta", dest="sbeta", +            action="store_true", help="single-beta model for SBP") +    grp_model.add_argument("-d", "--dbeta", dest="dbeta", +            action="store_true", help="double-beta model for SBP") +    # +    args = parser.parse_args() + +    config = ConfigObj(args.config) + +    # determine the model name +    if args.sbeta: +        modelname = "sbeta" +    elif args.dbeta: +        modelname = "dbeta" +    else: +        modelname = config["model"] + +    config_model = config[modelname] +    # determine the "outfile" and "imgfile" +    outfile = config.get("outfile") +    outfile = config_model.get("outfile", outfile) +    imgfile = config.get("imgfile") +    imgfile = config_model.get("imgfile", imgfile) + +    # SBP fitting model +    model = make_model(config, modelname=modelname) + +    # sbp data and fit object +    sbpdata = np.loadtxt(config["sbpfile"]) +    sbpfit = SbpFit(model=model, xdata=sbpdata[:, 0], xerr=sbpdata[:, 1], +            ydata=sbpdata[:, 2], yerr=sbpdata[:, 3], +            xunit=config.get("unit", "pix")) +    sbpfit.set_source(name=config["name"], obsid=config.get("obsid"), +            r500_pix=config.get("r500_pix"), r500_kpc=config.get("r500_kpc")) + +    # apply data range ignorance +    if "ignore" in config.keys(): +        for ig in config.as_list("ignore"): +            xmin, xmax = map(float, ig.split("-")) +            sbpfit.ignore_data(xmin=xmin, xmax=xmax) +    if "ignore_r500" in config.keys(): +        for ig in config.as_list("ignore_r500"): +            xmin, xmax = map(float, ig.split("-")) +            sbpfit.ignore_data(xmin=xmin, xmax=xmax, unit="r500") + +    # apply additional data range ignorance specified within model section +    if "ignore" in config_model.keys(): +        for ig in config_model.as_list("ignore"): +            xmin, xmax = map(float, ig.split("-")) +            sbpfit.ignore_data(xmin=xmin, xmax=xmax) +    if "ignore_r500" in config_model.keys(): +        for ig in config_model.as_list("ignore_r500"): +            xmin, xmax = map(float, ig.split("-")) +            sbpfit.ignore_data(xmin=xmin, xmax=xmax, unit="r500") + +    # fit and calculate confidence intervals +    sbpfit.fit() +    sbpfit.calc_ci() +    sbpfit.report() +    with open(outfile, "w") as ofile: +        sbpfit.report(outfile=ofile) + +    # make and save a plot +    fig = Figure(figsize=(10, 8)) +    canvas = FigureCanvas(fig) +    ax = fig.add_subplot(111) +    sbpfit.plot(ax=ax, fig=fig, r500_axis=True) +    fig.savefig(imgfile, dpi=150) + + +if __name__ == "__main__": +    main() + +#  vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/imapUTF7.py b/python/imapUTF7.py new file mode 100644 index 0000000..2e4db0a --- /dev/null +++ b/python/imapUTF7.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# This code was originally in PloneMailList, a GPL'd software. +# http://svn.plone.org/svn/collective/mxmImapClient/trunk/imapUTF7.py +# http://bugs.python.org/issue5305 +# +# Port to Python 3.x +# Credit: https://github.com/MarechJ/py3_imap_utf7 +# +# 2016-01-23 +# Aaron LI +# + +""" +Imap folder names are encoded using a special version of utf-7 as defined in RFC +2060 section 5.1.3. + +5.1.3.  Mailbox International Naming Convention + +   By convention, international mailbox names are specified using a +   modified version of the UTF-7 encoding described in [UTF-7].  The +   purpose of these modifications is to correct the following problems +   with UTF-7: + +      1) UTF-7 uses the "+" character for shifting; this conflicts with +         the common use of "+" in mailbox names, in particular USENET +         newsgroup names. + +      2) UTF-7's encoding is BASE64 which uses the "/" character; this +         conflicts with the use of "/" as a popular hierarchy delimiter. + +      3) UTF-7 prohibits the unencoded usage of "\"; this conflicts with +         the use of "\" as a popular hierarchy delimiter. + +      4) UTF-7 prohibits the unencoded usage of "~"; this conflicts with +         the use of "~" in some servers as a home directory indicator. + +      5) UTF-7 permits multiple alternate forms to represent the same +         string; in particular, printable US-ASCII chararacters can be +         represented in encoded form. + +   In modified UTF-7, printable US-ASCII characters except for "&" +   represent themselves; that is, characters with octet values 0x20-0x25 +   and 0x27-0x7e.  The character "&" (0x26) is represented by the two- +   octet sequence "&-". + +   All other characters (octet values 0x00-0x1f, 0x7f-0xff, and all +   Unicode 16-bit octets) are represented in modified BASE64, with a +   further modification from [UTF-7] that "," is used instead of "/". +   Modified BASE64 MUST NOT be used to represent any printing US-ASCII +   character which can represent itself. + +   "&" is used to shift to modified BASE64 and "-" to shift back to US- +   ASCII.  All names start in US-ASCII, and MUST end in US-ASCII (that +   is, a name that ends with a Unicode 16-bit octet MUST end with a "- +   "). + +      For example, here is a mailbox name which mixes English, Japanese, +      and Chinese text: ~peter/mail/&ZeVnLIqe-/&U,BTFw- +""" + + +import binascii +import codecs + + +## encoding + +def modified_base64(s:str): +    s = s.encode('utf-16be')  # UTF-16, big-endian byte order +    return binascii.b2a_base64(s).rstrip(b'\n=').replace(b'/', b',') + +def doB64(_in, r): +    if _in: +        r.append(b'&' + modified_base64(''.join(_in)) + b'-') +        del _in[:] + +def encoder(s:str): +    r = [] +    _in = [] +    for c in s: +        ordC = ord(c) +        if 0x20 <= ordC <= 0x25 or 0x27 <= ordC <= 0x7e: +            doB64(_in, r) +            r.append(c.encode()) +        elif c == '&': +            doB64(_in, r) +            r.append(b'&-') +        else: +            _in.append(c) +    doB64(_in, r) +    return (b''.join(r), len(s)) + + +## decoding + +def modified_unbase64(s:bytes): +    b = binascii.a2b_base64(s.replace(b',', b'/') + b'===') +    return b.decode('utf-16be') + +def decoder(s:bytes): +    r = [] +    decode = bytearray() +    for c in s: +        if c == ord('&') and not decode: +            decode.append(ord('&')) +        elif c == ord('-') and decode: +            if len(decode) == 1: +                r.append('&') +            else: +                r.append(modified_unbase64(decode[1:])) +            decode = bytearray() +        elif decode: +            decode.append(c) +        else: +            r.append(chr(c)) +    if decode: +        r.append(modified_unbase64(decode[1:])) +    bin_str = ''.join(r) +    return (bin_str, len(s)) + + +class StreamReader(codecs.StreamReader): +    def decode(self, s, errors='strict'): +        return decoder(s) + + +class StreamWriter(codecs.StreamWriter): +    def decode(self, s, errors='strict'): +        return encoder(s) + + +def imap4_utf_7(name): +    if name == 'imap4-utf-7': +        return (encoder, decoder, StreamReader, StreamWriter) + + +codecs.register(imap4_utf_7) + + +## testing methods + +def imapUTF7Encode(ust): +    "Returns imap utf-7 encoded version of string" +    return ust.encode('imap4-utf-7') + +def imapUTF7EncodeSequence(seq): +    "Returns imap utf-7 encoded version of strings in sequence" +    return [imapUTF7Encode(itm) for itm in seq] + + +def imapUTF7Decode(st): +    "Returns utf7 encoded version of imap utf-7 string" +    return st.decode('imap4-utf-7') + +def imapUTF7DecodeSequence(seq): +    "Returns utf7 encoded version of imap utf-7 strings in sequence" +    return [imapUTF7Decode(itm) for itm in seq] + + +def utf8Decode(st): +    "Returns utf7 encoded version of imap utf-7 string" +    return st.decode('utf-8') + + +def utf7SequenceToUTF8(seq): +    "Returns utf7 encoded version of imap utf-7 strings in sequence" +    return [itm.decode('imap4-utf-7').encode('utf-8') for itm in seq] + + +__all__ = [ 'imapUTF7Encode', 'imapUTF7Decode' ] + + +if __name__ == '__main__': +    testdata = [ +            (u'foo\r\n\nbar\n', b'foo&AA0ACgAK-bar&AAo-'), +            (u'测试', b'&bUuL1Q-'), +            (u'Hello 世界', b'Hello &ThZ1TA-') +    ] +    for s, e in testdata: +        #assert s == decoder(encoder(s)[0])[0] +        assert s == imapUTF7Decode(e) +        assert e == imapUTF7Encode(s) +        assert s == imapUTF7Decode(imapUTF7Encode(s)) +        assert e == imapUTF7Encode(imapUTF7Decode(e)) +    print("All tests passed!") + +#  vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/msvst_starlet.py b/python/msvst_starlet.py new file mode 100755 index 0000000..e534d3d --- /dev/null +++ b/python/msvst_starlet.py @@ -0,0 +1,646 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# References: +# [1] Jean-Luc Starck, Fionn Murtagh & Jalal M. Fadili +#     Sparse Image and Signal Processing: Wavelets, Curvelets, Morphological Diversity +#     Section 3.5, 6.6 +# +# Credits: +# [1] https://github.com/abrazhe/image-funcut/blob/master/imfun/atrous.py +# +# Aaron LI +# Created: 2016-03-17 +# Updated: 2016-04-22 +# +# ChangeLog: +# 2016-04-22: +#   * Add argument "end-scale" to specifiy the end denoising scale +#   * Check outfile existence first +#   * Add argument "start-scale" to specifiy the start denoising scale +#   * Fix a bug about "p_cutoff" when "comp" contains ALL False's +#   * Show more verbose information/details +# 2016-04-20: +#   * Add argparse and main() for scripting +# + +""" +Starlet wavelet transform, i.e., isotropic undecimated wavelet transform +(IUWT), or à trous wavelet transform. +And multi-scale variance stabling transform (MS-VST), which can be used +to effectively remove the Poisson noises. +""" + +__version__ = "0.2.5" +__date__    = "2016-04-22" + + +import sys +import os +import argparse +from datetime import datetime + +import numpy as np +import scipy as sp +from scipy import signal +from astropy.io import fits + + +class B3Spline:  # {{{ +    """ +    B3-spline wavelet. +    """ +    # scaling function (phi) +    dec_lo = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) / 16 +    dec_hi = np.array([-1.0, -4.0, 10.0, -4.0, -1.0]) / 16 +    rec_lo = np.array([0.0, 0.0, 1.0, 0.0, 0.0]) +    rec_hi = np.array([0.0, 0.0, 1.0, 0.0, 0.0]) +# B3Spline }}} + + +class IUWT:  # {{{ +    """ +    Isotropic undecimated wavelet transform. +    """ +    ## Decomposition filters list: +    # a_{scale} = convole(a_0, filters[scale]) +    # Note: the zero-th scale filter (i.e., delta function) is the first +    # element, thus the array index is the same as the decomposition scale. +    filters = [] + +    phi = None              # wavelet scaling function (2D) +    level = 0               # number of transform level +    decomposition = None    # decomposed coefficients/images +    reconstruction = None   # reconstructed image + +    # convolution boundary condition +    boundary = "symm" + +    def __init__(self, phi=B3Spline.dec_lo, level=None, boundary="symm", +            data=None): +        self.set_wavelet(phi=phi) +        self.level = level +        self.boundary = boundary +        self.data = np.array(data) + +    def reset(self): +        """ +        Reset the object attributes. +        """ +        self.data = None +        self.phi = None +        self.decomposition = None +        self.reconstruction = None +        self.level = 0 +        self.filters = [] +        self.boundary = "symm" + +    def load_data(self, data): +        self.reset() +        self.data = np.array(data) + +    def set_wavelet(self, phi): +        self.reset() +        phi = np.array(phi) +        if phi.ndim == 1: +            phi_ = phi.reshape(1, -1) +            self.phi = np.dot(phi_.T, phi_) +        elif phi.ndim == 2: +            self.phi = phi +        else: +            raise ValueError("Invalid phi dimension") + +    def calc_filters(self): +        """ +        Calculate the convolution filters of each scale. +        Note: the zero-th scale filter (i.e., delta function) is the first +        element, thus the array index is the same as the decomposition scale. +        """ +        self.filters = [] +        # scale 0: delta function +        h = np.array([[1]])  # NOTE: 2D +        self.filters.append(h) +        # scale 1 +        h = self.phi[::-1, ::-1] +        self.filters.append(h) +        for scale in range(2, self.level+1): +            h_up = self.zupsample(self.phi, order=scale-1) +            h2 = signal.convolve2d(h_up[::-1, ::-1], h, mode="same", +                    boundary=self.boundary) +            self.filters.append(h2) + +    def transform(self, data, scale, boundary="symm"): +        """ +        Perform only one scale wavelet transform for the given data. + +        return: +            [ approx, detail ] +        """ +        self.decomposition = [] +        approx = signal.convolve2d(data, self.filters[scale], +                mode="same", boundary=self.boundary) +        detail = data - approx +        return [approx, detail] + +    def decompose(self, level, boundary="symm"): +        """ +        Perform IUWT decomposition in the plain loop way. +        The filters of each scale/level are calculated first, then the +        approximations of each scale/level are calculated by convolving the +        raw/finest image with these filters. + +        return: +            [ W_1, W_2, ..., W_n, A_n ] +            n = level +            W: wavelet details +            A: approximation +        """ +        self.boundary = boundary +        if self.level != level or self.filters == []: +            self.level = level +            self.calc_filters() +        self.decomposition = [] +        approx = self.data +        for scale in range(1, level+1): +            # approximation: +            approx2 = signal.convolve2d(self.data, self.filters[scale], +                    mode="same", boundary=self.boundary) +            # wavelet details: +            w = approx - approx2 +            self.decomposition.append(w) +            if scale == level: +                self.decomposition.append(approx2) +            approx = approx2 +        return self.decomposition + +    def decompose_recursive(self, level, boundary="symm"): +        """ +        Perform the IUWT decomposition in the recursive way. + +        return: +            [ W_1, W_2, ..., W_n, A_n ] +            n = level +            W: wavelet details +            A: approximation +        """ +        self.level = level +        self.boundary = boundary +        self.decomposition = self.__decompose(self.data, self.phi, level=level) +        return self.decomposition + +    def __decompose(self, data, phi, level): +        """ +        2D IUWT decomposition (or stationary wavelet transform). + +        This is a convolution version, where kernel is zero-upsampled +        explicitly. Not fast. + +        Parameters: +        - level : level of decomposition +        - phi : low-pass filter kernel +        - boundary : boundary conditions (passed to scipy.signal.convolve2d, +                     'symm' by default) + +        Returns: +            list of wavelet details + last approximation. Each element in +            the list is an image of the same size as the input image.  +        """ +        if level <= 0: +            return data +        shapecheck = map(lambda a,b:a>b, data.shape, phi.shape) +        assert np.all(shapecheck) +        # approximation: +        approx = signal.convolve2d(data, phi[::-1, ::-1], mode="same", +                boundary=self.boundary) +        # wavelet details: +        w = data - approx +        phi_up = self.zupsample(phi, order=1) +        shapecheck = map(lambda a,b:a>b, data.shape, phi_up.shape) +        if level == 1: +            return [w, approx] +        elif not np.all(shapecheck): +            print("Maximum allowed decomposition level reached", +                    file=sys.stderr) +            return [w, approx] +        else: +            return [w] + self.__decompose(approx, phi_up, level-1) + +    @staticmethod +    def zupsample(data, order=1): +        """ +        Upsample data array by interleaving it with zero's. + +        h{up_order: n}[l] = (1) h[l], if l % 2^n == 0; +                            (2) 0, otherwise +        """ +        shape = data.shape +        new_shape = [ (2**order * (n-1) + 1) for n in shape ] +        output = np.zeros(new_shape, dtype=data.dtype) +        output[[ slice(None, None, 2**order) for d in shape ]] = data +        return output + +    def reconstruct(self, decomposition=None): +        if decomposition is not None: +            reconstruction = np.sum(decomposition, axis=0) +            return reconstruction +        else: +            self.reconstruction = np.sum(self.decomposition, axis=0) + +    def get_detail(self, scale): +        """ +        Get the wavelet detail coefficients of given scale. +        Note: 1 <= scale <= level +        """ +        if scale < 1 or scale > self.level: +            raise ValueError("Invalid scale") +        return self.decomposition[scale-1] + +    def get_approx(self): +        """ +        Get the approximation coefficients of the largest scale. +        """ +        return self.decomposition[-1] +# IUWT }}} + + +class IUWT_VST(IUWT):  # {{{ +    """ +    IUWT with Multi-scale variance stabling transform. + +    Refernce: +    [1] Bo Zhang, Jalal M. Fadili & Jean-Luc Starck, +        IEEE Trans. Image Processing, 17, 17, 2008 +    """ +    # VST coefficients and the corresponding asymptotic standard deviation +    # of each scale. +    vst_coef = [] + +    def reset(self): +        super(self.__class__, self).reset() +        vst_coef = [] + +    def __decompose(self): +        raise AttributeError("No '__decompose' attribute") + +    @staticmethod +    def soft_threshold(data, threshold): +        if isinstance(data, np.ndarray): +            data_th = data.copy() +            data_th[np.abs(data) <= threshold] = 0.0 +            data_th[data > threshold] -= threshold +            data_th[data < -threshold] += threshold +        else: +            data_th = data +            if np.abs(data) <= threshold: +                data_th = 0.0 +            elif data > threshold: +                data_th -= threshold +            else: +                data_th += threshold +        return data_th + +    def tau(self, k, scale): +        """ +        Helper function used in VST coefficients calculation. +        """ +        return np.sum(np.power(self.filters[scale], k)) + +    def filters_product(self, scale1, scale2): +        """ +        Calculate the scalar product of the filters of two scales, +        considering only the overlapped part. +        Helper function used in VST coefficients calculation. +        """ +        if scale1 > scale2: +            filter_big   = self.filters[scale1] +            filter_small = self.filters[scale2] +        else: +            filter_big   = self.filters[scale2] +            filter_small = self.filters[scale1] +        # crop the big filter to match the size of the small filter +        size_big = filter_big.shape +        size_small = filter_small.shape +        size_diff2 = list(map(lambda a,b: (a-b)//2, size_big, size_small)) +        filter_big_crop = filter_big[ +                size_diff2[0]:(size_big[0]-size_diff2[0]), +                size_diff2[1]:(size_big[1]-size_diff2[1])] +        assert(np.all(list(map(lambda a,b: a==b, +                size_small, filter_big_crop.shape)))) +        product = np.sum(filter_small * filter_big_crop) +        return product + +    def calc_vst_coef(self): +        """ +        Calculate the VST coefficients and the corresponding +        asymptotic standard deviation of each scale, according to the +        calculated filters of each scale/level. +        """ +        self.vst_coef = [] +        for scale in range(self.level+1): +            b = 2 * np.sqrt(np.abs(self.tau(1, scale)) / self.tau(2, scale)) +            c = 7.0*self.tau(2, scale) / (8.0*self.tau(1, scale)) - \ +                    self.tau(3, scale) / (2.0*self.tau(2, scale)) +            if scale == 0: +                std = -1.0 +            else: +                std = np.sqrt((self.tau(2, scale-1) / \ +                        (4 * self.tau(1, scale-1)**2)) + \ +                        (self.tau(2, scale) / (4 * self.tau(1, scale)**2)) - \ +                        (self.filters_product(scale-1, scale) / \ +                        (2 * self.tau(1, scale-1) * self.tau(1, scale)))) +            self.vst_coef.append({ "b": b, "c": c, "std": std }) + +    def vst(self, data, scale, coupled=True): +        """ +        Perform variance stabling transform + +        XXX: parameter `coupled' why?? +        Credit: MSVST-V1.0/src/libmsvst/B3VSTAtrous.h +        """ +        self.vst_coupled = coupled +        if self.vst_coef == []: +            self.calc_vst_coef() +        if coupled: +            b = 1.0 +        else: +            b = self.vst_coef[scale]["b"] +        data_vst = b * np.sqrt(np.abs(data + self.vst_coef[scale]["c"])) +        return data_vst + +    def ivst(self, data, scale, cbias=True): +        """ +        Inverse variance stabling transform +        NOTE: assuming that `a_{j} + c^{j}' are all positive. + +        XXX: parameter `cbias' why?? +             `bias correction' is recommended while reconstruct the data +             after estimation +        Credit: MSVST-V1.0/src/libmsvst/B3VSTAtrous.h +        """ +        self.vst_cbias = cbias +        if cbias: +            cb = 1.0 / (self.vst_coef[scale]["b"] ** 2) +        else: +            cb = 0.0 +        data_ivst = data ** 2 + cb - self.vst_coef[scale]["c"] +        return data_ivst + +    def is_significant(self, scale, fdr=0.1, independent=False, verbose=False): +        """ +        Multiple hypothesis testing with false discovery rate (FDR) control. + +        `independent': whether the test statistics of all the null +        hypotheses are independent. +        If `independent=True': FDR <= (m0/m) * q +        otherwise: FDR <= (m0/m) * q * (1 + 1/2 + 1/3 + ... + 1/m) + +        References: +        [1] False discovery rate - Wikipedia +            https://en.wikipedia.org/wiki/False_discovery_rate +        """ +        coef = self.get_detail(scale) +        std = self.vst_coef[scale]["std"] +        pvalues = 2.0 * (1.0 - sp.stats.norm.cdf(np.abs(coef) / std)) +        p_sorted = pvalues.flatten() +        p_sorted.sort() +        N = len(p_sorted) +        if independent: +            cn = 1.0 +        else: +            cn = np.sum(1.0 / np.arange(1, N+1)) +        p_comp = fdr * np.arange(N) / (N * cn) +        comp = (p_sorted < p_comp) +        if np.sum(comp) == 0: +            # `comp' contains ALL False +            p_cutoff = 0.0 +        else: +            # cutoff p-value after FDR control/correction +            p_cutoff = np.max(p_sorted[comp]) +        sig = (pvalues <= p_cutoff) +        if verbose: +            print("std/sigma: %g, p_cutoff: %g" % (std, p_cutoff), +                    flush=True, file=sys.stderr) +        return (sig, p_cutoff) + +    def denoise(self, fdr=0.1, fdr_independent=False, start_scale=1, +            end_scale=None, verbose=False): +        """ +        Denoise the wavelet coefficients by controlling FDR. +        """ +        self.fdr = fdr +        self.fdr_indepent = fdr_independent +        self.denoised = [] +        # supports of significant coefficients of each scale +        self.sig_supports = [None]  # make index match the scale +        self.p_cutoff = [None] +        if verbose: +            print("MSVST denosing ...", flush=True, file=sys.stderr) +        for scale in range(1, self.level+1): +            coef = self.get_detail(scale) +            if verbose: +                print("\tScale %d: " % scale, end="", +                        flush=True, file=sys.stderr) +            if (scale < start_scale) or \ +                    ((end_scale is not None) and scale > end_scale): +                if verbose: +                    print("skipped", flush=True, file=sys.stderr) +                sig, p_cutoff = None, None +            else: +                sig, p_cutoff = self.is_significant(scale, fdr=fdr, +                        independent=fdr_independent, verbose=verbose) +                coef[np.logical_not(sig)] = 0.0 +            # +            self.denoised.append(coef) +            self.sig_supports.append(sig) +            self.p_cutoff.append(p_cutoff) +        # append the last approximation +        self.denoised.append(self.get_approx()) + +    def decompose(self, level=5, boundary="symm", verbose=False): +        """ +        2D IUWT decomposition with VST. +        """ +        self.boundary = boundary +        if self.level != level or self.filters == []: +            self.level = level +            self.calc_filters() +            self.calc_vst_coef() +        self.decomposition = [] +        approx = self.data +        if verbose: +            print("IUWT decomposing (%d levels): " % level, +                    end="", flush=True, file=sys.stderr) +        for scale in range(1, level+1): +            if verbose: +                print("%d..." % scale, end="", flush=True, file=sys.stderr) +            # approximation: +            approx2 = signal.convolve2d(self.data, self.filters[scale], +                    mode="same", boundary=self.boundary) +            # wavelet details: +            w = self.vst(approx, scale=scale-1) - self.vst(approx2, scale=scale) +            self.decomposition.append(w) +            if scale == level: +                self.decomposition.append(approx2) +            approx = approx2 +        if verbose: +            print("DONE!", flush=True, file=sys.stderr) +        return self.decomposition + +    def reconstruct_ivst(self, denoised=True, positive_project=True): +        """ +        Reconstruct the original image from the *un-denoised* decomposition +        by applying the inverse VST. + +        This reconstruction result is also used as the `initial condition' +        for the below `iterative reconstruction' algorithm. + +        arguments: +        * denoised: whether use th denoised data or the direct decomposition +        * positive_project: whether replace negative values with zeros +        """ +        if denoised: +            decomposition = self.denoised +        else: +            decomposition = self.decomposition +        self.positive_project = positive_project +        details = np.sum(decomposition[:-1], axis=0) +        approx = self.vst(decomposition[-1], scale=self.level) +        reconstruction = self.ivst(approx+details, scale=0) +        if positive_project: +            reconstruction[reconstruction < 0.0] = 0.0 +        self.reconstruction = reconstruction +        return reconstruction + +    def reconstruct(self, denoised=True, niter=10, verbose=False): +        """ +        Reconstruct the original image using iterative method with +        L1 regularization, because the denoising violates the exact inverse +        procedure. + +        arguments: +        * denoised: whether use the denoised coefficients +        * niter: number of iterations +        """ +        if denoised: +            decomposition = self.denoised +        else: +            decomposition = self.decomposition +        # L1 regularization +        lbd = 1.0 +        delta = lbd / (niter - 1) +        # initial solution +        solution = self.reconstruct_ivst(denoised=denoised, +                positive_project=True) +        # +        iuwt = IUWT(level=self.level) +        iuwt.calc_filters() +        # iterative reconstruction +        if verbose: +            print("Iteratively reconstructing (%d times): " % niter, +                    end="", flush=True, file=sys.stderr) +        for i in range(niter): +            if verbose: +                print("%d..." % i, end="", flush=True, file=sys.stderr) +            tempd = self.data.copy() +            solution_decomp = [] +            for scale in range(1, self.level+1): +                approx, detail = iuwt.transform(tempd, scale) +                approx_sol, detail_sol = iuwt.transform(solution, scale) +                # Update coefficients according to the significant supports, +                # which are acquired during the denosing precodure with FDR. +                sig = self.sig_supports[scale] +                detail_sol[sig] = detail[sig] +                detail_sol = self.soft_threshold(detail_sol, threshold=lbd) +                # +                solution_decomp.append(detail_sol) +                tempd = approx.copy() +                solution = approx_sol.copy() +            # last approximation (the two are the same) +            solution_decomp.append(approx) +            # reconstruct +            solution = iuwt.reconstruct(decomposition=solution_decomp) +            # discard all negative values +            solution[solution < 0] = 0.0 +            # +            lbd -= delta +        if verbose: +            print("DONE!", flush=True, file=sys.stderr) +        # +        self.reconstruction = solution +        return self.reconstruction +# IUWT_VST }}} + + +def main(): +    # commandline arguments parser +    parser = argparse.ArgumentParser( +            description="Poisson Noise Removal with Multi-scale Variance " + \ +                        "Stabling Transform and Wavelet Transform", +            epilog="Version: %s (%s)" % (__version__, __date__)) +    parser.add_argument("-l", "--level", dest="level", +            type=int, default=5, +            help="level of the IUWT decomposition") +    parser.add_argument("-r", "--fdr", dest="fdr", +            type=float, default=0.1, +            help="false discovery rate") +    parser.add_argument("-I", "--fdr-independent", dest="fdr_independent", +            action="store_true", default=False, +            help="whether the FDR null hypotheses are independent") +    parser.add_argument("-s", "--start-scale", dest="start_scale", +            type=int, default=1, +            help="which scale to start the denoising (inclusive)") +    parser.add_argument("-e", "--end-scale", dest="end_scale", +            type=int, default=0, +            help="which scale to end the denoising (inclusive)") +    parser.add_argument("-n", "--niter", dest="niter", +            type=int, default=10, +            help="number of iterations for reconstruction") +    parser.add_argument("-v", "--verbose", dest="verbose", +            action="store_true", default=False, +            help="show verbose progress") +    parser.add_argument("-C", "--clobber", dest="clobber", +            action="store_true", default=False, +            help="overwrite output file if exists") +    parser.add_argument("infile", help="input image with Poisson noises") +    parser.add_argument("outfile", help="output denoised image") +    args = parser.parse_args() + +    if args.end_scale == 0: +        args.end_scale = args.level + +    if args.verbose: +        print("infile: '%s'" % args.infile, file=sys.stderr) +        print("outfile: '%s'" % args.outfile, file=sys.stderr) +        print("level: %d" % args.level, file=sys.stderr) +        print("fdr: %.2f" % args.fdr, file=sys.stderr) +        print("fdr_independent: %s" % args.fdr_independent, file=sys.stderr) +        print("start_scale: %d" % args.start_scale, file=sys.stderr) +        print("end_scale: %d" % args.end_scale, file=sys.stderr) +        print("niter: %d\n" % args.niter, flush=True, file=sys.stderr) + +    if not args.clobber and os.path.exists(args.outfile): +        raise OSError("outfile '%s' already exists" % args.outfile) + +    imgfits = fits.open(args.infile) +    img = imgfits[0].data +    # Remove Poisson noises +    msvst = IUWT_VST(data=img) +    msvst.decompose(level=args.level, verbose=args.verbose) +    msvst.denoise(fdr=args.fdr, fdr_independent=args.fdr_independent, +            start_scale=args.start_scale, end_scale=args.end_scale, +            verbose=args.verbose) +    msvst.reconstruct(denoised=True, niter=args.niter, verbose=args.verbose) +    img_denoised = msvst.reconstruction +    # Output +    imgfits[0].data = img_denoised +    imgfits[0].header.add_history("%s: Removed Poisson Noises @ %s" % ( +                os.path.basename(sys.argv[0]), datetime.utcnow().isoformat())) +    imgfits[0].header.add_history("  TOOL: %s (v%s, %s)" % ( +                os.path.basename(sys.argv[0]), __version__, __date__)) +    imgfits[0].header.add_history("  PARAM: %s" % " ".join(sys.argv[1:])) +    imgfits.writeto(args.outfile, checksum=True, clobber=args.clobber) + + +if __name__ == "__main__": +    main() + diff --git a/python/plot.py b/python/plot.py new file mode 100644 index 0000000..b65f8a3 --- /dev/null +++ b/python/plot.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# +# Credits: http://www.aosabook.org/en/matplotlib.html +# +# Aaron LI +# 2016-03-14 +# + +# Import the FigureCanvas from the backend of your choice +#  and attach the Figure artist to it. +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure +fig = Figure() +canvas = FigureCanvas(fig) + +# Import the numpy library to generate the random numbers. +import numpy as np +x = np.random.randn(10000) + +# Now use a figure method to create an Axes artist; the Axes artist is +#  added automatically to the figure container fig.axes. +# Here "111" is from the MATLAB convention: create a grid with 1 row and 1 +#  column, and use the first cell in that grid for the location of the new +#  Axes. +ax = fig.add_subplot(111) + +# Call the Axes method hist to generate the histogram; hist creates a +#  sequence of Rectangle artists for each histogram bar and adds them +#  to the Axes container.  Here "100" means create 100 bins. +ax.hist(x, 100) + +# Decorate the figure with a title and save it. +ax.set_title('Normal distribution with $\mu=0, \sigma=1$') +fig.savefig('matplotlib_histogram.png') + diff --git a/python/plot_tprofiles_zzh.py b/python/plot_tprofiles_zzh.py new file mode 100644 index 0000000..e5824e9 --- /dev/null +++ b/python/plot_tprofiles_zzh.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# +# Weitian LI +# 2015-09-11 +# + +""" +Plot a list of *temperature profiles* in a grid of subplots with Matplotlib. +""" + +import matplotlib.pyplot as plt + + +def plot_tprofiles(tplist, nrows, ncols, +        xlim=None, ylim=None, logx=False, logy=False, +        xlab="", ylab="", title=""): +    """ +    Plot a list of *temperature profiles* in a grid of subplots of size +    nrow x ncol. Each subplot is related to a temperature profile. +    All the subplots share the same X and Y axes. +    The order is by row. + +    The tplist is a list of dictionaries, each of which contains all the +    necessary data to make the subplot. + +    The dictionary consists of the following components: +    tpdat = { +        "name": "NAME", +        "radius": [[radius points], [radius errors]], +        "temperature": [[temperature points], [temperature errors]], +        "radius_model": [radus points of the fitted model], +        "temperature_model": [ +            [fitted model value], +            [lower bounds given by the model], +            [upper bounds given by the model] +        ] +    } + +    Arguments: +        tplist - a list of dictionaries containing the data of each +                 temperature profile. +                 Note that the length of this list should equal to nrows*ncols. +        nrows  - number of rows of the subplots +        ncols  - number of columns of the subplots +        xlim   - limits of the X axis +        ylim   - limits of the Y axis +        logx   - whether to set the log scale for X axis +        logy   - whether to set the log scale for Y axis +        xlab   - label for the X axis +        ylab   - label for the Y axis +        title  - title for the whole plot +    """ +    assert len(tplist) == nrows*ncols, "tplist length != nrows*ncols" +    # All subplots share both X and Y axes. +    fig, axarr = plt.subplots(nrows, ncols, sharex=True, sharey=True) +    # Set title for the whole plot. +    if title != "": +        fig.suptitle(title) +    # Set xlab and ylab for each subplot +    if xlab != "": +        for ax in axarr[-1, :]: +            ax.set_xlabel(xlab) +    if ylab != "": +        for ax in axarr[:, 0]: +            ax.set_ylabel(ylab) +    for ax in axarr.reshape(-1): +        # Set xlim and ylim. +        if xlim is not None: +            ax.set_xlim(xlim) +        if ylim is not None: +            ax.set_ylim(ylim) +        # Set xscale and yscale. +        if logx: +            ax.set_xscale("log", nonposx="clip") +        if logy: +            ax.set_yscale("log", nonposy="clip") +    # Decrease the spacing between the subplots and suptitle +    fig.subplots_adjust(top=0.94) +    # Eleminate the spaces between each row and column. +    fig.subplots_adjust(hspace=0, wspace=0) +    # Hide X ticks for all subplots but the bottom row. +    plt.setp([ax.get_xticklabels() for ax in axarr[:-1, :].reshape(-1)], +            visible=False) +    # Hide Y ticks for all subplots but the left column. +    plt.setp([ax.get_yticklabels() for ax in axarr[:, 1:].reshape(-1)], +            visible=False) +    # Plot each temperature profile in the tplist +    for i, ax in zip(range(len(tplist)), axarr.reshape(-1)): +        tpdat = tplist[i] +        # Add text to display the name. +        # The text is placed at (0.95, 0.95), i.e., the top-right corner, +        # with respect to this subplot, and the top-right part of the text +        # is aligned to the above position. +        ax_pois = ax.get_position() +        ax.text(0.95, 0.95, tpdat["name"], +                verticalalignment="top", horizontalalignment="right", +                transform=ax.transAxes, color="black", fontsize=10) +        # Plot data points +        if isinstance(tpdat["radius"][0], list) and \ +                len(tpdat["radius"]) == 2 and \ +                isinstance(tpdat["temperature"][0], list) and \ +                len(tpdat["temperature"]) == 2: +            # Data points have symmetric errorbar +            ax.errorbar(tpdat["radius"][0], tpdat["temperature"][0], +                    xerr=tpdat["radius"][1], yerr=tpdat["temperature"][1], +                    color="black", linewidth=1.5, linestyle="None") +        else: +            ax.plot(tpdat["radius"], tpdat["temperature"], +                    color="black", linewidth=1.5, linestyle="None") +        # Plot model line and bounds band +        if isinstance(tpdat["temperature_model"][0], list) and \ +                len(tpdat["temperature_model"]) == 3: +            # Model data have bounds +            ax.plot(tpdat["radius_model"], tpdat["temperature_model"][0], +                    color="blue", linewidth=1.0) +            # Plot model bounds band +            ax.fill_between(tpdat["radius_model"], +                    y1=tpdat["temperature_model"][1], +                    y2=tpdat["temperature_model"][2], +                    color="gray", alpha=0.5) +        else: +            ax.plot(tpdat["radius_model"], tpdat["temperature_model"], +                    color="blue", linewidth=1.5) +    return (fig, axarr) + +#  vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/randomize_events.py b/python/randomize_events.py new file mode 100755 index 0000000..e1a6e31 --- /dev/null +++ b/python/randomize_events.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# +# Randomize the (X,Y) position of each X-ray photon events according +# to a Gaussian distribution of given sigma. +# +# References: +# [1] G. Scheellenberger, T.H. Reiprich, L. Lovisari, J. Nevalainen & L. David +#     2015, A&A, 575, A30 +# +# +# Aaron LI +# Created: 2016-03-24 +# Updated: 2016-03-24 +# + +from astropy.io import fits +import numpy as np + +import os +import sys +import datetime +import argparse + + +CHANDRA_ARCSEC_PER_PIXEL = 0.492 + +def randomize_events(infile, outfile, sigma, clobber=False): +    """ +    Randomize the position (X,Y) of each X-ray event according to a +    specified size/sigma Gaussian distribution. +    """ +    sigma_pix = sigma / CHANDRA_ARCSEC_PER_PIXEL +    evt_fits = fits.open(infile) +    evt_table = evt_fits[1].data +    # (X,Y) physical coordinate +    evt_x = evt_table["x"] +    evt_y = evt_table["y"] +    rand_x = np.random.normal(scale=sigma_pix, size=evt_x.shape)\ +            .astype(evt_x.dtype) +    rand_y = np.random.normal(scale=sigma_pix, size=evt_y.shape)\ +            .astype(evt_y.dtype) +    evt_x += rand_x +    evt_y += rand_y +    # Add history to FITS header +    evt_hdr = evt_fits[1].header +    evt_hdr.add_history("TOOL: %s @ %s" % ( +            os.path.basename(sys.argv[0]), +            datetime.datetime.utcnow().isoformat())) +    evt_hdr.add_history("COMMAND: %s" % " ".join(sys.argv)) +    evt_fits.writeto(outfile, clobber=clobber, checksum=True) + + +def main(): +    parser = argparse.ArgumentParser( +            description="Randomize the (X,Y) of each X-ray event") +    parser.add_argument("infile", help="input event file") +    parser.add_argument("outfile", help="output randomized event file") +    parser.add_argument("-s", "--sigma", dest="sigma", +            required=True, type=float, +            help="sigma/size of the Gaussian distribution used" + \ +                 "to randomize the position of events (unit: arcsec)") +    parser.add_argument("-C", "--clobber", dest="clobber", +            action="store_true", help="overwrite output file if exists") +    args = parser.parse_args() + +    randomize_events(args.infile, args.outfile, +            sigma=args.sigma, clobber=args.clobber) + + +if __name__ == "__main__": +    main() + diff --git a/python/rebuild_ipod_db.py b/python/rebuild_ipod_db.py new file mode 100755 index 0000000..20c5454 --- /dev/null +++ b/python/rebuild_ipod_db.py @@ -0,0 +1,595 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# LICENSE: +# --------------------------------------------------------------------------- +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA +# --------------------------------------------------------------------------- +# +# Based on Matrin Fiedler's "rebuild_db.py" v1.0-rc1 (2006-04-26): +# http://shuffle-db.sourceforge.net/ +# + +from __future__ import print_function + + +__title__   = "iPod Shuffle Database Builder" +__author__  = "Aaron LI" +__version__ = "2.0.2" +__date__    = "2016-04-16" + + +import sys +import os +import operator +import array +import random +import fnmatch +import operator +import string +import argparse +import functools +import shutil +from collections import OrderedDict + + +domains = [] +total_count = 0 + + +class LogObj: +    """ +    Print and log the process information. +    """ +    def __init__(self, filename=None): +        self.filename = filename + +    def open(self): +        if self.filename: +            try: +                self.logfile = open(self.filename, "w") +            except IOError: +                self.logfile = None +        else: +            self.logfile = None + +    def log(self, line="", end="\n"): +        value = line + end +        if self.logfile: +            self.logfile.write(value) +        print(value, end="") + +    def close(self): +        if self.logfile: +            self.logfile.close() + + +class Rule: +    """ +    A RuleSet for the way to handle the found playable files. +    """ +    SUPPORT_PROPS = ("filename", "size", "ignore", "type", +                     "shuffle", "reuse", "bookmark") + +    def __init__(self, conditions=None, actions=None): +        self.conditions = conditions +        self.actions    = actions + +    @classmethod +    def parse(cls, rule): +        """ +        Parse the whole line of a rule. + +        Syntax: +            condition1, condition2, ...: action1, action2, ... + +        condition examples: +          * filename ~ "*.mp3" +          * size > 100000 +        action examples: +          * ignore = 1 +          * shuffle = 1 + +        Return: a object of this class with the parsed rule. +        """ +        conditions, actions = rule.split(":") +        conditions = list(map(cls.parse_condition, conditions.split(","))) +        actions = dict(map(cls.parse_action, actions.split(","))) +        return cls(conditions, actions) + +    @classmethod +    def parse_condition(cls, cond): +        sep_pos = min([ cond.find(sep) for sep in "~=<>" \ +                        if cond.find(sep)>0 ]) +        prop = cond[:sep_pos].strip() +        if prop not in cls.SUPPORT_PROPS: +            raise ValueError("WARNING: unknown property '%s'" % prop) +        return (prop, cond[sep_pos], +                cls.parse_value(cond[sep_pos+1:].strip())) + +    @classmethod +    def parse_action(cls, action): +        prop, value = map(str.strip, action.split("=", 1)) +        if prop not in cls.SUPPORT_PROPS: +            raise ValueError("WARNING: unknown property '%s'" % prop) +        return (prop, cls.parse_value(value)) + +    @staticmethod +    def parse_value(value): +        value = value.strip().strip('"').strip("'") +        try: +            return int(value) +        except ValueError: +            return value + +    def match(self, props): +        """ +        Check whether the given props match all the conditions. +        """ +        def match_condition(props, cond): +            """ +            Check whether the given props match the given condition. +            """ +            try: +                prop, op, ref = props[cond[0]], cond[1], cond[2] +            except KeyError: +                return False +            if op == "~": +                return fnmatch.fnmatchcase(prop.lower(), ref.lower()) +            elif op == "=": +                return prop == ref +            elif op == ">": +                return prop > ref +            elif op == "<": +                return prop < ref +            else: +                return False +        # +        return functools.reduce(operator.and_, +                                [ match_condition(props, cond) \ +                                        for cond in self.conditions ], +                                True) + + +class Entries: +    """ +    Walk through the directory to find all files, and filter by the +    extensions to get all the playable files. +    """ +    PLAYABLE_EXTS = (".mp3", ".m4a", ".m4b", ".m4p", ".aa", ".wav") + +    def __init__(self, dirs=[], rename=True, recursive=True, ignore_dup=True): +        self.entries = [] +        self.add_dirs(dirs=dirs, rename=rename, recursive=recursive, +                ignore_dup=ignore_dup) + +    def add_dirs(self, dirs=[], rename=True, recursive=True, ignore_dup=True): +        for dir in dirs: +            self.add_dir(dir=dir, rename=rename, recursive=recursive, +                    ignore_dup=ignore_dup) + +    def add_dir(self, dir, rename=True, recursive=True, ignore_dup=True): +        global logobj +        if recursive: +            # Get all directories, and rename them if needed +            dirs = [] +            for dirName, subdirList, fileList in os.walk(dir): +                dirs.append(dirName) +            for dirName in dirs: +                newDirName = self.get_newname(dirName) +                if rename and newDirName != dirName: +                    logobj.log("Rename: '%s' -> '%s'" % (dirName, newDirName)) +                    shutil.move(dirName, newDirName) +            # Get all files +            files = [] +            for dirName, subdirList, fileList in os.walk(dir): +                files.extend([ os.path.join(dirName, f) for f in fileList ]) +        else: +            # rename the directory if needed +            newDir = self.get_newname(dir) +            if rename and newDir != dir: +                logobj.log("Rename: '%s' -> '%s'" % (dir, newDir)) +                shutil.move(dir, newDir) +            files = [ os.path.join(newDir, f) for f in self.listfiles(newDir) ] +        # +        for fn in files: +            # rename filename if needed +            newfn = self.get_newname(fn) +            if rename and newfn != fn: +                logobj.log("Rename: '%s' -> '%s'" % (fn, newfn)) +                shutil.move(fn, newfn) +                fn = newfn +            # filter by playable extensions +            if os.path.splitext(fn)[1].lower() not in self.PLAYABLE_EXTS: +                continue +            if ignore_dup and (fn in self.entries): +                continue +            self.entries.append(fn) +            print("Entry: %s" % fn) + +    @staticmethod +    def listfiles(path, ignore_hidden=True): +        """ +        List only files of a directory +        """ +        for f in os.listdir(path): +            if os.path.isfile(os.path.join(path, f)): +                if ignore_hidden and f[0] != ".": +                    yield f +                else: +                    yield f + +    @staticmethod +    def get_newname(path): +        def conv_char(ch): +            safe_char = string.ascii_letters + string.digits + "-_" +            if ch in safe_char: +                return ch +            return "_" +        # +        if path == ".": +            return path +        dirname, basename = os.path.split(path) +        base, ext = os.path.splitext(basename) +        newbase = "".join(map(conv_char, base)) +        if basename == newbase+ext: +            return os.path.join(dirname, basename) +        if os.path.exists("%s/%s%s" % (dirname, newbase, ext)): +            i = 0 +            while os.path.exists("%s/%s_%d%s" % (dirname, newbase, i, ext)): +                i += 1 +            newbase += "_%d" % i +        newname = "%s/%s%s" % (dirname, newbase, ext) +        return newname + +    def fix_and_sort(self): +        """ +        Fix the entries' pathes (should starts with "/"), and sort. +        """ +        self.entries = [ "/"+f.lstrip("./") for f in self.entries ] +        self.entries.sort() + +    def apply_rules(self, rules): +        """ +        Apply rules to the found entries. +        The filtered/updated entries and properties are saved in: +        'self.entries_dict' +        """ +        self.entries_dict = OrderedDict() + +        for fn in self.entries: +            # set default properties +            props = { +                "filename": fn, +                "size":     os.stat(fn[1:]).st_size, +                "ignore":   0, +                "type":     1, +                "shuffle":  1, +                "bookmark": 0 +            } +            # check and apply rules +            for rule in rules: +                if rule.match(props): +                    props.update(rule.actions) +            # +            if props["ignore"]: +                continue +            # +            self.entries_dict[fn] = props + +    def get_entries(self): +        return self.entries_dict.items() + + +class iTunesSD: +    """ +    Class to handle the iPod Shuffle main database +    "iPod_Control/iTunes/iTunesSD" +    """ +    def __init__(self, dbfile="./iPod_Control/iTunes/iTunesSD"): +        self.dbfile = dbfile +        self.load() + +    def load(self): +        """ +        Load original header and entries. +        """ +        self.old_entries = {} +        self.header_main = array.array("B")  # unsigned integer array +        self.header_entry = array.array("B")  # unsigned integer array +        db = open(self.dbfile, "rb") +        try: +            self.header_main.fromfile(db, 18) +            self.header_entry.fromfile(db, 33) +            db.seek(18) +            entry = db.read(558) +            while len(entry) == 558: +                filename = entry[33::2].split(b"\0", 1)[0] +                self.old_entries[filename] = entry +                entry = db.read(558) +        except EOFError: +            pass +        db.close() +        print("Loaded %d entries from existing database" % \ +                len(self.old_entries)) + +    def build_header(self, force=False): +        global logobj +        # rebuild database header +        if force or len(self.header_main) != 18: +            logobj.log("Rebuild iTunesSD main header ...") +            del self.header_main[:] +            self.header_main.fromlist([0,0,0,1,6,0,0,0,18] + [0]*9) +        if force or len(self.header_entry) != 33: +            logobj.log("Rebuild iTunesSD entry header ...") +            del self.header_entry[:] +            self.header_entry.fromlist([0,2,46,90,165,1] + [0]*20 + \ +                                       [100,0,0,1,0,2,0]) + +    def add_entries(self, entries, reuse=True): +        """ +        Prepare the entries for database +        """ +        self.entries = OrderedDict() + +        for fn, props in entries.get_entries(): +            if reuse and props.get("reuse") and (fn in self.old_entries): +                # retrieve entry from old entries +                entry = self.old_entries[fn] +            else: +                # build new entry +                self.header_entry[29] = props["type"] +                entry_data = "".join([ c+"\0" for c in fn[:261] ]) + \ +                        "\0"*(558 - len(self.header_entry) - 2*len(fn)) +                entry = self.header_entry.tostring() + \ +                        entry_data.encode("utf-8") +            # modify the shuffle and bookmark flags +            entry = entry[:555] + chr(props["shuffle"]).encode("utf-8") + \ +                    chr(props["bookmark"]).encode("utf-8") + entry[557] +            # +            self.entries[fn] = entry + +    def write(self, dbfile=None): +        if dbfile is None: +            dbfile = self.dbfile +        # Make a backup +        if os.path.exists(dbfile): +            shutil.copy2(dbfile, dbfile+"_bak") + +        # write main database file +        with open(dbfile, "wb") as db: +            self.header_main.tofile(db) +            for entry in self.entries.values(): +                db.write(entry) +            # Update database header +            num_entries = len(self.entries) +            db.seek(0) +            db.write(b"\0%c%c" % (num_entries>>8, num_entries&0xFF)) + + +class iTunesPState: +    """ +    iPod Shuffle playback state database: "iPod_Control/iTunes/iTunesPState" +    """ +    def __init__(self, dbfile="iPod_Control/iTunes/iTunesPState"): +        self.dbfile = dbfile +        self.load() + +    def load(self): +        with open(self.dbfile, "rb") as db: +            a = array.array("B") +            a.fromstring(db.read()) +            self.PState = a.tolist() + +    def update(self, volume=None): +        if len(self.PState) != 21: +            # volume 29, FW ver 1.0 +            self.PState = self.listval(29) + [0]*15 + self.listval(1) +        # track 0, shuffle mode, start of track +        self.PState[3:15] = [0]*6 + [1] + [0]*5 +        if volume is not None: +            self.PState[:3] = self.listval(volume) + +    def write(self, dbfile=None): +        if dbfile is None: +            dbfile = self.dbfile +        # Make a backup +        if os.path.exists(dbfile): +            shutil.copy2(dbfile, dbfile+"_bak") + +        with open(dbfile, "wb") as db: +            array.array("B", self.PState).tofile(db) + +    @staticmethod +    def listval(i): +        if i < 0: +            i += 0x1000000 +        return [i&0xFF, (i>>8)&0xFF, (i>>16)&0xFF] + + +class iTunesStats: +    """ +    iPod Shuffle statistics database: "iPod_Control/iTunes/iTunesStats" +    """ +    def __init__(self, dbfile="iPod_Control/iTunes/iTunesStats"): +        self.dbfile = dbfile + +    def write(self, count, dbfile=None): +        if dbfile is None: +            dbfile = self.dbfile +        # Make a backup +        if os.path.exists(dbfile): +            shutil.copy2(dbfile, dbfile+"_bak") + +        with open(dbfile, "wb") as db: +            data = self.stringval(count) + "\0"*3 + \ +                     (self.stringval(18) + "\xff"*3 + "\0"*12) * count +            db.write(data.encode("utf-8")) + +    @staticmethod +    def stringval(i): +        if i < 0: +            i += 0x1000000 +        return "%c%c%c" % (i&0xFF, (i>>8)&0xFF, (i>>16)&0xFF) + + +class iTunesShuffle: +    """ +    iPod shuffle database: "iPod_Control/iTunes/iTunesShuffle" +    """ +    def __init__(self, dbfile="iPod_Control/iTunes/iTunesShuffle"): +        self.dbfile = dbfile + +    def shuffle(self, entries): +        """ +        Generate the shuffle sequences for the entries, and take care +        of the "shuffle" property. +        """ +        shuffle_prop = [ props["shuffle"] +                         for fn, props in entries.get_entries() ] +        shuffle_idx = [ idx for idx, s in enumerate(shuffle_prop) if s == 1 ] +        shuffled = shuffle_idx.copy() +        random.seed() +        random.shuffle(shuffled) +        shuffle_seq = list(range(len(shuffle_prop))) +        for i, idx in enumerate(shuffle_idx): +            shuffle_seq[idx] = shuffled[i] +        self.shuffle_seq = shuffle_seq + +    def write(self, dbfile=None): +        if dbfile is None: +            dbfile = self.dbfile +        # Make a backup +        if os.path.exists(dbfile): +            shutil.copy2(dbfile, dbfile+"_bak") + +        with open(dbfile, "wb") as db: +            data = "".join(map(iTunesStats.stringval, self.shuffle_seq)) +            db.write(data.encode("utf-8")) + + +def main(): +    prog_basename = os.path.splitext(os.path.basename(sys.argv[0]))[0] + +    # command line arguments +    parser = argparse.ArgumentParser( +            description="Rebuild iPod Shuffle Database", +            epilog="Version: %s (%s)\n\n" % (__version__, __date__) + \ +                   "Only 1st and 2nd iPod Shuffle supported!\n\n" + \ +                   "The script must be placed under the iPod's root directory") +    parser.add_argument("-f", "--force", dest="force", action="store_true", +            help="always rebuild database entries, do NOT reuse old ones") +    parser.add_argument("-M", "--no-rename", dest="norename", +            action="store_false", default=True, +            help="do NOT rename files") +    parser.add_argument("-V", "--volume", dest="volume", type=int, +            help="set playback volume (0 - 38)") +    parser.add_argument("-r", "--rulesfile", dest="rulesfile", +            default="%s.rules" % prog_basename, +            help="additional rules filename") +    parser.add_argument("-l", "--logfile", dest="logfile", +            default="%s.log" % prog_basename, +            help="log output filename") +    parser.add_argument("dirs", nargs="*", +            help="directories to be searched for playable files") +    args = parser.parse_args() + +    flag_reuse = not args.force + +    # Start logging +    global logobj +    logobj = LogObj(args.logfile) +    logobj.open() + +    # Rules for how to handle the found playable files +    rules = [] +    # Add default rules +    rules.append(Rule(conditions=[("filename", "~", "*.mp3")], +                      actions={"type":1, "shuffle":1, "bookmark":0})) +    rules.append(Rule(conditions=[("filename", "~", "*.m4?")], +                      actions={"type":2, "shuffle":1, "bookmark":0})) +    rules.append(Rule(conditions=[("filename", "~", "*.m4b")], +                      actions={"shuffle":0, "bookmark":1})) +    rules.append(Rule(conditions=[("filename", "~", "*.aa")], +                      actions={"type":1, "shuffle":0, "bookmark":1, "reuse":1})) +    rules.append(Rule(conditions=[("filename", "~", "*.wav")], +                      actions={"type":4, "shuffle":0, "bookmark":0})) +    rules.append(Rule(conditions=[("filename", "~", "*.book.???")], +                      actions={"shuffle":0, "bookmark":1})) +    rules.append(Rule(conditions=[("filename", "~", "*.announce.???")], +                      actions={"shuffle":0, "bookmark":0})) +    rules.append(Rule(conditions=[("filename", "~", "/backup/*")], +                      actions={"ignore":1})) +    # Load additional rules +    try: +        for line in open(args.rulesfile, "r").readlines(): +            rules.append(Rule.parse(line)) +        logobj.log("Loaded additional rules from file: %s" % args.rulesfile) +    except IOError: +        pass + +    # cd to the directory of this script +    os.chdir(os.path.dirname(sys.argv[0])) + +    if not os.path.isdir("iPod_Control/iTunes"): +        logobj.log("ERROR: No iPod control directory found!") +        logobj.log("Please make sure that:") +        logobj.log("(*) this script is placed under the iPod's root directory") +        logobj.log("(*) the iPod was correctly initialized with iTunes") +        sys.exit(1) + +    # playable entries +    logobj.log("Search for playable entries ...") +    entries = Entries() +    if args.dirs: +        for dir in args.dirs: +            entries.add_dir(dir=dir, recursive=True, rename=args.norename) +    else: +        entries.add_dir(".", recursive=True, rename=args.norename) +    entries.fix_and_sort() +    logobj.log("Apply rules to entries ...") +    entries.apply_rules(rules=rules) + +    # read main database file +    logobj.log("Update main database ...") +    db = iTunesSD(dbfile="iPod_Control/iTunes/iTunesSD") +    db.build_header(force=args.force) +    db.add_entries(entries=entries, reuse=flag_reuse) +    assert len(db.entries) == len(entries.get_entries()) +    db.write() +    logobj.log("Added %d entries ..." % len(db.entries)) + +    # other misc databases +    logobj.log("Update playback state database ...") +    db_pstate = iTunesPState(dbfile="iPod_Control/iTunes/iTunesPState") +    db_pstate.update(volume=args.volume) +    db_pstate.write() +    logobj.log("Update statistics database ...") +    db_stats = iTunesStats(dbfile="iPod_Control/iTunes/iTunesStats") +    db_stats.write(count=len(db.entries)) +    logobj.log("Update shuffle database ...") +    db_shuffle = iTunesShuffle(dbfile="iPod_Control/iTunes/iTunesShuffle") +    db_shuffle.shuffle(entries=entries) +    db_shuffle.write() + +    logobj.log("The iPod Shuffle database was rebuilt successfully!") + +    logobj.close() + + +if __name__ == "__main__": +    main() + +#  vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/splitBoxRegion.py b/python/splitBoxRegion.py new file mode 100755 index 0000000..5254686 --- /dev/null +++ b/python/splitBoxRegion.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8- +# +# Split the strip-shaped CCD gaps regions into a series of small +# square regions, which are used as the input regions of 'roi' to +# determine the corresponding background regions, and finally providied +# to 'dmfilth' in order to fill in the CCD gaps. +# +# Aaron LI +# 2015/08/12 +# +# Changelogs: +# v0.1.0, 2015/08/12 +#   * initial version +# + + +__version__ = "0.1.0" +__date__    = "2015/08/12" + + +import os +import sys +import re +import math +import argparse +from io import TextIOWrapper + + +## BoxRegion {{{ +class BoxRegion(object): +    """ +    CIAO/DS9 "rotbox"/"box" region class. + +    rotbox/box format: +        rotbox(xc, yc, width, height, rotation) +        box(xc, yc, width, height, rotation) +    Notes: +        rotation: [0, 360) (degree) +    """ +    def __init__(self, xc=None, yc=None, +                 width=None, height=None, rotation=None): +        self.regtype  = "rotbox" +        self.xc       = xc +        self.yc       = yc +        self.width    = width +        self.height   = height +        self.rotation = rotation + +    def __str__(self): +        return "%s(%s,%s,%s,%s,%s)" % (self.regtype, self.xc, self.yc, +                self.width, self.height, self.rotation) + +    @classmethod +    def parse(cls, regstr): +        """ +        Parse region string. +        """ +        regex_box = re.compile(r'^\s*(box|rotbox)\(([0-9. ]+),([0-9. ]+),([0-9. ]+),([0-9. ]+),([0-9. ]+)\)\s*$', re.I) +        m = regex_box.match(regstr) +        if m: +            regtype  = m.group(1) +            xc       = float(m.group(2)) +            yc       = float(m.group(3)) +            width    = float(m.group(4)) +            height   = float(m.group(5)) +            rotation = float(m.group(6)) +            return cls(xc, yc, width, height, rotation) +        else: +            return None + +    def split(self, filename=None): +        """ +        Split strip-shaped box region into a series small square regions. +        """ +        angle = self.rotation * math.pi / 180.0 +        # to record the center coordinates of each split region +        centers = [] +        if self.width > self.height: +            # number of regions after split +            nreg = math.ceil(self.width / self.height) +            # width & height of the split region +            width = self.width / nreg +            height = self.height +            # position of the left-most region +            x_l = self.xc - 0.5*self.width * math.cos(angle) +            y_l = self.yc - 0.5*self.width * math.sin(angle) +            for i in range(nreg): +                x = x_l + (0.5 + i) * width * math.cos(angle) +                y = y_l + (0.5 + i) * width * math.sin(angle) +                centers.append((x, y)) +        else: +            # number of regions after split +            nreg = math.ceil(self.height / self.width) +            # width & height of the split region +            width = self.width +            height = self.height / nreg +            # position of the left-most region +            x_l = self.xc + 0.5*self.height * math.cos(angle + math.pi/2) +            y_l = self.yc + 0.5*self.height * math.sin(angle + math.pi/2) +            for i in range(nreg): +                x = x_l - (0.5 + i) * height * math.cos(angle + math.pi/2) +                y = y_l - (0.5 + i) * height * math.sin(angle + math.pi/2) +                centers.append((x, y)) +        # create split regions +        regions = [] +        for (x, y) in centers: +            regions.append(self.__class__(x, y, width+2, height+2, +                                          self.rotation)) +        # write split regions into file if specified +        if isinstance(filename, str): +            regout = open(filename, "w") +            regout.write("\n".join(map(str, regions))) +            regout.close() +        else: +            return regions + +## BoxRegion }}} + + +def main(): +    # command line arguments +    parser = argparse.ArgumentParser( +            description="Split strip-shaped rotbox region into " + \ +                    "a series of small square regions.", +            epilog="Version: %s (%s)" % (__version__, __date__)) +    parser.add_argument("-V", "--version", action="version", +            version="%(prog)s " + "%s (%s)" % (__version__, __date__)) +    parser.add_argument("infile", help="input rotbox region file") +    parser.add_argument("outfile", help="output file of the split regions") +    args = parser.parse_args() + +    outfile = open(args.outfile, "w") +    regex_box = re.compile(r'^\s*(box|rotbox)\([0-9., ]+\)\s*$', re.I) +    for line in open(args.infile, "r"): +        if regex_box.match(line): +            reg = BoxRegion.parse(line) +            split_regs = reg.split() +            outfile.write("\n".join(map(str, split_regs)) + "\n") +        else: +            outfile.write(line) + +    outfile.close() + + +if __name__ == "__main__": +    main() + diff --git a/python/splitCCDgaps.py b/python/splitCCDgaps.py new file mode 100644 index 0000000..bc26c29 --- /dev/null +++ b/python/splitCCDgaps.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8- +# +# Aaron LI +# 2015/08/12 +# + +""" +Split the long-strip-shaped CCD gaps regions into a series of little +square regions, which are used as the input regions of 'roi' to +determine the corresponding background regions, and finally providied +to 'dmfilth' in order to fill in the CCD gaps. +""" + + +import re +import math +from io import TextIOWrapper + + +class BoxRegion(object): +    """ +    CIAO/DS9 "rotbox"/"box" region class. + +    rotbox/box format: +        rotbox(xc, yc, width, height, rotation) +        box(xc, yc, width, height, rotation) +    Notes: +        rotation: [0, 360) (degree) +    """ +    def __init__(self, xc=None, yc=None, +                 width=None, height=None, rotation=None): +        self.regtype  = "rotbox" +        self.xc       = xc +        self.yc       = yc +        self.width    = width +        self.height   = height +        self.rotation = rotation + +    def __str__(self): +        return "%s(%s,%s,%s,%s,%s)" % (self.regtype, self.xc, self.yc, +                self.width, self.height, self.rotation) + +    @classmethod +    def parse(cls, regstr): +        """ +        Parse region string. +        """ +        regex_box = re.compile(r'^(box|rotbox)\(([0-9.]+),([0-9.]+),([0-9.]+),([0-9.]+),([0-9.]+)\)$', re.I) +        m = regex_box.match(regstr) +        if m: +            regtype  = m.group(1) +            xc       = float(m.group(2)) +            yc       = float(m.group(3)) +            width    = float(m.group(4)) +            height   = float(m.group(5)) +            rotation = float(m.group(6)) +            return cls(xc, yc, width, height, rotation) +        else: +            return None + +    def split(self, filename=None): +        """ +        Split long-strip-shaped box region into a series square box regions. +        """ +        angle = self.rotation * math.pi / 180.0 +        # to record the center coordinates of each split region +        centers = [] +        if self.width > self.height: +            # number of regions after split +            nreg = math.ceil(self.width / self.height) +            # width & height of the split region +            width = self.width / nreg +            height = self.height +            # position of the left-most region +            x_l = self.xc - 0.5*self.width * math.cos(angle) +            y_l = self.yc - 0.5*self.width * math.sin(angle) +            for i in range(nreg): +                x = x_l + (0.5 + i) * width * math.cos(angle) +                y = y_l + (0.5 + i) * width * math.sin(angle) +                centers.append((x, y)) +        else: +            # number of regions after split +            nreg = math.ceil(self.height / self.width) +            # width & height of the split region +            width = self.width +            height = self.height / nreg +            # position of the left-most region +            x_l = self.xc + 0.5*self.height * math.cos(angle + math.pi/2) +            y_l = self.yc + 0.5*self.height * math.sin(angle + math.pi/2) +            for i in range(nreg): +                x = x_l - (0.5 + i) * height * math.cos(angle + math.pi/2) +                y = y_l - (0.5 + i) * height * math.sin(angle + math.pi/2) +                centers.append((x, y)) +        # create split regions +        regions = [] +        for (x, y) in centers: +            regions.append(self.__class__(x, y, width+2, height+2, +                                          self.rotation)) +        # write split regions into file if specified +        if isinstance(filename, str): +            regout = open(filename, "w") +            regout.write("\n".join(map(str, regions))) +            regout.close() +        else: +            return regions + + diff --git a/python/xkeywordsync.py b/python/xkeywordsync.py new file mode 100644 index 0000000..9a7cd79 --- /dev/null +++ b/python/xkeywordsync.py @@ -0,0 +1,533 @@ +#!/bin/usr/env python3 +# -*- coding: utf-8 -*- +# +# Credits: +# [1] Gaute Hope: gauteh/abunchoftags +#     https://github.com/gauteh/abunchoftags/blob/master/keywsync.cc +# +# TODO: +# * Support case-insensitive tags merge +#   (ref: http://stackoverflow.com/a/1480230) +# * Accept a specified mtime, and only deal with files with newer mtime. +# +# Aaron LI +# Created: 2016-01-24 +# + +""" +Sync message 'X-Keywords' header with notmuch tags. + +* tags-to-keywords: +  Check if the messages in the query have a matching 'X-Keywords' header +  to the list of notmuch tags. +  If not, update the 'X-Keywords' and re-write the message. + +* keywords-to-tags: +  Check if the messages in the query have matching notmuch tags to the +  'X-Keywords' header. +  If not, update the tags in the notmuch database. + +* merge-keywords-tags: +  Merge the 'X-Keywords' labels and notmuch tags, and update both. +""" + +__version__ = "0.1.2" +__date__    = "2016-01-25" + +import os +import sys +import argparse +import email + +# Require Python 3.4, or install package 'enum34' +from enum import Enum + +from notmuch import Database, Query + +from imapUTF7 import imapUTF7Decode, imapUTF7Encode + + +class SyncDirection(Enum): +    """ +    Synchronization direction +    """ +    MERGE_KEYWORDS_TAGS = 0  # Merge 'X-Keywords' and notmuch tags and +                             # update both +    KEYWORDS_TO_TAGS    = 1  # Sync 'X-Keywords' header to notmuch tags +    TAGS_TO_KEYWORDS    = 2  # Sync notmuch tags to 'X-Keywords' header + +class SyncMode(Enum): +    """ +    Sync mode +    """ +    ADD_REMOVE  = 0  # Allow add & remove tags/keywords +    ADD_ONLY    = 1  # Only allow add tags/keywords +    REMOVE_ONLY = 2  # Only allow remove tags/keywords + + +class KwMessage: +    """ +    Message class to deal with 'X-Keywords' header synchronization +    with notmuch tags. + +    NOTE: +    * The same message may have multiple files with different keywords +      (e.g, the same message exported under each label by Gmail) +      managed by OfflineIMAP. +      For example: a message file in OfflineIMAP synced folder of +      '[Gmail]/All Mail' have keywords ['google', 'test']; however, +      the file in synced folder 'test' of the same message only have +      keywords ['google'] without the keyword 'test'. +    * All files associated to the same message are regarded as the same. +      The keywords are extracted from all files and merged. +      And the same updated keywords are written back to all files, which +      results all files finally having the same 'X-Keywords' header. +    * You may only sync the '[Gmail]/All Mail' folder without other +      folders exported according the labels by Gmail. +    """ +    # Replace some special characters before mapping keyword to tag +    enable_replace_chars = True +    chars_replace = { +            '/' : '.', +    } +    # Mapping between (Gmail) keywords and notmuch tags (before ignoring tags) +    keywords_mapping = { +            '\\Inbox'     : 'inbox', +            '\\Important' : 'important', +            '\\Starred'   : 'flagged', +            '\\Sent'      : 'sent', +            '\\Muted'     : 'killed', +            '\\Draft'     : 'draft', +            '\\Trash'     : 'deleted', +            '\\Junk'      : 'spam', +    } +    # Tags ignored from syncing +    # These tags are either internal tags or tags handled by maildir flags. +    enable_ignore_tags = True +    tags_ignored = set([ +            'new', 'unread', 'attachment', 'signed', 'encrypted', +            'flagged', 'replied', 'passed', 'draft', +    ]) +    # Ignore case when merging tags +    tags_ignorecase = True + +    # Whether the tags updated against the message 'X-Keywords' header +    tags_updated = False +    # Added & removed tags for notmuch database against 'X-Keywords' +    tags_added   = [] +    tags_removed = [] +    # Newly updated/merged notmuch tags against 'X-Keywords' +    tags_new     = [] + +    # Whether the keywords updated against the notmuch tags +    keywords_updated = False +    # Added & removed tags for 'X-Keywords' against notmuch database +    tags_kw_added    = [] +    tags_kw_removed  = [] +    # Newly updated/merged tags for 'X-Keywords' against notmuch database +    tags_kw_new      = [] + +    def __init__(self, msg, filename=None): +        self.message  = msg +        self.filename = filename +        self.allfiles = [ fn for fn in msg.get_filenames() ] +        self.tags     = set(msg.get_tags()) + +    def sync(self, direction, mode=SyncMode.ADD_REMOVE, +             dryrun=False, verbose=False): +        """ +        Wrapper function to sync between 'X-Keywords' and notmuch tags. +        """ +        if direction == SyncDirection.KEYWORDS_TO_TAGS: +            self.sync_keywords_to_tags(sync_mode=mode, dryrun=dryrun, +                                       verbose=verbose) +        elif direction == SyncDirection.TAGS_TO_KEYWORDS: +            self.sync_tags_to_keywords(sync_mode=mode, dryrun=dryrun, +                                       verbose=verbose) +        elif direction == SyncDirection.MERGE_KEYWORDS_TAGS: +            self.merge_keywords_tags(sync_mode=mode, dryrun=dryrun, +                                     verbose=verbose) +        else: +            raise ValueError("Invalid sync direction: %s" % direction) + +    def sync_keywords_to_tags(self, sync_mode=SyncMode.ADD_REMOVE, +                              dryrun=False, verbose=False): +        """ +        Wrapper function to sync 'X-Keywords' to notmuch tags. +        """ +        self.get_keywords() +        self.map_keywords() +        self.merge_tags(sync_direction=SyncDirection.KEYWORDS_TO_TAGS, +                        sync_mode=sync_mode) +        if dryrun or verbose: +            print('* MSG: %s' % self.message) +            print('  TAG: [%s]  +[%s] -[%s]  =>  [%s]' % ( +                ','.join(self.tags), ','.join(self.tags_added), +                ','.join(self.tags_removed), ','.join(self.tags_new))) +        if not dryrun: +            self.update_tags() + +    def sync_tags_to_keywords(self, sync_mode=SyncMode.ADD_REMOVE, +                              dryrun=False, verbose=False): +        """ +        Wrapper function to sync notmuch tags to 'X-Keywords' +        """ +        self.get_keywords() +        self.map_keywords() +        self.merge_tags(sync_direction=SyncDirection.TAGS_TO_KEYWORDS, +                        sync_mode=sync_mode) +        keywords_new = self.map_tags(tags=self.tags_kw_new) +        if dryrun or verbose: +            print('* MSG: %s' % self.message) +            print('* FILES: %s' % ' ; '.join(self.allfiles)) +            print('  XKW: {%s}  +[%s] -[%s]  =>  {%s}' % ( +                ','.join(self.keywords), ','.join(self.tags_kw_added), +                ','.join(self.tags_kw_removed), ','.join(keywords_new))) +        if not dryrun: +            self.update_keywords(keywords_new=keywords_new) + +    def merge_keywords_tags(self, sync_mode=SyncMode.ADD_REMOVE, +                            dryrun=False, verbose=False): +        """ +        Wrapper function to merge 'X-Keywords' and notmuch tags +        """ +        self.get_keywords() +        self.map_keywords() +        self.merge_tags(sync_direction=SyncDirection.MERGE_KEYWORDS_TAGS, +                        sync_mode=sync_mode) +        keywords_new = self.map_tags(tags=self.tags_kw_new) +        if dryrun or verbose: +            print('* MSG: %s' % self.message) +            print('* FILES: %s' % ' ; '.join(self.allfiles)) +            print('  TAG: [%s]  +[%s] -[%s]  =>  [%s]' % ( +                ','.join(self.tags), ','.join(self.tags_added), +                ','.join(self.tags_removed), ','.join(self.tags_new))) +            print('  XKW: {%s}  +[%s] -[%s]  =>  {%s}' % ( +                ','.join(self.keywords), ','.join(self.tags_kw_added), +                ','.join(self.tags_kw_removed), ','.join(keywords_new))) +        if not dryrun: +            self.update_tags() +            self.update_keywords(keywords_new=keywords_new) + +    def get_keywords(self): +        """ +        Get 'X-Keywords' header from all files associated with the same +        message, decode, split and merge. + +        NOTE: Do NOT simply use the `message.get_header()` method, which +              cannot get the complete keywords from all files. +        """ +        keywords_utf7 = [] +        for fn in self.allfiles: +            msg = email.message_from_file(open(fn, 'r')) +            val = msg['X-Keywords'] +            if val: +                keywords_utf7.append(val) +            else: +                print("WARNING: 'X-Keywords' header not found or empty " +\ +                        "for file: %s" % fn, file=sys.stderr) +        keywords_utf7 = ','.join(keywords_utf7) +        if keywords_utf7 != '': +            keywords = imapUTF7Decode(keywords_utf7.encode()).split(',') +            keywords = [ kw.strip() for kw in keywords ] +            # Remove duplications +            keywords = set(keywords) +        else: +            keywords = set() +        self.keywords = keywords +        return keywords + +    def map_keywords(self, keywords=None): +        """ +        Map keywords to notmuch tags according to the mapping table. +        """ +        if keywords is None: +            keywords = self.keywords +        if self.enable_replace_chars: +            # Replace specified characters in keywords +            trans = str.maketrans(self.chars_replace) +            keywords = [ kw.translate(trans) for kw in keywords ] +        # Map keywords to tags +        tags = set([ self.keywords_mapping.get(kw, kw) for kw in keywords ]) +        self.tags_kw = tags +        return tags + +    def map_tags(self, tags=None): +        """ +        Map tags to keywords according to the inversed mapping table. +        """ +        if tags is None: +            tags = self.tags +        if self.enable_replace_chars: +            # Inversely replace specified characters in tags +            chars_replace_inv = { v: k for k, v in self.chars_replace.items() } +            trans = str.maketrans(chars_replace_inv) +            tags = [ tag.translate(trans) for tag in tags ] +        # Map keywords to tags +        keywords_mapping_inv = { v:k for k,v in self.keywords_mapping.items() } +        keywords = set([ keywords_mapping_inv.get(tag, tag) for tag in tags ]) +        self.keywords_tags = keywords +        return keywords + +    def merge_tags(self, sync_direction, sync_mode=SyncMode.ADD_REMOVE, +                   tags_nm=None, tags_kw=None): +        """ +        Merge the tags from notmuch database and 'X-Keywords' header, +        according to the specified sync direction and operation restriction. + +        TODO: support case-insensitive set operations +        """ +        # Added & removed tags for notmuch database against 'X-Keywords' +        tags_added       = [] +        tags_removed     = [] +        # Newly updated/merged notmuch tags against 'X-Keywords' +        tags_new         = [] +        # Added & removed tags for 'X-Keywords' against notmuch database +        tags_kw_added    = [] +        tags_kw_removed  = [] +        # Newly updated/merged tags for 'X-Keywords' against notmuch database +        tags_kw_new      = [] +        # +        if tags_nm is None: +            tags_nm = self.tags +        if tags_kw is None: +            tags_kw = self.tags_kw +        if self.enable_ignore_tags: +            # Remove ignored tags before merge +            tags_nm2 = tags_nm.difference(self.tags_ignored) +            tags_kw2 = tags_kw.difference(self.tags_ignored) +        else: +            tags_nm2 = tags_nm +            tags_kw2 = tags_kw +        # +        if sync_direction == SyncDirection.KEYWORDS_TO_TAGS: +            # Sync 'X-Keywords' to notmuch tags +            tags_added      = tags_kw2.difference(tags_nm2) +            tags_removed    = tags_nm2.difference(tags_kw2) +        elif sync_direction == SyncDirection.TAGS_TO_KEYWORDS: +            # Sync notmuch tags to 'X-Keywords' +            tags_kw_added   = tags_nm2.difference(tags_kw2) +            tags_kw_removed = tags_kw2.difference(tags_nm2) +        elif sync_direction == SyncDirection.MERGE_KEYWORDS_TAGS: +            # Merge both notmuch tags and 'X-Keywords' +            tags_merged     = tags_nm2.union(tags_kw2) +            # notmuch tags +            tags_added      = tags_merged.difference(tags_nm2) +            tags_removed    = tags_nm2.difference(tags_merged) +            # tags for 'X-Keywords' +            tags_kw_added   = tags_merged.difference(tags_kw2) +            tags_kw_removed = tags_kw2.difference(tags_merged) +        else: +            raise ValueError("Invalid synchronization direction") +        # Apply sync operation restriction +        self.tags_added       = [] +        self.tags_removed     = [] +        self.tags_kw_added    = [] +        self.tags_kw_removed  = [] +        tags_new              = tags_nm  # Use un-ignored notmuch tags +        tags_kw_new           = tags_kw  # Use un-ignored 'X-Keywords' tags +        if sync_mode != SyncMode.REMOVE_ONLY: +            self.tags_added      = tags_added +            self.tags_kw_added   = tags_kw_added +            tags_new             = tags_new.union(tags_added) +            tags_kw_new          = tags_kw_new.union(tags_kw_added) +        if sync_mode != SyncMode.ADD_ONLY: +            self.tags_removed    = tags_removed +            self.tags_kw_removed = tags_kw_removed +            tags_new             = tags_new.difference(tags_removed) +            tags_kw_new          = tags_kw_new.difference(tags_kw_removed) +        # +        self.tags_new         = tags_new +        self.tags_kw_new      = tags_kw_new +        if self.tags_added or self.tags_removed: +            self.tags_updated = True +        if self.tags_kw_added or self.tags_kw_removed: +            self.keywords_updated = True +        # +        return { +                'tags_updated'     : self.tags_updated, +                'tags_added'       : self.tags_added, +                'tags_removed'     : self.tags_removed, +                'tags_new'         : self.tags_new, +                'keywords_updated' : self.keywords_updated, +                'tags_kw_added'    : self.tags_kw_added, +                'tags_kw_removed'  : self.tags_kw_removed, +                'tags_kw_new'      : self.tags_kw_new, +        } + +    def update_keywords(self, keywords_new=None, outfile=None): +        """ +        Encode the keywords (default: self.keywords_new) and write back to +        all message files. + +        If parameter 'outfile' specified, then write the updated message +        to that file instead of overwriting. + +        NOTE: +        * The modification time of the message file should be kept to prevent +          OfflineIMAP from treating it as a new one (and the previous a +          deleted one). +        * All files associated with the same message are updated to have +          the same 'X-Keywords' header. +        """ +        if not self.keywords_updated: +            # keywords NOT updated, just skip +            return + +        if keywords_new is None: +            keywords_new = self.keywords_new +        # +        if outfile is not None: +            infile  = self.allfiles[0:1] +            outfile = [ os.path.expanduser(outfile) ] +        else: +            infile  = self.allfiles +            outfile = self.allfiles +        # +        for ifname, ofname in zip(infile, outfile): +            msg   = email.message_from_file(open(ifname, 'r')) +            fstat = os.stat(ifname) +            if keywords_new == []: +                # Delete 'X-Keywords' header +                print("WARNING: delete 'X-Keywords' header from file: %s" % +                        ifname, file=sys.stderr) +                del msg['X-Keywords'] +            else: +                # Update 'X-Keywords' header +                keywords = ','.join(keywords_new) +                keywords_utf7 = imapUTF7Encode(keywords).decode() +                # Delete then add, to avoid multiple occurrences +                del msg['X-Keywords'] +                msg['X-Keywords'] = keywords_utf7 +            # Write updated message +            with open(ofname, 'w') as fp: +                fp.write(msg.as_string()) +            # Reset the timestamps +            os.utime(ofname, ns=(fstat.st_atime_ns, fstat.st_mtime_ns)) + +    def update_tags(self, tags_added=None, tags_removed=None): +        """ +        Update notmuch tags according to keywords. +        """ +        if not self.tags_updated: +            # tags NOT updated, just skip +            return + +        if tags_added is None: +            tags_added   = self.tags_added +        if tags_removed is None: +            tags_removed = self.tags_removed +        # Use freeze/thaw for safer transactions to change tag values. +        self.message.freeze() +        for tag in tags_added: +            self.message.add_tag(tag, sync_maildir_flags=False) +        for tag in tags_removed: +            self.message.remove_tag(tag, sync_maildir_flags=False) +        self.message.thaw() + + +def get_notmuch_revision(dbpath=None): +    """ +    Get the current revision and UUID of notmuch database. +    """ +    import subprocess +    import tempfile +    if dbpath: +        tf = tempfile.NamedTemporaryFile() +        # Create a minimal notmuch config for the specified dbpath +        config = '[database]\npath=%s\n' % os.path.expanduser(dbpath) +        tf.file.write(config.encode()) +        tf.file.flush() +        cmd = 'notmuch --config=%s count --lastmod' % tf.name +        output = subprocess.check_output(cmd, shell=True) +        tf.close() +    else: +        cmd = 'notmuch count --lastmod' +        output = subprocess.check_output(cmd, shell=True) +    # Extract output +    dbinfo = output.decode().split() +    return { 'revision': int(dbinfo[2]), 'uuid': dbinfo[1] } + + +def main(): +    parser = argparse.ArgumentParser( +            description="Sync message 'X-Keywords' header with notmuch tags.") +    parser.add_argument("-V", "--version", action="version", +            version="%(prog)s " + "v%s (%s)" % (__version__, __date__)) +    parser.add_argument("-q", "--query", dest="query", required=True, +            help="notmuch database query string") +    parser.add_argument("-p", "--db-path", dest="dbpath", +            help="notmuch database path (default to try user configuration)") +    parser.add_argument("-n", "--dry-run", dest="dryrun", +            action="store_true", help="dry run") +    parser.add_argument("-v", "--verbose", dest="verbose", +            action="store_true", help="show verbose information") +    # Exclusive argument group for sync mode +    exgroup1 = parser.add_mutually_exclusive_group(required=True) +    exgroup1.add_argument("-m", "--merge-keywords-tags", +            dest="direction_merge", action="store_true", +            help="merge 'X-Keywords' and tags and update both") +    exgroup1.add_argument("-k", "--keywords-to-tags", +            dest="direction_keywords2tags", action="store_true", +            help="sync 'X-Keywords' to notmuch tags") +    exgroup1.add_argument("-t", "--tags-to-keywords", +            dest="direction_tags2keywords", action="store_true", +            help="sync notmuch tags to 'X-Keywords'") +    # Exclusive argument group for tag operation mode +    exgroup2 = parser.add_mutually_exclusive_group(required=False) +    exgroup2.add_argument("-a", "--add-only", dest="mode_addonly", +            action="store_true", help="only add notmuch tags") +    exgroup2.add_argument("-r", "--remove-only", dest="mode_removeonly", +            action="store_true", help="only remove notmuch tags") +    # Parse +    args = parser.parse_args() +    # Sync direction +    if args.direction_merge: +        sync_direction = SyncDirection.MERGE_KEYWORDS_TAGS +    elif args.direction_keywords2tags: +        sync_direction = SyncDirection.KEYWORDS_TO_TAGS +    elif args.direction_tags2keywords: +        sync_direction = SyncDirection.TAGS_TO_KEYWORDS +    else: +        raise ValueError("Invalid synchronization direction") +    # Sync mode +    if args.mode_addonly: +        sync_mode = SyncMode.ADD_ONLY +    elif args.mode_removeonly: +        sync_mode = SyncMode.REMOVE_ONLY +    else: +        sync_mode = SyncMode.ADD_REMOVE +    # +    if args.dbpath: +        dbpath = os.path.abspath(os.path.expanduser(args.dbpath)) +    else: +        dbpath = None +    # +    db = Database(path=dbpath, create=False, mode=Database.MODE.READ_WRITE) +    dbinfo = get_notmuch_revision(dbpath=dbpath) +    q = Query(db, args.query) +    total_msgs = q.count_messages() +    msgs = q.search_messages() +    # +    if args.verbose: +        print("# Notmuch database path: %s" % dbpath) +        print("# Database revision: %d (uuid: %s)" % +                (dbinfo['revision'], dbinfo['uuid'])) +        print("# Query: %s" % args.query) +        print("# Sync direction: %s" % sync_direction.name) +        print("# Sync mode: %s" % sync_mode.name) +        print("# Total messages to check: %d" % total_msgs) +        print("# Dryn run: %s" % args.dryrun) +    # +    for msg in msgs: +        kwmsg = KwMessage(msg) +        kwmsg.sync(direction=sync_direction, mode=sync_mode, +                   dryrun=args.dryrun, verbose=args.verbose) +    # +    db.close() + + +if __name__ == "__main__": +    main() + +#  vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # | 
