diff options
-rwxr-xr-x | python/adjust_spectrum_error.py | 170 | ||||
-rwxr-xr-x | python/correct_crosstalk.py | 319 | ||||
-rw-r--r-- | python/imapUTF7.py | 189 | ||||
-rw-r--r-- | python/plot.py | 35 | ||||
-rw-r--r-- | python/plot_tprofiles_zzh.py | 126 | ||||
-rwxr-xr-x | python/randomize_events.py | 72 | ||||
-rw-r--r-- | python/sbp_fit.py | 613 | ||||
-rw-r--r-- | python/starlet.py | 513 | ||||
-rw-r--r-- | python/xkeywordsync.py | 533 |
9 files changed, 2570 insertions, 0 deletions
diff --git a/python/adjust_spectrum_error.py b/python/adjust_spectrum_error.py new file mode 100755 index 0000000..0f80ec7 --- /dev/null +++ b/python/adjust_spectrum_error.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Squeeze the spectrum according to the grouping specification, then +calculate the statistical errors for each group, and apply error +adjustments (e.g., incorporate the systematic uncertainties). +""" + +__version__ = "0.1.0" +__date__ = "2016-01-11" + + +import sys +import argparse + +import numpy as np +from astropy.io import fits + + +class Spectrum: + """ + Spectrum class to keep spectrum information and perform manipulations. + """ + header = None + channel = None + counts = None + grouping = None + quality = None + + def __init__(self, specfile): + f = fits.open(specfile) + spechdu = f['SPECTRUM'] + self.header = spechdu.header + self.channel = spechdu.data.field('CHANNEL') + self.counts = spechdu.data.field('COUNTS') + self.grouping = spechdu.data.field('GROUPING') + self.quality = spechdu.data.field('QUALITY') + f.close() + + def squeezeByGrouping(self): + """ + Squeeze the spectrum according to the grouping specification, + i.e., sum the counts belonging to the same group, and place the + sum as the first channel within each group with other channels + of counts zero's. + """ + counts_squeezed = [] + cnt_sum = 0 + cnt_num = 0 + first = True + for grp, cnt in zip(self.grouping, self.counts): + if first and grp == 1: + # first group + cnt_sum = cnt + cnt_num = 1 + first = False + elif grp == 1: + # save previous group + counts_squeezed.append(cnt_sum) + counts_squeezed += [ 0 for i in range(cnt_num-1) ] + # start new group + cnt_sum = cnt + cnt_num = 1 + else: + # group continues + cnt_sum += cnt + cnt_num += 1 + # last group + # save previous group + counts_squeezed.append(cnt_sum) + counts_squeezed += [ 0 for i in range(cnt_num-1) ] + self.counts_squeezed = np.array(counts_squeezed, dtype=np.int32) + + def calcStatErr(self, gehrels=False): + """ + Calculate the statistical errors for the grouped channels, + and save as the STAT_ERR column. + """ + idx_nz = np.nonzero(self.counts_squeezed) + stat_err = np.zeros(self.counts_squeezed.shape) + if gehrels: + # Gehrels + stat_err[idx_nz] = 1 + np.sqrt(self.counts_squeezed[idx_nz] + 0.75) + else: + stat_err[idx_nz] = np.sqrt(self.counts_squeezed[idx_nz]) + self.stat_err = stat_err + + @staticmethod + def parseSysErr(syserr): + """ + Parse the string format of syserr supplied in the commandline. + """ + items = map(str.strip, syserr.split(',')) + syserr_spec = [] + for item in items: + spec = item.split(':') + try: + spec = (int(spec[0]), int(spec[1]), float(spec[2])) + except: + raise ValueError("invalid syserr specficiation") + syserr_spec.append(spec) + return syserr_spec + + def applySysErr(self, syserr): + """ + Apply systematic error adjustments to the above calculated + statistical errors. + """ + syserr_spec = self.parseSysErr(syserr) + for lo, hi, se in syserr_spec: + err_adjusted = self.stat_err[(lo-1):(hi-1)] * np.sqrt(1+se) + self.stat_err_adjusted = err_adjusted + + def updateHeader(self): + """ + Update header accordingly. + """ + # POISSERR + self.header['POISSERR'] = False + + def write(self, filename, clobber=False): + """ + Write the updated/modified spectrum block to file. + """ + channel_col = fits.Column(name='CHANNEL', format='J', + array=self.channel) + counts_col = fits.Column(name='COUNTS', format='J', + array=self.counts_squeezed) + stat_err_col = fits.Column(name='STAT_ERR', format='D', + array=self.stat_err_adjusted) + grouping_col = fits.Column(name='GROUPING', format='I', + array=self.grouping) + quality_col = fits.Column(name='QUALITY', format='I', + array=self.quality) + spec_cols = fits.ColDefs([channel_col, counts_col, stat_err_col, + grouping_col, quality_col]) + spechdu = fits.BinTableHDU.from_columns(spec_cols, header=self.header) + spechdu.writeto(filename, clobber=clobber) + + +def main(): + parser = argparse.ArgumentParser( + description="Apply systematic error adjustments to spectrum.") + parser.add_argument("-V", "--version", action="version", + version="%(prog)s " + "%s (%s)" % (__version__, __date__)) + parser.add_argument("infile", help="input spectrum file") + parser.add_argument("outfile", help="output adjusted spectrum file") + parser.add_argument("-e", "--syserr", dest="syserr", required=True, + help="systematic error specification; " + \ + "syntax: ch1low:ch1high:syserr1,...") + parser.add_argument("-C", "--clobber", dest="clobber", + action="store_true", help="overwrite output file if exists") + parser.add_argument("-G", "--gehrels", dest="gehrels", + action="store_true", help="use Gehrels error?") + args = parser.parse_args() + + spec = Spectrum(args.infile) + spec.squeezeByGrouping() + spec.calcStatErr(gehrels=args.gehrels) + spec.applySysErr(syserr=args.syserr) + spec.updateHeader() + spec.write(args.outfile, clobber=args.clobber) + + +if __name__ == "__main__": + main() + + +# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/correct_crosstalk.py b/python/correct_crosstalk.py new file mode 100755 index 0000000..4c7f820 --- /dev/null +++ b/python/correct_crosstalk.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# +# Correct the crosstalk effect of XMM spectra by subtracting the +# scattered photons from surrounding regions, and by compensating +# the photons scattered to surrounding regions, according to the +# generated crosstalk ARFs. +# +# Sample config file (in `ConfigObj' syntax): +#---------------------- +# fix_negative = True +# verbose = True +# clobber = False +# +# [reg2] +# outfile = cc_reg2.pi +# spec = reg2.pi +# arf = reg2.arf +# [[cross_in]] +# [[[in1]]] +# spec = reg1.pi +# arf = reg1.arf +# cross_arf = reg_1-2.arf +# [[[in2]]] +# spec = reg3.pi +# arf = reg3.arf +# cross_arf = reg_3-2.arf +# [[cross_out]] +# cross_arf = reg_2-1.arf, reg_2-3.arf +#---------------------- +# +# Weitian LI +# Created: 2016-03-26 +# Updated: 2016-03-28 +# + +from astropy.io import fits +import numpy as np +from configobj import ConfigObj + +import sys +import os +import argparse +from datetime import datetime + + +class ARF: + """ + Deal with X-ray ARF file (.arf) + """ + filename = None + fitsobj = None + header = None + energ_lo = None + energ_hi = None + specresp = None + + def __init__(self, filename): + self.filename = filename + self.fitsobj = fits.open(filename) + ext_specresp = self.fitsobj["SPECRESP"] + self.header = ext_specresp.header + self.energ_lo = ext_specresp.data["ENERG_LO"] + self.energ_hi = ext_specresp.data["ENERG_HI"] + self.specresp = ext_specresp.data["SPECRESP"] + + def get_data(self, copy=True): + if copy: + return self.specresp.copy() + else: + return self.specresp + + +class Spectrum: + """ + Deal with X-ray spectrum (.pi) + + NOTE: + The "COUNTS" column data are converted from "int32" to "float32". + """ + filename = None + # FITS object return by `fits.open' + fitsobj = None + # header of "SPECTRUM" extension + header = None + # "SPECTRUM" extension data + channel = None + # name of the column containing the spectrum data, either "COUNTS" or "RATE" + spec_colname = None + # spectrum data + spec_data = None + # ARF object for this spectrum + arf = None + + def __init__(self, filename, arffile): + self.filename = filename + self.fitsobj = fits.open(filename) + ext_spec = self.fitsobj["SPECTRUM"] + self.header = ext_spec.header.copy(strip=True) + colnames = ext_spec.columns.names + if "COUNTS" in colnames: + self.spec_colname = "COUNTS" + elif "RATE" in colnames: + self.spec_colname = "RATE" + else: + raise ValueError("Invalid spectrum file") + self.channel = ext_spec.data["CHANNEL"].copy() + self.spec_data = ext_spec.data.field(self.spec_colname)\ + .astype(np.float32) + self.arf = ARF(arffile) + + def get_data(self, copy=True): + if copy: + return self.spec_data.copy() + else: + return self.spec_data + + def get_arf(self, copy=True): + if self.arf is None: + return None + else: + return self.arf.get_data(copy=copy) + + def subtract(self, spectrum, cross_arf, verbose=False): + """ + Subtract the photons that originate from the surrounding regions + but were scattered into this spectrum due to the finite PSF. + + NOTE: + The crosstalk ARF must be provided, since the `spectrum.arf' is + required to be its ARF without taking crosstalk into account: + spec1_new = spec1 - spec2 * (cross_arf_2_to_1 / arf2) + """ + operation = " SUBTRACT: %s - (%s/%s) * %s" % (self.filename, + cross_arf.filename, spectrum.arf.filename, spectrum.filename) + if verbose: + print(operation, file=sys.stderr) + arf_ratio = cross_arf.get_data() / spectrum.get_arf() + arf_ratio[np.isnan(arf_ratio)] = 0.0 + self.spec_data -= spectrum.get_data() * arf_ratio + # record history + self.header.add_history(operation) + + def compensate(self, cross_arf, verbose=False): + """ + Compensate the photons that originate from this regions but were + scattered into the surrounding regions due to the finite PSF. + + formula: + spec1_new = spec1 + spec1 * (cross_arf_1_to_2 / arf1) + """ + operation = " COMPENSATE: %s + (%s/%s) * %s" % (self.filename, + cross_arf.filename, self.arf.filename, self.filename) + if verbose: + print(operation, file=sys.stderr) + arf_ratio = cross_arf.get_data() / self.get_arf() + arf_ratio[np.isnan(arf_ratio)] = 0.0 + self.spec_data += self.get_data() * arf_ratio + # record history + self.header.add_history(operation) + + def fix_negative(self, verbose=False): + """ + The subtractions may lead to negative counts, it may be necessary + to fix these channels with negative values. + """ + neg_counts = self.spec_data < 0 + N = len(neg_counts) + neg_channels = np.arange(N, dtype=np.int)[neg_counts] + if len(neg_channels) > 0: + print("WARNING: %d channels have NEGATIVE counts" % \ + len(neg_channels), file=sys.stderr) + i = 0 + while len(neg_channels) > 0: + i += 1 + if verbose: + print("*** Fixing negative channels: iteration %d ..." % i, + file=sys.stderr) + for ch in neg_channels: + neg_val = self.spec_data[ch] + if ch < N-2: + self.spec_data[ch] = 0 + self.spec_data[(ch+1):(ch+3)] -= 0.5 * np.abs(neg_val) + else: + # just set to zero if it is the last 2 channels + self.spec_data[ch] = 0 + # update negative channels indices + neg_counts = self.spec_data < 0 + neg_channels = np.arange(N, dtype=np.int)[neg_counts] + if i > 0: + print("*** Fixed negative channels ***", file=sys.stderr) + + def write(self, filename, clobber=False): + """ + Create a new "SPECTRUM" table/extension and replace the original + one, then write to output file. + """ + ext_spec_cols = fits.ColDefs([ + fits.Column(name="CHANNEL", format="I", array=self.channel), + fits.Column(name="COUNTS", format="E", unit="count", + array=self.spec_data)]) + ext_spec = fits.BinTableHDU.from_columns(ext_spec_cols, + header=self.header) + self.fitsobj["SPECTRUM"] = ext_spec + self.fitsobj.writeto(filename, clobber=clobber, checksum=True) + + +class Crosstalk: + """ + Crosstalk correction. + """ + # `Spectrum' object for the spectrum to be corrected + spectrum = None + # XXX: do NOT use list (e.g., []) here, otherwise, all the instances + # will share these list properties. + # `Spectrum' and `ARF' objects corresponding to the spectra from which + # the photons were scattered into this spectrum. + cross_in_spec = None + cross_in_arf = None + # `ARF' objects corresponding to the regions to which the photons of + # this spectrum were scattered into. + cross_out_arf = None + # output filename to which write the corrected spectrum + outfile = None + + def __init__(self, config): + """ + `config': a section of the whole config file (`ConfigObj` object). + """ + self.cross_in_spec = [] + self.cross_in_arf = [] + self.cross_out_arf = [] + # this spectrum to be corrected + self.spectrum = Spectrum(config["spec"], config["arf"]) + # spectra and cross arf from which photons were scattered in + for reg_in in config["cross_in"].values(): + spec = Spectrum(reg_in["spec"], reg_in["arf"]) + self.cross_in_spec.append(spec) + self.cross_in_arf.append(ARF(reg_in["cross_arf"])) + # regions into which the photons of this spectrum were scattered into + if "cross_out" in config.sections: + cross_arf = config["cross_out"].as_list("cross_arf") + for arffile in cross_arf: + self.cross_out_arf.append(ARF(arffile)) + # output filename + self.outfile = config["outfile"] + + def do_correction(self, fix_negative=False, verbose=False): + self.spectrum.header.add_history("Crosstalk Correction BEGIN") + self.spectrum.header.add_history(" TOOL: %s @ %s" % (\ + os.path.basename(sys.argv[0]), datetime.utcnow().isoformat())) + # subtractions + if verbose: + print("INFO: apply subtractions ...", file=sys.stderr) + for spec, cross_arf in zip(self.cross_in_spec, self.cross_in_arf): + self.spectrum.subtract(spectrum=spec, cross_arf=cross_arf, + verbose=verbose) + # compensations + if verbose: + print("INFO: apply compensations ...", file=sys.stderr) + for cross_arf in self.cross_out_arf: + self.spectrum.compensate(cross_arf=cross_arf, verbose=verbose) + # fix negative values in channels + if fix_negative: + if verbose: + print("INFO: fix negative channel values ...", file=sys.stderr) + self.spectrum.fix_negative(verbose=verbose) + self.spectrum.header.add_history("END Crosstalk Correction") + + def write(self, filename=None, clobber=False): + if filename is None: + filename = self.outfile + self.spectrum.write(filename, clobber=clobber) + + +def main(): + parser = argparse.ArgumentParser( + description="Correct the crosstalk effects of XMM spectra") + parser.add_argument("config", help="config file in which describes " +\ + "the crosstalk relations. ('ConfigObj' syntax)") + parser.add_argument("-N", "--fix-negative", dest="fix_negative", + action="store_true", help="fix negative channel values") + parser.add_argument("-C", "--clobber", dest="clobber", + action="store_true", help="overwrite output file if exists") + parser.add_argument("-v", "--verbose", dest="verbose", + action="store_true", help="show verbose information") + args = parser.parse_args() + + config = ConfigObj(args.config) + + fix_negative = False + if "fix_negative" in config.keys(): + fix_negative = config.as_bool("fix_negative") + if args.fix_negative: + fix_negative = args.fix_negative + + verbose = False + if "verbose" in config.keys(): + verbose = config.as_bool("verbose") + if args.verbose: + verbose = args.verbose + + clobber = False + if "clobber" in config.keys(): + clobber = config.as_bool("clobber") + if args.clobber: + clobber = args.clobber + + for region in config.sections: + if verbose: + print("INFO: processing '%s' ..." % region, file=sys.stderr) + crosstalk = Crosstalk(config.get(region)) + crosstalk.do_correction(fix_negative=fix_negative, verbose=verbose) + crosstalk.write(clobber=clobber) + + +if __name__ == "__main__": + main() + +# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/imapUTF7.py b/python/imapUTF7.py new file mode 100644 index 0000000..2e4db0a --- /dev/null +++ b/python/imapUTF7.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# This code was originally in PloneMailList, a GPL'd software. +# http://svn.plone.org/svn/collective/mxmImapClient/trunk/imapUTF7.py +# http://bugs.python.org/issue5305 +# +# Port to Python 3.x +# Credit: https://github.com/MarechJ/py3_imap_utf7 +# +# 2016-01-23 +# Aaron LI +# + +""" +Imap folder names are encoded using a special version of utf-7 as defined in RFC +2060 section 5.1.3. + +5.1.3. Mailbox International Naming Convention + + By convention, international mailbox names are specified using a + modified version of the UTF-7 encoding described in [UTF-7]. The + purpose of these modifications is to correct the following problems + with UTF-7: + + 1) UTF-7 uses the "+" character for shifting; this conflicts with + the common use of "+" in mailbox names, in particular USENET + newsgroup names. + + 2) UTF-7's encoding is BASE64 which uses the "/" character; this + conflicts with the use of "/" as a popular hierarchy delimiter. + + 3) UTF-7 prohibits the unencoded usage of "\"; this conflicts with + the use of "\" as a popular hierarchy delimiter. + + 4) UTF-7 prohibits the unencoded usage of "~"; this conflicts with + the use of "~" in some servers as a home directory indicator. + + 5) UTF-7 permits multiple alternate forms to represent the same + string; in particular, printable US-ASCII chararacters can be + represented in encoded form. + + In modified UTF-7, printable US-ASCII characters except for "&" + represent themselves; that is, characters with octet values 0x20-0x25 + and 0x27-0x7e. The character "&" (0x26) is represented by the two- + octet sequence "&-". + + All other characters (octet values 0x00-0x1f, 0x7f-0xff, and all + Unicode 16-bit octets) are represented in modified BASE64, with a + further modification from [UTF-7] that "," is used instead of "/". + Modified BASE64 MUST NOT be used to represent any printing US-ASCII + character which can represent itself. + + "&" is used to shift to modified BASE64 and "-" to shift back to US- + ASCII. All names start in US-ASCII, and MUST end in US-ASCII (that + is, a name that ends with a Unicode 16-bit octet MUST end with a "- + "). + + For example, here is a mailbox name which mixes English, Japanese, + and Chinese text: ~peter/mail/&ZeVnLIqe-/&U,BTFw- +""" + + +import binascii +import codecs + + +## encoding + +def modified_base64(s:str): + s = s.encode('utf-16be') # UTF-16, big-endian byte order + return binascii.b2a_base64(s).rstrip(b'\n=').replace(b'/', b',') + +def doB64(_in, r): + if _in: + r.append(b'&' + modified_base64(''.join(_in)) + b'-') + del _in[:] + +def encoder(s:str): + r = [] + _in = [] + for c in s: + ordC = ord(c) + if 0x20 <= ordC <= 0x25 or 0x27 <= ordC <= 0x7e: + doB64(_in, r) + r.append(c.encode()) + elif c == '&': + doB64(_in, r) + r.append(b'&-') + else: + _in.append(c) + doB64(_in, r) + return (b''.join(r), len(s)) + + +## decoding + +def modified_unbase64(s:bytes): + b = binascii.a2b_base64(s.replace(b',', b'/') + b'===') + return b.decode('utf-16be') + +def decoder(s:bytes): + r = [] + decode = bytearray() + for c in s: + if c == ord('&') and not decode: + decode.append(ord('&')) + elif c == ord('-') and decode: + if len(decode) == 1: + r.append('&') + else: + r.append(modified_unbase64(decode[1:])) + decode = bytearray() + elif decode: + decode.append(c) + else: + r.append(chr(c)) + if decode: + r.append(modified_unbase64(decode[1:])) + bin_str = ''.join(r) + return (bin_str, len(s)) + + +class StreamReader(codecs.StreamReader): + def decode(self, s, errors='strict'): + return decoder(s) + + +class StreamWriter(codecs.StreamWriter): + def decode(self, s, errors='strict'): + return encoder(s) + + +def imap4_utf_7(name): + if name == 'imap4-utf-7': + return (encoder, decoder, StreamReader, StreamWriter) + + +codecs.register(imap4_utf_7) + + +## testing methods + +def imapUTF7Encode(ust): + "Returns imap utf-7 encoded version of string" + return ust.encode('imap4-utf-7') + +def imapUTF7EncodeSequence(seq): + "Returns imap utf-7 encoded version of strings in sequence" + return [imapUTF7Encode(itm) for itm in seq] + + +def imapUTF7Decode(st): + "Returns utf7 encoded version of imap utf-7 string" + return st.decode('imap4-utf-7') + +def imapUTF7DecodeSequence(seq): + "Returns utf7 encoded version of imap utf-7 strings in sequence" + return [imapUTF7Decode(itm) for itm in seq] + + +def utf8Decode(st): + "Returns utf7 encoded version of imap utf-7 string" + return st.decode('utf-8') + + +def utf7SequenceToUTF8(seq): + "Returns utf7 encoded version of imap utf-7 strings in sequence" + return [itm.decode('imap4-utf-7').encode('utf-8') for itm in seq] + + +__all__ = [ 'imapUTF7Encode', 'imapUTF7Decode' ] + + +if __name__ == '__main__': + testdata = [ + (u'foo\r\n\nbar\n', b'foo&AA0ACgAK-bar&AAo-'), + (u'测试', b'&bUuL1Q-'), + (u'Hello 世界', b'Hello &ThZ1TA-') + ] + for s, e in testdata: + #assert s == decoder(encoder(s)[0])[0] + assert s == imapUTF7Decode(e) + assert e == imapUTF7Encode(s) + assert s == imapUTF7Decode(imapUTF7Encode(s)) + assert e == imapUTF7Encode(imapUTF7Decode(e)) + print("All tests passed!") + +# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/plot.py b/python/plot.py new file mode 100644 index 0000000..b65f8a3 --- /dev/null +++ b/python/plot.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# +# Credits: http://www.aosabook.org/en/matplotlib.html +# +# Aaron LI +# 2016-03-14 +# + +# Import the FigureCanvas from the backend of your choice +# and attach the Figure artist to it. +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure +fig = Figure() +canvas = FigureCanvas(fig) + +# Import the numpy library to generate the random numbers. +import numpy as np +x = np.random.randn(10000) + +# Now use a figure method to create an Axes artist; the Axes artist is +# added automatically to the figure container fig.axes. +# Here "111" is from the MATLAB convention: create a grid with 1 row and 1 +# column, and use the first cell in that grid for the location of the new +# Axes. +ax = fig.add_subplot(111) + +# Call the Axes method hist to generate the histogram; hist creates a +# sequence of Rectangle artists for each histogram bar and adds them +# to the Axes container. Here "100" means create 100 bins. +ax.hist(x, 100) + +# Decorate the figure with a title and save it. +ax.set_title('Normal distribution with $\mu=0, \sigma=1$') +fig.savefig('matplotlib_histogram.png') + diff --git a/python/plot_tprofiles_zzh.py b/python/plot_tprofiles_zzh.py new file mode 100644 index 0000000..e5824e9 --- /dev/null +++ b/python/plot_tprofiles_zzh.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# +# Weitian LI +# 2015-09-11 +# + +""" +Plot a list of *temperature profiles* in a grid of subplots with Matplotlib. +""" + +import matplotlib.pyplot as plt + + +def plot_tprofiles(tplist, nrows, ncols, + xlim=None, ylim=None, logx=False, logy=False, + xlab="", ylab="", title=""): + """ + Plot a list of *temperature profiles* in a grid of subplots of size + nrow x ncol. Each subplot is related to a temperature profile. + All the subplots share the same X and Y axes. + The order is by row. + + The tplist is a list of dictionaries, each of which contains all the + necessary data to make the subplot. + + The dictionary consists of the following components: + tpdat = { + "name": "NAME", + "radius": [[radius points], [radius errors]], + "temperature": [[temperature points], [temperature errors]], + "radius_model": [radus points of the fitted model], + "temperature_model": [ + [fitted model value], + [lower bounds given by the model], + [upper bounds given by the model] + ] + } + + Arguments: + tplist - a list of dictionaries containing the data of each + temperature profile. + Note that the length of this list should equal to nrows*ncols. + nrows - number of rows of the subplots + ncols - number of columns of the subplots + xlim - limits of the X axis + ylim - limits of the Y axis + logx - whether to set the log scale for X axis + logy - whether to set the log scale for Y axis + xlab - label for the X axis + ylab - label for the Y axis + title - title for the whole plot + """ + assert len(tplist) == nrows*ncols, "tplist length != nrows*ncols" + # All subplots share both X and Y axes. + fig, axarr = plt.subplots(nrows, ncols, sharex=True, sharey=True) + # Set title for the whole plot. + if title != "": + fig.suptitle(title) + # Set xlab and ylab for each subplot + if xlab != "": + for ax in axarr[-1, :]: + ax.set_xlabel(xlab) + if ylab != "": + for ax in axarr[:, 0]: + ax.set_ylabel(ylab) + for ax in axarr.reshape(-1): + # Set xlim and ylim. + if xlim is not None: + ax.set_xlim(xlim) + if ylim is not None: + ax.set_ylim(ylim) + # Set xscale and yscale. + if logx: + ax.set_xscale("log", nonposx="clip") + if logy: + ax.set_yscale("log", nonposy="clip") + # Decrease the spacing between the subplots and suptitle + fig.subplots_adjust(top=0.94) + # Eleminate the spaces between each row and column. + fig.subplots_adjust(hspace=0, wspace=0) + # Hide X ticks for all subplots but the bottom row. + plt.setp([ax.get_xticklabels() for ax in axarr[:-1, :].reshape(-1)], + visible=False) + # Hide Y ticks for all subplots but the left column. + plt.setp([ax.get_yticklabels() for ax in axarr[:, 1:].reshape(-1)], + visible=False) + # Plot each temperature profile in the tplist + for i, ax in zip(range(len(tplist)), axarr.reshape(-1)): + tpdat = tplist[i] + # Add text to display the name. + # The text is placed at (0.95, 0.95), i.e., the top-right corner, + # with respect to this subplot, and the top-right part of the text + # is aligned to the above position. + ax_pois = ax.get_position() + ax.text(0.95, 0.95, tpdat["name"], + verticalalignment="top", horizontalalignment="right", + transform=ax.transAxes, color="black", fontsize=10) + # Plot data points + if isinstance(tpdat["radius"][0], list) and \ + len(tpdat["radius"]) == 2 and \ + isinstance(tpdat["temperature"][0], list) and \ + len(tpdat["temperature"]) == 2: + # Data points have symmetric errorbar + ax.errorbar(tpdat["radius"][0], tpdat["temperature"][0], + xerr=tpdat["radius"][1], yerr=tpdat["temperature"][1], + color="black", linewidth=1.5, linestyle="None") + else: + ax.plot(tpdat["radius"], tpdat["temperature"], + color="black", linewidth=1.5, linestyle="None") + # Plot model line and bounds band + if isinstance(tpdat["temperature_model"][0], list) and \ + len(tpdat["temperature_model"]) == 3: + # Model data have bounds + ax.plot(tpdat["radius_model"], tpdat["temperature_model"][0], + color="blue", linewidth=1.0) + # Plot model bounds band + ax.fill_between(tpdat["radius_model"], + y1=tpdat["temperature_model"][1], + y2=tpdat["temperature_model"][2], + color="gray", alpha=0.5) + else: + ax.plot(tpdat["radius_model"], tpdat["temperature_model"], + color="blue", linewidth=1.5) + return (fig, axarr) + +# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/randomize_events.py b/python/randomize_events.py new file mode 100755 index 0000000..e1a6e31 --- /dev/null +++ b/python/randomize_events.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# +# Randomize the (X,Y) position of each X-ray photon events according +# to a Gaussian distribution of given sigma. +# +# References: +# [1] G. Scheellenberger, T.H. Reiprich, L. Lovisari, J. Nevalainen & L. David +# 2015, A&A, 575, A30 +# +# +# Aaron LI +# Created: 2016-03-24 +# Updated: 2016-03-24 +# + +from astropy.io import fits +import numpy as np + +import os +import sys +import datetime +import argparse + + +CHANDRA_ARCSEC_PER_PIXEL = 0.492 + +def randomize_events(infile, outfile, sigma, clobber=False): + """ + Randomize the position (X,Y) of each X-ray event according to a + specified size/sigma Gaussian distribution. + """ + sigma_pix = sigma / CHANDRA_ARCSEC_PER_PIXEL + evt_fits = fits.open(infile) + evt_table = evt_fits[1].data + # (X,Y) physical coordinate + evt_x = evt_table["x"] + evt_y = evt_table["y"] + rand_x = np.random.normal(scale=sigma_pix, size=evt_x.shape)\ + .astype(evt_x.dtype) + rand_y = np.random.normal(scale=sigma_pix, size=evt_y.shape)\ + .astype(evt_y.dtype) + evt_x += rand_x + evt_y += rand_y + # Add history to FITS header + evt_hdr = evt_fits[1].header + evt_hdr.add_history("TOOL: %s @ %s" % ( + os.path.basename(sys.argv[0]), + datetime.datetime.utcnow().isoformat())) + evt_hdr.add_history("COMMAND: %s" % " ".join(sys.argv)) + evt_fits.writeto(outfile, clobber=clobber, checksum=True) + + +def main(): + parser = argparse.ArgumentParser( + description="Randomize the (X,Y) of each X-ray event") + parser.add_argument("infile", help="input event file") + parser.add_argument("outfile", help="output randomized event file") + parser.add_argument("-s", "--sigma", dest="sigma", + required=True, type=float, + help="sigma/size of the Gaussian distribution used" + \ + "to randomize the position of events (unit: arcsec)") + parser.add_argument("-C", "--clobber", dest="clobber", + action="store_true", help="overwrite output file if exists") + args = parser.parse_args() + + randomize_events(args.infile, args.outfile, + sigma=args.sigma, clobber=args.clobber) + + +if __name__ == "__main__": + main() + diff --git a/python/sbp_fit.py b/python/sbp_fit.py new file mode 100644 index 0000000..0b7d29a --- /dev/null +++ b/python/sbp_fit.py @@ -0,0 +1,613 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Aaron LI +# Created: 2016-03-13 +# Updated: 2016-03-14 +# +# Changelogs: +# 2016-03-28: +# * Add `main()', `make_model()' +# * Use `configobj' to handle configurations +# * Save fit results and plot +# * Add `ci_report()' +# 2016-03-14: +# * Refactor classes `FitModelSBeta' and `FitModelDBeta' +# * Add matplotlib plot support +# * Add `ignore_data()' and `notice_data()' support +# * Add classes `FitModelSBetaNorm' and `FitModelDBetaNorm' +# + +""" +Fit the surface brightness profile (SBP) with the single-beta model: + s(r) = s0 * [1.0 + (r/rc)^2] ^ (0.5-3*beta) + bkg +or the double-beta model: + s(r) = s01 * [1.0 + (r/rc1)^2] ^ (0.5-3*beta1) + + s02 * [1.0 + (r/rc2)^2] ^ (0.5-3*beta2) + bkg + +Sample config file: +------------------------------------------------- +name = <NAME> +obsid = <OBSID> + +sbpfile = sbprofile.txt + +model = sbeta +outfile = sbpfit_sbeta.txt +imgfile = sbpfit_sbeta.png + +#model = dbeta +#outfile = sbpfit_dbeta.txt +#imgfile = sbpfit_dbeta.png + +# data range to be ignored during fitting +#ignore = 0.0-20.0, + +[sbeta] +# name = initial, lower, upper +s0 = 1.0e-8, 0.0, 1.0e-6 +rc = 30.0, 1.0, 1.0e4 +beta = 0.7, 0.3, 1.1 +bkg = 1.0e-9, 0.0, 1.0e-7 + +[dbeta] +s01 = 1.0e-8, 0.0, 1.0e-6 +rc1 = 50.0, 10.0, 1.0e4 +beta1 = 0.7, 0.3, 1.1 +s02 = 1.0e-8, 0.0, 1.0e-6 +rc2 = 30.0, 1.0, 5.0e2 +beta2 = 0.7, 0.3, 1.1 +bkg = 1.0e-9, 0.0, 1.0e-7 +------------------------------------------------- +""" + +__version__ = "0.4.0" +__date__ = "2016-03-28" + +import numpy as np +import lmfit +import matplotlib.pyplot as plt + +from configobj import ConfigObj +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure + +import os +import sys +import re +import argparse + + +plt.style.use("ggplot") + + +class FitModel: + """ + Meta-class of the fitting model. + + The supplied `func' should have the following syntax: + y = f(x, params) + where the `params' is the parameters to be fitted, + and should be provided as well. + """ + def __init__(self, name=None, func=None, params=lmfit.Parameters()): + self.name = name + self.func = func + self.params = params + + def f(self, x): + return self.func(x, self.params) + + def get_param(self, name=None): + """ + Return the requested `Parameter' object or the whole + `Parameters' object of no name supplied. + """ + try: + return self.params[name] + except KeyError: + return self.params + + def set_param(self, name, *args, **kwargs): + """ + Set the properties of the specified parameter. + """ + param = self.params[name] + param.set(*args, **kwargs) + + def plot(self, params, xdata, ax): + """ + Plot the fitted model. + """ + f_fitted = lambda x: self.func(x, params) + ydata = f_fitted(xdata) + ax.plot(xdata, ydata, 'k-') + +class FitModelSBeta(FitModel): + """ + The single-beta model to be fitted. + Single-beta model, with a constant background. + """ + params = lmfit.Parameters() + params.add_many( # (name, value, vary, min, max, expr) + ("s0", 1.0e-8, True, 0.0, 1.0e-6, None), + ("rc", 30.0, True, 1.0, 1.0e4, None), + ("beta", 0.7, True, 0.3, 1.1, None), + ("bkg", 1.0e-9, True, 0.0, 1.0e-7, None)) + + @staticmethod + def sbeta(r, params): + parvals = params.valuesdict() + s0 = parvals["s0"] + rc = parvals["rc"] + beta = parvals["beta"] + bkg = parvals["bkg"] + return s0 * np.power((1 + (r/rc)**2), (0.5 - 3*beta)) + bkg + + def __init__(self): + super(self.__class__, self).__init__(name="Single-beta", + func=self.sbeta, params=self.params) + + def plot(self, params, xdata, ax): + """ + Plot the fitted model, as well as the fitted parameters. + """ + super(self.__class__, self).plot(params, xdata, ax) + ydata = self.sbeta(xdata, params) + # fitted paramters + ax.vlines(x=params["rc"].value, ymin=min(ydata), ymax=max(ydata), + linestyles="dashed") + ax.hlines(y=params["bkg"].value, xmin=min(xdata), xmax=max(xdata), + linestyles="dashed") + ax.text(x=params["rc"].value, y=min(ydata), + s="beta: %.2f\nrc: %.2f" % (params["beta"].value, + params["rc"].value)) + ax.text(x=min(xdata), y=min(ydata), + s="bkg: %.3e" % params["bkg"].value, + verticalalignment="top") + + +class FitModelDBeta(FitModel): + """ + The double-beta model to be fitted. + Double-beta model, with a constant background. + + NOTE: + the first beta component (s01, rc1, beta1) describes the main and + outer SBP; while the second beta component (s02, rc2, beta2) accounts + for the central brightness excess. + """ + params = lmfit.Parameters() + params.add("s01", value=1.0e-8, min=0.0, max=1.0e-6) + params.add("rc1", value=50.0, min=10.0, max=1.0e4) + params.add("beta1", value=0.7, min=0.3, max=1.1) + #params.add("df_s0", value=1.0e-8, min=0.0, max=1.0e-6) + #params.add("s02", expr="s01 + df_s0") + params.add("s02", value=1.0e-8, min=0.0, max=1.0e-6) + #params.add("df_rc", value=30.0, min=0.0, max=1.0e4) + #params.add("rc2", expr="rc1 - df_rc") + params.add("rc2", value=20.0, min=1.0, max=5.0e2) + params.add("beta2", value=0.7, min=0.3, max=1.1) + params.add("bkg", value=1.0e-9, min=0.0, max=1.0e-7) + + @staticmethod + def beta1(r, params): + """ + This beta component describes the main/outer part of the SBP. + """ + parvals = params.valuesdict() + s01 = parvals["s01"] + rc1 = parvals["rc1"] + beta1 = parvals["beta1"] + bkg = parvals["bkg"] + return s01 * np.power((1 + (r/rc1)**2), (0.5 - 3*beta1)) + bkg + + @staticmethod + def beta2(r, params): + """ + This beta component describes the central/excess part of the SBP. + """ + parvals = params.valuesdict() + s02 = parvals["s02"] + rc2 = parvals["rc2"] + beta2 = parvals["beta2"] + return s02 * np.power((1 + (r/rc2)**2), (0.5 - 3*beta2)) + + @classmethod + def dbeta(self, r, params): + return self.beta1(r, params) + self.beta2(r, params) + + def __init__(self): + super(self.__class__, self).__init__(name="Double-beta", + func=self.dbeta, params=self.params) + + def plot(self, params, xdata, ax): + """ + Plot the fitted model, and each beta component, + as well as the fitted parameters. + """ + super(self.__class__, self).plot(params, xdata, ax) + beta1_ydata = self.beta1(xdata, params) + beta2_ydata = self.beta2(xdata, params) + ax.plot(xdata, beta1_ydata, 'b-.') + ax.plot(xdata, beta2_ydata, 'b-.') + # fitted paramters + ydata = beta1_ydata + beta2_ydata + ax.vlines(x=params["rc1"].value, ymin=min(ydata), ymax=max(ydata), + linestyles="dashed") + ax.vlines(x=params["rc2"].value, ymin=min(ydata), ymax=max(ydata), + linestyles="dashed") + ax.hlines(y=params["bkg"].value, xmin=min(xdata), xmax=max(xdata), + linestyles="dashed") + ax.text(x=params["rc1"].value, y=min(ydata), + s="beta1: %.2f\nrc1: %.2f" % (params["beta1"].value, + params["rc1"].value)) + ax.text(x=params["rc2"].value, y=min(ydata), + s="beta2: %.2f\nrc2: %.2f" % (params["beta2"].value, + params["rc2"].value)) + ax.text(x=min(xdata), y=min(ydata), + s="bkg: %.3e" % params["bkg"].value, + verticalalignment="top") + + +class FitModelSBetaNorm(FitModel): + """ + The single-beta model to be fitted. + Single-beta model, with a constant background. + Normalized the `s0' and `bkg' parameters by take the logarithm. + """ + params = lmfit.Parameters() + params.add_many( # (name, value, vary, min, max, expr) + ("log10_s0", -8.0, True, -12.0, -6.0, None), + ("rc", 30.0, True, 1.0, 1.0e4, None), + ("beta", 0.7, True, 0.3, 1.1, None), + ("log10_bkg", -9.0, True, -12.0, -7.0, None)) + + @staticmethod + def sbeta(r, params): + parvals = params.valuesdict() + s0 = 10 ** parvals["log10_s0"] + rc = parvals["rc"] + beta = parvals["beta"] + bkg = 10 ** parvals["log10_bkg"] + return s0 * np.power((1 + (r/rc)**2), (0.5 - 3*beta)) + bkg + + def __init__(self): + super(self.__class__, self).__init__(name="Single-beta", + func=self.sbeta, params=self.params) + + def plot(self, params, xdata, ax): + """ + Plot the fitted model, as well as the fitted parameters. + """ + super(self.__class__, self).plot(params, xdata, ax) + ydata = self.sbeta(xdata, params) + # fitted paramters + ax.vlines(x=params["rc"].value, ymin=min(ydata), ymax=max(ydata), + linestyles="dashed") + ax.hlines(y=(10 ** params["bkg"].value), xmin=min(xdata), + xmax=max(xdata), linestyles="dashed") + ax.text(x=params["rc"].value, y=min(ydata), + s="beta: %.2f\nrc: %.2f" % (params["beta"].value, + params["rc"].value)) + ax.text(x=min(xdata), y=min(ydata), + s="bkg: %.3e" % (10 ** params["bkg"].value), + verticalalignment="top") + + +class FitModelDBetaNorm(FitModel): + """ + The double-beta model to be fitted. + Double-beta model, with a constant background. + Normalized the `s01', `s02' and `bkg' parameters by take the logarithm. + + NOTE: + the first beta component (s01, rc1, beta1) describes the main and + outer SBP; while the second beta component (s02, rc2, beta2) accounts + for the central brightness excess. + """ + params = lmfit.Parameters() + params.add("log10_s01", value=-8.0, min=-12.0, max=-6.0) + params.add("rc1", value=50.0, min=10.0, max=1.0e4) + params.add("beta1", value=0.7, min=0.3, max=1.1) + #params.add("df_s0", value=1.0e-8, min=0.0, max=1.0e-6) + #params.add("s02", expr="s01 + df_s0") + params.add("log10_s02", value=-8.0, min=-12.0, max=-6.0) + #params.add("df_rc", value=30.0, min=0.0, max=1.0e4) + #params.add("rc2", expr="rc1 - df_rc") + params.add("rc2", value=20.0, min=1.0, max=5.0e2) + params.add("beta2", value=0.7, min=0.3, max=1.1) + params.add("log10_bkg", value=-9.0, min=-12.0, max=-7.0) + + @staticmethod + def beta1(r, params): + """ + This beta component describes the main/outer part of the SBP. + """ + parvals = params.valuesdict() + s01 = 10 ** parvals["log10_s01"] + rc1 = parvals["rc1"] + beta1 = parvals["beta1"] + bkg = 10 ** parvals["log10_bkg"] + return s01 * np.power((1 + (r/rc1)**2), (0.5 - 3*beta1)) + bkg + + @staticmethod + def beta2(r, params): + """ + This beta component describes the central/excess part of the SBP. + """ + parvals = params.valuesdict() + s02 = 10 ** parvals["log10_s02"] + rc2 = parvals["rc2"] + beta2 = parvals["beta2"] + return s02 * np.power((1 + (r/rc2)**2), (0.5 - 3*beta2)) + + @classmethod + def dbeta(self, r, params): + return self.beta1(r, params) + self.beta2(r, params) + + def __init__(self): + super(self.__class__, self).__init__(name="Double-beta", + func=self.dbeta, params=self.params) + + def plot(self, params, xdata, ax): + """ + Plot the fitted model, and each beta component, + as well as the fitted parameters. + """ + super(self.__class__, self).plot(params, xdata, ax) + beta1_ydata = self.beta1(xdata, params) + beta2_ydata = self.beta2(xdata, params) + ax.plot(xdata, beta1_ydata, 'b-.') + ax.plot(xdata, beta2_ydata, 'b-.') + # fitted paramters + ydata = beta1_ydata + beta2_ydata + ax.vlines(x=params["log10_rc1"].value, ymin=min(ydata), ymax=max(ydata), + linestyles="dashed") + ax.vlines(x=params["rc2"].value, ymin=min(ydata), ymax=max(ydata), + linestyles="dashed") + ax.hlines(y=(10 ** params["bkg"].value), xmin=min(xdata), + xmax=max(xdata), linestyles="dashed") + ax.text(x=params["rc1"].value, y=min(ydata), + s="beta1: %.2f\nrc1: %.2f" % (params["beta1"].value, + params["rc1"].value)) + ax.text(x=params["rc2"].value, y=min(ydata), + s="beta2: %.2f\nrc2: %.2f" % (params["beta2"].value, + params["rc2"].value)) + ax.text(x=min(xdata), y=min(ydata), + s="bkg: %.3e" % (10 ** params["bkg"].value), + verticalalignment="top") + + +class SbpFit: + """ + Class to handle the SBP fitting with single-/double-beta model. + """ + def __init__(self, model, method="lbfgsb", + xdata=None, ydata=None, xerr=None, yerr=None, + name=None, obsid=None): + self.method = method + self.model = model + self.xdata = xdata + self.ydata = ydata + self.xerr = xerr + self.yerr = yerr + if xdata is not None: + self.mask = np.ones(xdata.shape, dtype=np.bool) + else: + self.mask = None + self.name = name + self.obsid = obsid + + def set_source(self, name, obsid=None): + self.name = name + self.obsid = obsid + + def load_data(self, xdata, ydata, xerr, yerr): + self.xdata = xdata + self.ydata = ydata + self.xerr = xerr + self.yerr = yerr + self.mask = np.ones(xdata.shape, dtype=np.bool) + + def ignore_data(self, xmin=None, xmax=None): + """ + Ignore the data points within range [xmin, xmax]. + If xmin is None, then xmin=min(xdata); + if xmax is None, then xmax=max(xdata). + """ + if xmin is None: + xmin = min(self.xdata) + if xmax is None: + xmax = max(self.xdata) + ignore_idx = np.logical_and(self.xdata >= xmin, self.xdata <= xmax) + self.mask[ignore_idx] = False + # reset `f_residual' + self.f_residual = None + + def notice_data(self, xmin=None, xmax=None): + """ + Notice the data points within range [xmin, xmax]. + If xmin is None, then xmin=min(xdata); + if xmax is None, then xmax=max(xdata). + """ + if xmin is None: + xmin = min(self.xdata) + if xmax is None: + xmax = max(self.xdata) + notice_idx = np.logical_and(self.xdata >= xmin, self.xdata <= xmax) + self.mask[notice_idx] = True + # reset `f_residual' + self.f_residual = None + + def set_residual(self): + def f_residual(params): + if self.yerr is None: + return self.model.func(self.xdata[self.mask], params) - \ + self.ydata + else: + return (self.model.func(self.xdata[self.mask], params) - \ + self.ydata[self.mask]) / self.yerr[self.mask] + self.f_residual = f_residual + + def fit(self, method=None): + if method is None: + method = self.method + if not hasattr(self, "f_residual") or self.f_residual is None: + self.set_residual() + self.fitter = lmfit.Minimizer(self.f_residual, self.model.params) + self.results = self.fitter.minimize(method=method) + self.fitted_model = lambda x: self.model.func(x, self.results.params) + + def calc_ci(self, sigmas=[0.68, 0.90]): + # `conf_interval' requires the fitted results have valid `stderr', + # so we need to re-fit the model with method `leastsq'. + results = self.fitter.minimize(method="leastsq", + params=self.results.params) + self.ci, self.trace = lmfit.conf_interval(self.fitter, results, + sigmas=sigmas, trace=True) + + @staticmethod + def ci_report(ci, with_offset=True): + """ + Format and return the report for confidence intervals. + This function is intended to replace the `lmfit.ci_report()` + which do not use scitenfic notation for small values. + """ + pnames = list(ci.keys()) + plen = max(map(len, pnames)) + report = [] + # header + sigmas = [ v[0] for v in ci.get(pnames[0]) if v[0] > 0.0 ] + fmt = "{:<%d}" % (plen+1) + \ + " {:11.2%} {:11.2%} _BEST_ {:11.2%} {:11.2%}" + report.append(fmt.format(" ", *sigmas)) + # parameters ci + for par in pnames: + values = [ v[1] for v in ci.get(par) ] + if with_offset: + N = len(values) + offsets = [ values[N//2] for i in range(N) ] + offsets[N//2] = 0.0 + values = list(map(lambda a,b: a-b, values, offsets)) + fmt = "{:<%d}" % (plen+1) + \ + " {:+11.5g} {:+11.5g} {:11.5g} {:+11.5g} {:+11.5g}" + report.append(fmt.format(par+":", *values)) + return "\n".join(report) + + def report(self, outfile=sys.stdout): + if hasattr(self, "results") and self.results is not None: + p = lmfit.fit_report(self.results) + print(p, file=outfile) + if hasattr(self, "ci") and self.ci is not None: + print("-----------------------------------------------", + file=outfile) + p = self.ci_report(self.ci) + print(p, file=outfile) + + def plot(self, ax=None, fig=None): + if ax is None: + fig, ax = plt.subplots(1, 1) + # noticed data points + eb = ax.errorbar(self.xdata[self.mask], self.ydata[self.mask], + xerr=self.xerr[self.mask], yerr=self.yerr[self.mask], + fmt="none") + # ignored data points + ignore_mask = np.logical_not(self.mask) + if np.sum(ignore_mask) > 0: + eb = ax.errorbar(self.xdata[ignore_mask], self.ydata[ignore_mask], + xerr=self.xerr[ignore_mask], yerr=self.yerr[ignore_mask], + fmt="none") + eb[-1][0].set_linestyle("-.") + # fitted model + xmax = self.xdata[-1] + self.xerr[-1] + xpred = np.power(10, np.linspace(0, np.log10(xmax), 2*len(self.xdata))) + ypred = self.fitted_model(xpred) + ymin = min(min(self.ydata), min(ypred)) + ymax = max(max(self.ydata), max(ypred)) + self.model.plot(params=self.results.params, xdata=xpred, ax=ax) + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlim(1.0, xmax) + ax.set_ylim(ymin/1.2, ymax*1.2) + name = self.name + if self.obsid is not None: + name += "; %s" % self.obsid + ax.set_title("Fitted Surface Brightness Profile (%s)" % name) + ax.set_xlabel("Radius (pixel)") + ax.set_ylabel(r"Surface Brightness (photons/cm$^2$/pixel$^2$/s)") + ax.text(x=xmax, y=ymax, + s="redchi: %.2f / %.2f = %.2f" % (self.results.chisqr, + self.results.nfree, self.results.chisqr/self.results.nfree), + horizontalalignment="right", verticalalignment="top") + return (fig, ax) + + +def make_model(config): + """ + Make the model with parameters set according to the config. + """ + modelname = config["model"] + if modelname == "sbeta": + # single-beta model + model = FitModelSBeta() + elif modelname == "dbeta": + # double-beta model + model = FitModelDBeta() + else: + raise ValueError("Invalid model") + # set initial values and bounds for the model parameters + params = config.get(modelname) + for p, value in params.items(): + model.set_param(name=p, value=float(value[0]), + min=float(value[1]), max=float(value[2])) + return model + + +def main(): + # parser for command line options and arguments + parser = argparse.ArgumentParser( + description="Fit surface brightness profile with " + \ + "single-/double-beta model", + epilog="Version: %s (%s)" % (__version__, __date__)) + parser.add_argument("-V", "--version", action="version", + version="%(prog)s " + "%s (%s)" % (__version__, __date__)) + parser.add_argument("config", help="Config file for SBP fitting") + args = parser.parse_args() + + config = ConfigObj(args.config) + + # fit model + model = make_model(config) + + # sbp data and fit object + sbpdata = np.loadtxt(config["sbpfile"]) + sbpfit = SbpFit(model=model, xdata=sbpdata[:, 0], xerr=sbpdata[:, 1], + ydata=sbpdata[:, 2], yerr=sbpdata[:, 3]) + sbpfit.set_source(config["name"], obsid=config.get("obsid")) + + # apply data range ignorance + if "ignore" in config.keys(): + for ig in config.as_list("ignore"): + xmin, xmax = map(float, ig.split("-")) + sbpfit.ignore_data(xmin=xmin, xmax=xmax) + + # fit and calculate confidence intervals + sbpfit.fit() + sbpfit.calc_ci() + sbpfit.report() + with open(config["outfile"], "w") as outfile: + sbpfit.report(outfile=outfile) + + # make and save a plot + fig = Figure() + canvas = FigureCanvas(fig) + ax = fig.add_subplot(111) + sbpfit.plot(ax=ax, fig=fig) + fig.savefig(config["imgfile"]) + + +if __name__ == "__main__": + main() + +# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: # diff --git a/python/starlet.py b/python/starlet.py new file mode 100644 index 0000000..67e69ba --- /dev/null +++ b/python/starlet.py @@ -0,0 +1,513 @@ +# -*- coding: utf-8 -*- +# +# References: +# [1] Jean-Luc Starck, Fionn Murtagh & Jalal M. Fadili +# Sparse Image and Signal Processing: Wavelets, Curvelets, Morphological Diversity +# Section 3.5, 6.6 +# +# Credits: +# [1] https://github.com/abrazhe/image-funcut/blob/master/imfun/atrous.py +# +# Aaron LI +# Created: 2016-03-17 +# Updated: 2016-03-22 +# + +""" +Starlet wavelet transform, i.e., isotropic undecimated wavelet transform +(IUWT), or à trous wavelet transform. +And multi-scale variance stabling transform (MS-VST). +""" + +import sys + +import numpy as np +import scipy as sp +from scipy import signal + + +class B3Spline: + """ + B3-spline wavelet. + """ + # scaling function (phi) + dec_lo = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) / 16 + dec_hi = np.array([-1.0, -4.0, 10.0, -4.0, -1.0]) / 16 + rec_lo = np.array([0.0, 0.0, 1.0, 0.0, 0.0]) + rec_hi = np.array([0.0, 0.0, 1.0, 0.0, 0.0]) + + +class IUWT: + """ + Isotropic undecimated wavelet transform. + """ + ## Decomposition filters list: + # a_{scale} = convole(a_0, filters[scale]) + # Note: the zero-th scale filter (i.e., delta function) is the first + # element, thus the array index is the same as the decomposition scale. + filters = [] + + phi = None # wavelet scaling function (2D) + level = 0 # number of transform level + decomposition = None # decomposed coefficients/images + reconstruction = None # reconstructed image + + # convolution boundary condition + boundary = "symm" + + def __init__(self, phi=B3Spline.dec_lo, level=None, boundary="symm", + data=None): + self.set_wavelet(phi=phi) + self.level = level + self.boundary = boundary + self.data = np.array(data) + + def reset(self): + """ + Reset the object attributes. + """ + self.data = None + self.phi = None + self.decomposition = None + self.reconstruction = None + self.level = 0 + self.filters = [] + self.boundary = "symm" + + def load_data(self, data): + self.reset() + self.data = np.array(data) + + def set_wavelet(self, phi): + self.reset() + phi = np.array(phi) + if phi.ndim == 1: + phi_ = phi.reshape(1, -1) + self.phi = np.dot(phi_.T, phi_) + elif phi.ndim == 2: + self.phi = phi + else: + raise ValueError("Invalid phi dimension") + + def calc_filters(self): + """ + Calculate the convolution filters of each scale. + Note: the zero-th scale filter (i.e., delta function) is the first + element, thus the array index is the same as the decomposition scale. + """ + self.filters = [] + # scale 0: delta function + h = np.array([[1]]) # NOTE: 2D + self.filters.append(h) + # scale 1 + h = self.phi[::-1, ::-1] + self.filters.append(h) + for scale in range(2, self.level+1): + h_up = self.zupsample(self.phi, order=scale-1) + h2 = signal.convolve2d(h_up[::-1, ::-1], h, mode="same", + boundary=self.boundary) + self.filters.append(h2) + + def transform(self, data, scale, boundary="symm"): + """ + Perform only one scale wavelet transform for the given data. + + return: + [ approx, detail ] + """ + self.decomposition = [] + approx = signal.convolve2d(data, self.filters[scale], + mode="same", boundary=self.boundary) + detail = data - approx + return [approx, detail] + + def decompose(self, level, boundary="symm"): + """ + Perform IUWT decomposition in the plain loop way. + The filters of each scale/level are calculated first, then the + approximations of each scale/level are calculated by convolving the + raw/finest image with these filters. + + return: + [ W_1, W_2, ..., W_n, A_n ] + n = level + W: wavelet details + A: approximation + """ + self.boundary = boundary + if self.level != level or self.filters == []: + self.level = level + self.calc_filters() + self.decomposition = [] + approx = self.data + for scale in range(1, level+1): + # approximation: + approx2 = signal.convolve2d(self.data, self.filters[scale], + mode="same", boundary=self.boundary) + # wavelet details: + w = approx - approx2 + self.decomposition.append(w) + if scale == level: + self.decomposition.append(approx2) + approx = approx2 + return self.decomposition + + def decompose_recursive(self, level, boundary="symm"): + """ + Perform the IUWT decomposition in the recursive way. + + return: + [ W_1, W_2, ..., W_n, A_n ] + n = level + W: wavelet details + A: approximation + """ + self.level = level + self.boundary = boundary + self.decomposition = self.__decompose(self.data, self.phi, level=level) + return self.decomposition + + def __decompose(self, data, phi, level): + """ + 2D IUWT decomposition (or stationary wavelet transform). + + This is a convolution version, where kernel is zero-upsampled + explicitly. Not fast. + + Parameters: + - level : level of decomposition + - phi : low-pass filter kernel + - boundary : boundary conditions (passed to scipy.signal.convolve2d, + 'symm' by default) + + Returns: + list of wavelet details + last approximation. Each element in + the list is an image of the same size as the input image. + """ + if level <= 0: + return data + shapecheck = map(lambda a,b:a>b, data.shape, phi.shape) + assert np.all(shapecheck) + # approximation: + approx = signal.convolve2d(data, phi[::-1, ::-1], mode="same", + boundary=self.boundary) + # wavelet details: + w = data - approx + phi_up = self.zupsample(phi, order=1) + shapecheck = map(lambda a,b:a>b, data.shape, phi_up.shape) + if level == 1: + return [w, approx] + elif not np.all(shapecheck): + print("Maximum allowed decomposition level reached", + file=sys.stderr) + return [w, approx] + else: + return [w] + self.__decompose(approx, phi_up, level-1) + + @staticmethod + def zupsample(data, order=1): + """ + Upsample data array by interleaving it with zero's. + + h{up_order: n}[l] = (1) h[l], if l % 2^n == 0; + (2) 0, otherwise + """ + shape = data.shape + new_shape = [ (2**order * (n-1) + 1) for n in shape ] + output = np.zeros(new_shape, dtype=data.dtype) + output[[ slice(None, None, 2**order) for d in shape ]] = data + return output + + def reconstruct(self, decomposition=None): + if decomposition is not None: + reconstruction = np.sum(decomposition, axis=0) + return reconstruction + else: + self.reconstruction = np.sum(self.decomposition, axis=0) + + def get_detail(self, scale): + """ + Get the wavelet detail coefficients of given scale. + Note: 1 <= scale <= level + """ + if scale < 1 or scale > self.level: + raise ValueError("Invalid scale") + return self.decomposition[scale-1] + + def get_approx(self): + """ + Get the approximation coefficients of the largest scale. + """ + return self.decomposition[-1] + + +class IUWT_VST(IUWT): + """ + IUWT with Multi-scale variance stabling transform. + + Refernce: + [1] Bo Zhang, Jalal M. Fadili & Jean-Luc Starck, + IEEE Trans. Image Processing, 17, 17, 2008 + """ + # VST coefficients and the corresponding asymptotic standard deviation + # of each scale. + vst_coef = [] + + def reset(self): + super(self.__class__, self).reset() + vst_coef = [] + + def __decompose(self): + raise AttributeError("No '__decompose' attribute") + + @staticmethod + def soft_threshold(data, threshold): + if isinstance(data, np.ndarray): + data_th = data.copy() + data_th[np.abs(data) <= threshold] = 0.0 + data_th[data > threshold] -= threshold + data_th[data < -threshold] += threshold + else: + data_th = data + if np.abs(data) <= threshold: + data_th = 0.0 + elif data > threshold: + data_th -= threshold + else: + data_th += threshold + return data_th + + def tau(self, k, scale): + """ + Helper function used in VST coefficients calculation. + """ + return np.sum(np.power(self.filters[scale], k)) + + def filters_product(self, scale1, scale2): + """ + Calculate the scalar product of the filters of two scales, + considering only the overlapped part. + Helper function used in VST coefficients calculation. + """ + if scale1 > scale2: + filter_big = self.filters[scale1] + filter_small = self.filters[scale2] + else: + filter_big = self.filters[scale2] + filter_small = self.filters[scale1] + # crop the big filter to match the size of the small filter + size_big = filter_big.shape + size_small = filter_small.shape + size_diff2 = list(map(lambda a,b: (a-b)//2, size_big, size_small)) + filter_big_crop = filter_big[ + size_diff2[0]:(size_big[0]-size_diff2[0]), + size_diff2[1]:(size_big[1]-size_diff2[1])] + assert(np.all(list(map(lambda a,b: a==b, + size_small, filter_big_crop.shape)))) + product = np.sum(filter_small * filter_big_crop) + return product + + def calc_vst_coef(self): + """ + Calculate the VST coefficients and the corresponding + asymptotic standard deviation of each scale, according to the + calculated filters of each scale/level. + """ + self.vst_coef = [] + for scale in range(self.level+1): + b = 2 * np.sqrt(np.abs(self.tau(1, scale)) / self.tau(2, scale)) + c = 7.0*self.tau(2, scale) / (8.0*self.tau(1, scale)) - \ + self.tau(3, scale) / (2.0*self.tau(2, scale)) + if scale == 0: + std = -1.0 + else: + std = np.sqrt((self.tau(2, scale-1) / \ + (4 * self.tau(1, scale-1)**2)) + \ + (self.tau(2, scale) / (4 * self.tau(1, scale)**2)) - \ + (self.filters_product(scale-1, scale) / \ + (2 * self.tau(1, scale-1) * self.tau(1, scale)))) + self.vst_coef.append({ "b": b, "c": c, "std": std }) + + def vst(self, data, scale, coupled=True): + """ + Perform variance stabling transform + + XXX: parameter `coupled' why?? + Credit: MSVST-V1.0/src/libmsvst/B3VSTAtrous.h + """ + self.vst_coupled = coupled + if self.vst_coef == []: + self.calc_vst_coef() + if coupled: + b = 1.0 + else: + b = self.vst_coef[scale]["b"] + data_vst = b * np.sqrt(np.abs(data + self.vst_coef[scale]["c"])) + return data_vst + + def ivst(self, data, scale, cbias=True): + """ + Inverse variance stabling transform + NOTE: assuming that `a_{j} + c^{j}' are all positive. + + XXX: parameter `cbias' why?? + `bias correction' is recommended while reconstruct the data + after estimation + Credit: MSVST-V1.0/src/libmsvst/B3VSTAtrous.h + """ + self.vst_cbias = cbias + if cbias: + cb = 1.0 / (self.vst_coef[scale]["b"] ** 2) + else: + cb = 0.0 + data_ivst = data ** 2 + cb - self.vst_coef[scale]["c"] + return data_ivst + + def is_significant(self, scale, fdr=0.1, independent=False): + """ + Multiple hypothesis testing with false discovery rate (FDR) control. + + If `independent=True': FDR <= (m0/m) * q + otherwise: FDR <= (m0/m) * q * (1 + 1/2 + 1/3 + ... + 1/m) + + References: + [1] False discovery rate - Wikipedia + https://en.wikipedia.org/wiki/False_discovery_rate + """ + coef = self.get_detail(scale) + pvalues = 2 * (1 - sp.stats.norm.cdf(np.abs(coef) / \ + self.vst_coef[scale]["std"])) + p_sorted = pvalues.flatten() + p_sorted.sort() + N = len(p_sorted) + if independent: + cn = 1.0 + else: + cn = np.sum(1.0 / np.arange(1, N+1)) + p_comp = fdr * np.arange(N) / (N * cn) + comp = (p_sorted < p_comp) + # cutoff p-value after FDR control/correction + p_cutoff = np.max(p_sorted[comp]) + return (pvalues <= p_cutoff, p_cutoff) + + def denoise(self, fdr=0.1, fdr_independent=False): + """ + Denoise the wavelet coefficients by controlling FDR. + """ + self.fdr = fdr + self.fdr_indepent = fdr_independent + self.denoised = [] + # supports of significant coefficients of each scale + self.sig_supports = [None] # make index match the scale + self.p_cutoff = [None] + for scale in range(1, self.level+1): + coef = self.get_detail(scale) + sig, p_cutoff = self.is_significant(scale, fdr, fdr_independent) + coef[np.logical_not(sig)] = 0.0 + self.denoised.append(coef) + self.sig_supports.append(sig) + self.p_cutoff.append(p_cutoff) + # append the last approximation + self.denoised.append(self.get_approx()) + + def decompose(self, level, boundary="symm"): + """ + 2D IUWT decomposition with VST. + """ + self.boundary = boundary + if self.level != level or self.filters == []: + self.level = level + self.calc_filters() + self.calc_vst_coef() + self.decomposition = [] + approx = self.data + for scale in range(1, level+1): + # approximation: + approx2 = signal.convolve2d(self.data, self.filters[scale], + mode="same", boundary=self.boundary) + # wavelet details: + w = self.vst(approx, scale=scale-1) - self.vst(approx2, scale=scale) + self.decomposition.append(w) + if scale == level: + self.decomposition.append(approx2) + approx = approx2 + return self.decomposition + + def reconstruct_ivst(self, denoised=True, positive_project=True): + """ + Reconstruct the original image from the *un-denoised* decomposition + by applying the inverse VST. + + This reconstruction result is also used as the `initial condition' + for the below `iterative reconstruction' algorithm. + + arguments: + * denoised: whether use th denoised data or the direct decomposition + * positive_project: whether replace negative values with zeros + """ + if denoised: + decomposition = self.denoised + else: + decomposition = self.decomposition + self.positive_project = positive_project + details = np.sum(decomposition[:-1], axis=0) + approx = self.vst(decomposition[-1], scale=self.level) + reconstruction = self.ivst(approx+details, scale=0) + if positive_project: + reconstruction[reconstruction < 0.0] = 0.0 + self.reconstruction = reconstruction + return reconstruction + + def reconstruct(self, denoised=True, niter=20, verbose=False): + """ + Reconstruct the original image using iterative method with + L1 regularization, because the denoising violates the exact inverse + procedure. + + arguments: + * denoised: whether use the denoised coefficients + * niter: number of iterations + """ + if denoised: + decomposition = self.denoised + else: + decomposition = self.decomposition + # L1 regularization + lbd = 1.0 + delta = lbd / (niter - 1) + # initial solution + solution = self.reconstruct_ivst(denoised=denoised, + positive_project=True) + # + iuwt = IUWT(level=self.level) + iuwt.calc_filters() + # iterative reconstruction + for i in range(niter): + if verbose: + print("iter: %d" % i+1, file=sys.stderr) + tempd = self.data.copy() + solution_decomp = [] + for scale in range(1, self.level+1): + approx, detail = iuwt.transform(tempd, scale) + approx_sol, detail_sol = iuwt.transform(solution, scale) + # Update coefficients according to the significant supports, + # which are acquired during the denosing precodure with FDR. + sig = self.sig_supports[scale] + detail_sol[sig] = detail[sig] + detail_sol = self.soft_threshold(detail_sol, threshold=lbd) + # + solution_decomp.append(detail_sol) + tempd = approx.copy() + solution = approx_sol.copy() + # last approximation (the two are the same) + solution_decomp.append(approx) + # reconstruct + solution = iuwt.reconstruct(decomposition=solution_decomp) + # discard all negative values + solution[solution < 0] = 0.0 + # + lbd -= delta + # + self.reconstruction = solution + return self.reconstruction + diff --git a/python/xkeywordsync.py b/python/xkeywordsync.py new file mode 100644 index 0000000..73f48b9 --- /dev/null +++ b/python/xkeywordsync.py @@ -0,0 +1,533 @@ +#!/bin/usr/env python3 +# -*- coding: utf-8 -*- +# +# Credits: +# [1] Gaute Hope: gauteh/abunchoftags +# https://github.com/gauteh/abunchoftags/blob/master/keywsync.cc +# +# TODO: +# * Support case-insensitive tags merge +# (ref: http://stackoverflow.com/a/1480230) +# * Accept a specified mtime, and only deal with files with newer mtime. +# +# Aaron LI +# Created: 2016-01-24 +# + +""" +Sync message 'X-Keywords' header with notmuch tags. + +* tags-to-keywords: + Check if the messages in the query have a matching 'X-Keywords' header + to the list of notmuch tags. + If not, update the 'X-Keywords' and re-write the message. + +* keywords-to-tags: + Check if the messages in the query have matching notmuch tags to the + 'X-Keywords' header. + If not, update the tags in the notmuch database. + +* merge-keywords-tags: + Merge the 'X-Keywords' labels and notmuch tags, and update both. +""" + +__version__ = "0.1.2" +__date__ = "2016-01-25" + +import os +import sys +import argparse +import email + +# Require Python 3.4, or install package 'enum34' +from enum import Enum + +from notmuch import Database, Query + +from imapUTF7 import imapUTF7Decode, imapUTF7Encode + + +class SyncDirection(Enum): + """ + Synchronization direction + """ + MERGE_KEYWORDS_TAGS = 0 # Merge 'X-Keywords' and notmuch tags and + # update both + KEYWORDS_TO_TAGS = 1 # Sync 'X-Keywords' header to notmuch tags + TAGS_TO_KEYWORDS = 2 # Sync notmuch tags to 'X-Keywords' header + +class SyncMode(Enum): + """ + Sync mode + """ + ADD_REMOVE = 0 # Allow add & remove tags/keywords + ADD_ONLY = 1 # Only allow add tags/keywords + REMOVE_ONLY = 2 # Only allow remove tags/keywords + + +class KwMessage: + """ + Message class to deal with 'X-Keywords' header synchronization + with notmuch tags. + + NOTE: + * The same message may have multiple files with different keywords + (e.g, the same message exported under each label by Gmail) + managed by OfflineIMAP. + For example: a message file in OfflineIMAP synced folder of + '[Gmail]/All Mail' have keywords ['google', 'test']; however, + the file in synced folder 'test' of the same message only have + keywords ['google'] without the keyword 'test'. + * All files associated to the same message are regarded as the same. + The keywords are extracted from all files and merged. + And the same updated keywords are written back to all files, which + results all files finally having the same 'X-Keywords' header. + * You may only sync the '[Gmail]/All Mail' folder without other + folders exported according the labels by Gmail. + """ + # Replace some special characters before mapping keyword to tag + enable_replace_chars = True + chars_replace = { + '/' : '.', + } + # Mapping between (Gmail) keywords and notmuch tags (before ignoring tags) + keywords_mapping = { + '\\Inbox' : 'inbox', + '\\Important' : 'important', + '\\Starred' : 'flagged', + '\\Sent' : 'sent', + '\\Muted' : 'killed', + '\\Draft' : 'draft', + '\\Trash' : 'deleted', + '\\Junk' : 'spam', + } + # Tags ignored from syncing + # These tags are either internal tags or tags handled by maildir flags. + enable_ignore_tags = True + tags_ignored = set([ + 'new', 'unread', 'attachment', 'signed', 'encrypted', + 'flagged', 'replied', 'passed', 'draft', + ]) + # Ignore case when merging tags + tags_ignorecase = True + + # Whether the tags updated against the message 'X-Keywords' header + tags_updated = False + # Added & removed tags for notmuch database against 'X-Keywords' + tags_added = [] + tags_removed = [] + # Newly updated/merged notmuch tags against 'X-Keywords' + tags_new = [] + + # Whether the keywords updated against the notmuch tags + keywords_updated = False + # Added & removed tags for 'X-Keywords' against notmuch database + tags_kw_added = [] + tags_kw_removed = [] + # Newly updated/merged tags for 'X-Keywords' against notmuch database + tags_kw_new = [] + + def __init__(self, msg, filename=None): + self.message = msg + self.filename = filename + self.allfiles = [ fn for fn in msg.get_filenames() ] + self.tags = set(msg.get_tags()) + + def sync(self, direction, mode=SyncMode.ADD_REMOVE, + dryrun=False, verbose=False): + """ + Wrapper function to sync between 'X-Keywords' and notmuch tags. + """ + if direction == SyncDirection.KEYWORDS_TO_TAGS: + self.sync_keywords_to_tags(sync_mode=mode, dryrun=dryrun, + verbose=verbose) + elif direction == SyncDirection.TAGS_TO_KEYWORDS: + self.sync_tags_to_keywords(sync_mode=mode, dryrun=dryrun, + verbose=verbose) + elif direction == SyncDirection.MERGE_KEYWORDS_TAGS: + self.merge_keywords_tags(sync_mode=mode, dryrun=dryrun, + verbose=verbose) + else: + raise ValueError("Invalid sync direction: %s" % direction) + + def sync_keywords_to_tags(self, sync_mode=SyncMode.ADD_REMOVE, + dryrun=False, verbose=False): + """ + Wrapper function to sync 'X-Keywords' to notmuch tags. + """ + self.get_keywords() + self.map_keywords() + self.merge_tags(sync_direction=SyncDirection.KEYWORDS_TO_TAGS, + sync_mode=sync_mode) + if dryrun or verbose: + print('* MSG: %s' % self.message) + print(' TAG: [%s] +[%s] -[%s] => [%s]' % ( + ','.join(self.tags), ','.join(self.tags_added), + ','.join(self.tags_removed), ','.join(self.tags_new))) + if not dryrun: + self.update_tags() + + def sync_tags_to_keywords(self, sync_mode=SyncMode.ADD_REMOVE, + dryrun=False, verbose=False): + """ + Wrapper function to sync notmuch tags to 'X-Keywords' + """ + self.get_keywords() + self.map_keywords() + self.merge_tags(sync_direction=SyncDirection.TAGS_TO_KEYWORDS, + sync_mode=sync_mode) + keywords_new = self.map_tags(tags=self.tags_kw_new) + if dryrun or verbose: + print('* MSG: %s' % self.message) + print('* FILES: %s' % ' ; '.join(self.allfiles)) + print(' XKW: {%s} +[%s] -[%s] => {%s}' % ( + ','.join(self.keywords), ','.join(self.tags_kw_added), + ','.join(self.tags_kw_removed), ','.join(keywords_new))) + if not dryrun: + self.update_keywords(keywords_new=keywords_new) + + def merge_keywords_tags(self, sync_mode=SyncMode.ADD_REMOVE, + dryrun=False, verbose=False): + """ + Wrapper function to merge 'X-Keywords' and notmuch tags + """ + self.get_keywords() + self.map_keywords() + self.merge_tags(sync_direction=SyncDirection.MERGE_KEYWORDS_TAGS, + sync_mode=sync_mode) + keywords_new = self.map_tags(tags=self.tags_kw_new) + if dryrun or verbose: + print('* MSG: %s' % self.message) + print('* FILES: %s' % ' ; '.join(self.allfiles)) + print(' TAG: [%s] +[%s] -[%s] => [%s]' % ( + ','.join(self.tags), ','.join(self.tags_added), + ','.join(self.tags_removed), ','.join(self.tags_new))) + print(' XKW: {%s} +[%s] -[%s] => {%s}' % ( + ','.join(self.keywords), ','.join(self.tags_kw_added), + ','.join(self.tags_kw_removed), ','.join(keywords_new))) + if not dryrun: + self.update_tags() + self.update_keywords(keywords_new=keywords_new) + + def get_keywords(self): + """ + Get 'X-Keywords' header from all files associated with the same + message, decode, split and merge. + + NOTE: Do NOT simply use the `message.get_header()` method, which + cannot get the complete keywords from all files. + """ + keywords_utf7 = [] + for fn in self.allfiles: + msg = email.message_from_file(open(fn, 'r')) + val = msg['X-Keywords'] + if val: + keywords_utf7.append(val) + else: + print("WARNING: 'X-Keywords' header not found or empty " +\ + "for file: %s" % fn, file=sys.stderr) + keywords_utf7 = ','.join(keywords_utf7) + if keywords_utf7 != '': + keywords = imapUTF7Decode(keywords_utf7.encode()).split(',') + keywords = [ kw.strip() for kw in keywords ] + # Remove duplications + keywords = set(keywords) + else: + keywords = set() + self.keywords = keywords + return keywords + + def map_keywords(self, keywords=None): + """ + Map keywords to notmuch tags according to the mapping table. + """ + if keywords is None: + keywords = self.keywords + if self.enable_replace_chars: + # Replace specified characters in keywords + trans = str.maketrans(self.chars_replace) + keywords = [ kw.translate(trans) for kw in keywords ] + # Map keywords to tags + tags = set([ self.keywords_mapping.get(kw, kw) for kw in keywords ]) + self.tags_kw = tags + return tags + + def map_tags(self, tags=None): + """ + Map tags to keywords according to the inversed mapping table. + """ + if tags is None: + tags = self.tags + if self.enable_replace_chars: + # Inversely replace specified characters in tags + chars_replace_inv = { v: k for k, v in self.chars_replace.items() } + trans = str.maketrans(chars_replace_inv) + tags = [ tag.translate(trans) for tag in tags ] + # Map keywords to tags + keywords_mapping_inv = { v:k for k,v in self.keywords_mapping.items() } + keywords = set([ keywords_mapping_inv.get(tag, tag) for tag in tags ]) + self.keywords_tags = keywords + return keywords + + def merge_tags(self, sync_direction, sync_mode=SyncMode.ADD_REMOVE, + tags_nm=None, tags_kw=None): + """ + Merge the tags from notmuch database and 'X-Keywords' header, + according to the specified sync direction and operation restriction. + + TODO: support case-insensitive set operations + """ + # Added & removed tags for notmuch database against 'X-Keywords' + tags_added = [] + tags_removed = [] + # Newly updated/merged notmuch tags against 'X-Keywords' + tags_new = [] + # Added & removed tags for 'X-Keywords' against notmuch database + tags_kw_added = [] + tags_kw_removed = [] + # Newly updated/merged tags for 'X-Keywords' against notmuch database + tags_kw_new = [] + # + if tags_nm is None: + tags_nm = self.tags + if tags_kw is None: + tags_kw = self.tags_kw + if self.enable_ignore_tags: + # Remove ignored tags before merge + tags_nm2 = tags_nm.difference(self.tags_ignored) + tags_kw2 = tags_kw.difference(self.tags_ignored) + else: + tags_nm2 = tags_nm + tags_kw2 = tags_kw + # + if sync_direction == SyncDirection.KEYWORDS_TO_TAGS: + # Sync 'X-Keywords' to notmuch tags + tags_added = tags_kw2.difference(tags_nm2) + tags_removed = tags_nm2.difference(tags_kw2) + elif sync_direction == SyncDirection.TAGS_TO_KEYWORDS: + # Sync notmuch tags to 'X-Keywords' + tags_kw_added = tags_nm2.difference(tags_kw2) + tags_kw_removed = tags_kw2.difference(tags_nm2) + elif sync_direction == SyncDirection.MERGE_KEYWORDS_TAGS: + # Merge both notmuch tags and 'X-Keywords' + tags_merged = tags_nm2.union(tags_kw2) + # notmuch tags + tags_added = tags_merged.difference(tags_nm2) + tags_removed = tags_nm2.difference(tags_merged) + # tags for 'X-Keywords' + tags_kw_added = tags_merged.difference(tags_kw2) + tags_kw_removed = tags_kw2.difference(tags_merged) + else: + raise ValueError("Invalid synchronization direction") + # Apply sync operation restriction + self.tags_added = [] + self.tags_removed = [] + self.tags_kw_added = [] + self.tags_kw_removed = [] + tags_new = tags_nm # Use un-ignored notmuch tags + tags_kw_new = tags_kw # Use un-ignored 'X-Keywords' tags + if sync_mode != SyncMode.REMOVE_ONLY: + self.tags_added = tags_added + self.tags_kw_added = tags_kw_added + tags_new = tags_new.union(tags_added) + tags_kw_new = tags_kw_new.union(tags_kw_added) + if sync_mode != SyncMode.ADD_ONLY: + self.tags_removed = tags_removed + self.tags_kw_removed = tags_kw_removed + tags_new = tags_new.difference(tags_removed) + tags_kw_new = tags_kw_new.difference(tags_kw_removed) + # + self.tags_new = tags_new + self.tags_kw_new = tags_kw_new + if self.tags_added or self.tags_removed: + self.tags_updated = True + if self.tags_kw_added or self.tags_kw_removed: + self.keywords_updated = True + # + return { + 'tags_updated' : self.tags_updated, + 'tags_added' : self.tags_added, + 'tags_removed' : self.tags_removed, + 'tags_new' : self.tags_new, + 'keywords_updated' : self.keywords_updated, + 'tags_kw_added' : self.tags_kw_added, + 'tags_kw_removed' : self.tags_kw_removed, + 'tags_kw_new' : self.tags_kw_new, + } + + def update_keywords(self, keywords_new=None, outfile=None): + """ + Encode the keywords (default: self.keywords_new) and write back to + all message files. + + If parameter 'outfile' specified, then write the updated message + to that file instead of overwriting. + + NOTE: + * The modification time of the message file should be kept to prevent + OfflineIMAP from treating it as a new one (and the previous a + deleted one). + * All files associated with the same message are updated to have + the same 'X-Keywords' header. + """ + if not self.keywords_updated: + # keywords NOT updated, just skip + return + + if keywords_new is None: + keywords_new = self.keywords_new + # + if outfile is not None: + infile = self.allfiles[0:1] + outfile = [ os.path.expanduser(outfile) ] + else: + infile = self.allfiles + outfile = self.allfiles + # + for ifname, ofname in zip(infile, outfile): + msg = email.message_from_file(open(ifname, 'r')) + fstat = os.stat(ifname) + if keywords_new == []: + # Delete 'X-Keywords' header + print("WARNING: delete 'X-Keywords' header from file: %s" % + ifname, file=sys.stderr) + del msg['X-Keywords'] + else: + # Update 'X-Keywords' header + keywords = ','.join(keywords_new) + keywords_utf7 = imapUTF7Encode(keywords).decode() + # Delete then add, to avoid multiple occurrences + del msg['X-Keywords'] + msg['X-Keywords'] = keywords_utf7 + # Write updated message + with open(ofname, 'w') as fp: + fp.write(msg.as_string()) + # Reset the timestamps + os.utime(ofname, ns=(fstat.st_atime_ns, fstat.st_mtime_ns)) + + def update_tags(self, tags_added=None, tags_removed=None): + """ + Update notmuch tags according to keywords. + """ + if not self.tags_updated: + # tags NOT updated, just skip + return + + if tags_added is None: + tags_added = self.tags_added + if tags_removed is None: + tags_removed = self.tags_removed + # Use freeze/thaw for safer transactions to change tag values. + self.message.freeze() + for tag in tags_added: + self.message.add_tag(tag, sync_maildir_flags=False) + for tag in tags_removed: + self.message.remove_tag(tag, sync_maildir_flags=False) + self.message.thaw() + + +def get_notmuch_revision(dbpath=None): + """ + Get the current revision and UUID of notmuch database. + """ + import subprocess + import tempfile + if dbpath: + tf = tempfile.NamedTemporaryFile() + # Create a minimal notmuch config for the specified dbpath + config = '[database]\npath=%s\n' % os.path.expanduser(dbpath) + tf.file.write(config.encode()) + tf.file.flush() + cmd = 'notmuch --config=%s count --lastmod' % tf.name + output = subprocess.check_output(cmd, shell=True) + tf.close() + else: + cmd = 'notmuch count --lastmod' + output = subprocess.check_output(cmd, shell=True) + # Extract output + dbinfo = output.decode().split() + return { 'revision': int(dbinfo[2]), 'uuid': dbinfo[1] } + + +def main(): + parser = argparse.ArgumentParser( + description="Sync message 'X-Keywords' header with notmuch tags.") + parser.add_argument("-V", "--version", action="version", + version="%(prog)s " + "v%s (%s)" % (__version__, __date__)) + parser.add_argument("-q", "--query", dest="query", required=True, + help="notmuch database query string") + parser.add_argument("-p", "--db-path", dest="dbpath", + help="notmuch database path (default to try user configuration)") + parser.add_argument("-n", "--dry-run", dest="dryrun", + action="store_true", help="dry run") + parser.add_argument("-v", "--verbose", dest="verbose", + action="store_true", help="show verbose information") + # Exclusive argument group for sync mode + exgroup1 = parser.add_mutually_exclusive_group(required=True) + exgroup1.add_argument("-m", "--merge-keywords-tags", + dest="direction_merge", action="store_true", + help="merge 'X-Keywords' and tags and update both") + exgroup1.add_argument("-k", "--keywords-to-tags", + dest="direction_keywords2tags", action="store_true", + help="sync 'X-Keywords' to notmuch tags") + exgroup1.add_argument("-t", "--tags-to-keywords", + dest="direction_tags2keywords", action="store_true", + help="sync notmuch tags to 'X-Keywords'") + # Exclusive argument group for tag operation mode + exgroup2 = parser.add_mutually_exclusive_group(required=False) + exgroup2.add_argument("-a", "--add-only", dest="mode_addonly", + action="store_true", help="only add notmuch tags") + exgroup2.add_argument("-r", "--remove-only", dest="mode_removeonly", + action="store_true", help="only remove notmuch tags") + # Parse + args = parser.parse_args() + # Sync direction + if args.direction_merge: + sync_direction = SyncDirection.MERGE_KEYWORDS_TAGS + elif args.direction_keywords2tags: + sync_direction = SyncDirection.KEYWORDS_TO_TAGS + elif args.direction_tags2keywords: + sync_direction = SyncDirection.TAGS_TO_KEYWORDS + else: + raise ValueError("Invalid synchronization direction") + # Sync mode + if args.mode_addonly: + sync_mode = SyncMode.ADD_ONLY + elif args.mode_removeonly: + sync_mode = SyncMode.REMOVE_ONLY + else: + sync_mode = SyncMode.ADD_REMOVE + # + if args.dbpath: + dbpath = os.path.abspath(os.path.expanduser(args.dbpath)) + else: + dbpath = None + # + db = Database(path=dbpath, create=False, mode=Database.MODE.READ_WRITE) + dbinfo = get_notmuch_revision(dbpath=dbpath) + q = Query(db, args.query) + total_msgs = q.count_messages() + msgs = q.search_messages() + # + if args.verbose: + print("# Notmuch database path: %s" % dbpath) + print("# Database revision: %d (uuid: %s)" % + (dbinfo['revision'], dbinfo['uuid'])) + print("# Query: %s" % args.query) + print("# Sync direction: %s" % sync_direction.name) + print("# Sync mode: %s" % sync_mode.name) + print("# Total messages to check: %d" % total_msgs) + print("# Dryn run: %s" % args.dryrun) + # + for msg in msgs: + kwmsg = KwMessage(msg) + kwmsg.sync(direction=sync_direction, mode=sync_mode, + dryrun=args.dryrun, verbose=args.verbose) + # + db.close() + + +if __name__ == "__main__": + main() + +# vim: set ts=4 sw=4 tw=0 fenc= ft=python: # |