aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xpython/adjust_spectrum_error.py170
-rwxr-xr-xpython/correct_crosstalk.py319
-rw-r--r--python/imapUTF7.py189
-rw-r--r--python/plot.py35
-rw-r--r--python/plot_tprofiles_zzh.py126
-rwxr-xr-xpython/randomize_events.py72
-rw-r--r--python/sbp_fit.py613
-rw-r--r--python/starlet.py513
-rw-r--r--python/xkeywordsync.py533
9 files changed, 2570 insertions, 0 deletions
diff --git a/python/adjust_spectrum_error.py b/python/adjust_spectrum_error.py
new file mode 100755
index 0000000..0f80ec7
--- /dev/null
+++ b/python/adjust_spectrum_error.py
@@ -0,0 +1,170 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+"""
+Squeeze the spectrum according to the grouping specification, then
+calculate the statistical errors for each group, and apply error
+adjustments (e.g., incorporate the systematic uncertainties).
+"""
+
+__version__ = "0.1.0"
+__date__ = "2016-01-11"
+
+
+import sys
+import argparse
+
+import numpy as np
+from astropy.io import fits
+
+
+class Spectrum:
+ """
+ Spectrum class to keep spectrum information and perform manipulations.
+ """
+ header = None
+ channel = None
+ counts = None
+ grouping = None
+ quality = None
+
+ def __init__(self, specfile):
+ f = fits.open(specfile)
+ spechdu = f['SPECTRUM']
+ self.header = spechdu.header
+ self.channel = spechdu.data.field('CHANNEL')
+ self.counts = spechdu.data.field('COUNTS')
+ self.grouping = spechdu.data.field('GROUPING')
+ self.quality = spechdu.data.field('QUALITY')
+ f.close()
+
+ def squeezeByGrouping(self):
+ """
+ Squeeze the spectrum according to the grouping specification,
+ i.e., sum the counts belonging to the same group, and place the
+ sum as the first channel within each group with other channels
+ of counts zero's.
+ """
+ counts_squeezed = []
+ cnt_sum = 0
+ cnt_num = 0
+ first = True
+ for grp, cnt in zip(self.grouping, self.counts):
+ if first and grp == 1:
+ # first group
+ cnt_sum = cnt
+ cnt_num = 1
+ first = False
+ elif grp == 1:
+ # save previous group
+ counts_squeezed.append(cnt_sum)
+ counts_squeezed += [ 0 for i in range(cnt_num-1) ]
+ # start new group
+ cnt_sum = cnt
+ cnt_num = 1
+ else:
+ # group continues
+ cnt_sum += cnt
+ cnt_num += 1
+ # last group
+ # save previous group
+ counts_squeezed.append(cnt_sum)
+ counts_squeezed += [ 0 for i in range(cnt_num-1) ]
+ self.counts_squeezed = np.array(counts_squeezed, dtype=np.int32)
+
+ def calcStatErr(self, gehrels=False):
+ """
+ Calculate the statistical errors for the grouped channels,
+ and save as the STAT_ERR column.
+ """
+ idx_nz = np.nonzero(self.counts_squeezed)
+ stat_err = np.zeros(self.counts_squeezed.shape)
+ if gehrels:
+ # Gehrels
+ stat_err[idx_nz] = 1 + np.sqrt(self.counts_squeezed[idx_nz] + 0.75)
+ else:
+ stat_err[idx_nz] = np.sqrt(self.counts_squeezed[idx_nz])
+ self.stat_err = stat_err
+
+ @staticmethod
+ def parseSysErr(syserr):
+ """
+ Parse the string format of syserr supplied in the commandline.
+ """
+ items = map(str.strip, syserr.split(','))
+ syserr_spec = []
+ for item in items:
+ spec = item.split(':')
+ try:
+ spec = (int(spec[0]), int(spec[1]), float(spec[2]))
+ except:
+ raise ValueError("invalid syserr specficiation")
+ syserr_spec.append(spec)
+ return syserr_spec
+
+ def applySysErr(self, syserr):
+ """
+ Apply systematic error adjustments to the above calculated
+ statistical errors.
+ """
+ syserr_spec = self.parseSysErr(syserr)
+ for lo, hi, se in syserr_spec:
+ err_adjusted = self.stat_err[(lo-1):(hi-1)] * np.sqrt(1+se)
+ self.stat_err_adjusted = err_adjusted
+
+ def updateHeader(self):
+ """
+ Update header accordingly.
+ """
+ # POISSERR
+ self.header['POISSERR'] = False
+
+ def write(self, filename, clobber=False):
+ """
+ Write the updated/modified spectrum block to file.
+ """
+ channel_col = fits.Column(name='CHANNEL', format='J',
+ array=self.channel)
+ counts_col = fits.Column(name='COUNTS', format='J',
+ array=self.counts_squeezed)
+ stat_err_col = fits.Column(name='STAT_ERR', format='D',
+ array=self.stat_err_adjusted)
+ grouping_col = fits.Column(name='GROUPING', format='I',
+ array=self.grouping)
+ quality_col = fits.Column(name='QUALITY', format='I',
+ array=self.quality)
+ spec_cols = fits.ColDefs([channel_col, counts_col, stat_err_col,
+ grouping_col, quality_col])
+ spechdu = fits.BinTableHDU.from_columns(spec_cols, header=self.header)
+ spechdu.writeto(filename, clobber=clobber)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Apply systematic error adjustments to spectrum.")
+ parser.add_argument("-V", "--version", action="version",
+ version="%(prog)s " + "%s (%s)" % (__version__, __date__))
+ parser.add_argument("infile", help="input spectrum file")
+ parser.add_argument("outfile", help="output adjusted spectrum file")
+ parser.add_argument("-e", "--syserr", dest="syserr", required=True,
+ help="systematic error specification; " + \
+ "syntax: ch1low:ch1high:syserr1,...")
+ parser.add_argument("-C", "--clobber", dest="clobber",
+ action="store_true", help="overwrite output file if exists")
+ parser.add_argument("-G", "--gehrels", dest="gehrels",
+ action="store_true", help="use Gehrels error?")
+ args = parser.parse_args()
+
+ spec = Spectrum(args.infile)
+ spec.squeezeByGrouping()
+ spec.calcStatErr(gehrels=args.gehrels)
+ spec.applySysErr(syserr=args.syserr)
+ spec.updateHeader()
+ spec.write(args.outfile, clobber=args.clobber)
+
+
+if __name__ == "__main__":
+ main()
+
+
+# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: #
diff --git a/python/correct_crosstalk.py b/python/correct_crosstalk.py
new file mode 100755
index 0000000..4c7f820
--- /dev/null
+++ b/python/correct_crosstalk.py
@@ -0,0 +1,319 @@
+#!/usr/bin/env python3
+#
+# Correct the crosstalk effect of XMM spectra by subtracting the
+# scattered photons from surrounding regions, and by compensating
+# the photons scattered to surrounding regions, according to the
+# generated crosstalk ARFs.
+#
+# Sample config file (in `ConfigObj' syntax):
+#----------------------
+# fix_negative = True
+# verbose = True
+# clobber = False
+#
+# [reg2]
+# outfile = cc_reg2.pi
+# spec = reg2.pi
+# arf = reg2.arf
+# [[cross_in]]
+# [[[in1]]]
+# spec = reg1.pi
+# arf = reg1.arf
+# cross_arf = reg_1-2.arf
+# [[[in2]]]
+# spec = reg3.pi
+# arf = reg3.arf
+# cross_arf = reg_3-2.arf
+# [[cross_out]]
+# cross_arf = reg_2-1.arf, reg_2-3.arf
+#----------------------
+#
+# Weitian LI
+# Created: 2016-03-26
+# Updated: 2016-03-28
+#
+
+from astropy.io import fits
+import numpy as np
+from configobj import ConfigObj
+
+import sys
+import os
+import argparse
+from datetime import datetime
+
+
+class ARF:
+ """
+ Deal with X-ray ARF file (.arf)
+ """
+ filename = None
+ fitsobj = None
+ header = None
+ energ_lo = None
+ energ_hi = None
+ specresp = None
+
+ def __init__(self, filename):
+ self.filename = filename
+ self.fitsobj = fits.open(filename)
+ ext_specresp = self.fitsobj["SPECRESP"]
+ self.header = ext_specresp.header
+ self.energ_lo = ext_specresp.data["ENERG_LO"]
+ self.energ_hi = ext_specresp.data["ENERG_HI"]
+ self.specresp = ext_specresp.data["SPECRESP"]
+
+ def get_data(self, copy=True):
+ if copy:
+ return self.specresp.copy()
+ else:
+ return self.specresp
+
+
+class Spectrum:
+ """
+ Deal with X-ray spectrum (.pi)
+
+ NOTE:
+ The "COUNTS" column data are converted from "int32" to "float32".
+ """
+ filename = None
+ # FITS object return by `fits.open'
+ fitsobj = None
+ # header of "SPECTRUM" extension
+ header = None
+ # "SPECTRUM" extension data
+ channel = None
+ # name of the column containing the spectrum data, either "COUNTS" or "RATE"
+ spec_colname = None
+ # spectrum data
+ spec_data = None
+ # ARF object for this spectrum
+ arf = None
+
+ def __init__(self, filename, arffile):
+ self.filename = filename
+ self.fitsobj = fits.open(filename)
+ ext_spec = self.fitsobj["SPECTRUM"]
+ self.header = ext_spec.header.copy(strip=True)
+ colnames = ext_spec.columns.names
+ if "COUNTS" in colnames:
+ self.spec_colname = "COUNTS"
+ elif "RATE" in colnames:
+ self.spec_colname = "RATE"
+ else:
+ raise ValueError("Invalid spectrum file")
+ self.channel = ext_spec.data["CHANNEL"].copy()
+ self.spec_data = ext_spec.data.field(self.spec_colname)\
+ .astype(np.float32)
+ self.arf = ARF(arffile)
+
+ def get_data(self, copy=True):
+ if copy:
+ return self.spec_data.copy()
+ else:
+ return self.spec_data
+
+ def get_arf(self, copy=True):
+ if self.arf is None:
+ return None
+ else:
+ return self.arf.get_data(copy=copy)
+
+ def subtract(self, spectrum, cross_arf, verbose=False):
+ """
+ Subtract the photons that originate from the surrounding regions
+ but were scattered into this spectrum due to the finite PSF.
+
+ NOTE:
+ The crosstalk ARF must be provided, since the `spectrum.arf' is
+ required to be its ARF without taking crosstalk into account:
+ spec1_new = spec1 - spec2 * (cross_arf_2_to_1 / arf2)
+ """
+ operation = " SUBTRACT: %s - (%s/%s) * %s" % (self.filename,
+ cross_arf.filename, spectrum.arf.filename, spectrum.filename)
+ if verbose:
+ print(operation, file=sys.stderr)
+ arf_ratio = cross_arf.get_data() / spectrum.get_arf()
+ arf_ratio[np.isnan(arf_ratio)] = 0.0
+ self.spec_data -= spectrum.get_data() * arf_ratio
+ # record history
+ self.header.add_history(operation)
+
+ def compensate(self, cross_arf, verbose=False):
+ """
+ Compensate the photons that originate from this regions but were
+ scattered into the surrounding regions due to the finite PSF.
+
+ formula:
+ spec1_new = spec1 + spec1 * (cross_arf_1_to_2 / arf1)
+ """
+ operation = " COMPENSATE: %s + (%s/%s) * %s" % (self.filename,
+ cross_arf.filename, self.arf.filename, self.filename)
+ if verbose:
+ print(operation, file=sys.stderr)
+ arf_ratio = cross_arf.get_data() / self.get_arf()
+ arf_ratio[np.isnan(arf_ratio)] = 0.0
+ self.spec_data += self.get_data() * arf_ratio
+ # record history
+ self.header.add_history(operation)
+
+ def fix_negative(self, verbose=False):
+ """
+ The subtractions may lead to negative counts, it may be necessary
+ to fix these channels with negative values.
+ """
+ neg_counts = self.spec_data < 0
+ N = len(neg_counts)
+ neg_channels = np.arange(N, dtype=np.int)[neg_counts]
+ if len(neg_channels) > 0:
+ print("WARNING: %d channels have NEGATIVE counts" % \
+ len(neg_channels), file=sys.stderr)
+ i = 0
+ while len(neg_channels) > 0:
+ i += 1
+ if verbose:
+ print("*** Fixing negative channels: iteration %d ..." % i,
+ file=sys.stderr)
+ for ch in neg_channels:
+ neg_val = self.spec_data[ch]
+ if ch < N-2:
+ self.spec_data[ch] = 0
+ self.spec_data[(ch+1):(ch+3)] -= 0.5 * np.abs(neg_val)
+ else:
+ # just set to zero if it is the last 2 channels
+ self.spec_data[ch] = 0
+ # update negative channels indices
+ neg_counts = self.spec_data < 0
+ neg_channels = np.arange(N, dtype=np.int)[neg_counts]
+ if i > 0:
+ print("*** Fixed negative channels ***", file=sys.stderr)
+
+ def write(self, filename, clobber=False):
+ """
+ Create a new "SPECTRUM" table/extension and replace the original
+ one, then write to output file.
+ """
+ ext_spec_cols = fits.ColDefs([
+ fits.Column(name="CHANNEL", format="I", array=self.channel),
+ fits.Column(name="COUNTS", format="E", unit="count",
+ array=self.spec_data)])
+ ext_spec = fits.BinTableHDU.from_columns(ext_spec_cols,
+ header=self.header)
+ self.fitsobj["SPECTRUM"] = ext_spec
+ self.fitsobj.writeto(filename, clobber=clobber, checksum=True)
+
+
+class Crosstalk:
+ """
+ Crosstalk correction.
+ """
+ # `Spectrum' object for the spectrum to be corrected
+ spectrum = None
+ # XXX: do NOT use list (e.g., []) here, otherwise, all the instances
+ # will share these list properties.
+ # `Spectrum' and `ARF' objects corresponding to the spectra from which
+ # the photons were scattered into this spectrum.
+ cross_in_spec = None
+ cross_in_arf = None
+ # `ARF' objects corresponding to the regions to which the photons of
+ # this spectrum were scattered into.
+ cross_out_arf = None
+ # output filename to which write the corrected spectrum
+ outfile = None
+
+ def __init__(self, config):
+ """
+ `config': a section of the whole config file (`ConfigObj` object).
+ """
+ self.cross_in_spec = []
+ self.cross_in_arf = []
+ self.cross_out_arf = []
+ # this spectrum to be corrected
+ self.spectrum = Spectrum(config["spec"], config["arf"])
+ # spectra and cross arf from which photons were scattered in
+ for reg_in in config["cross_in"].values():
+ spec = Spectrum(reg_in["spec"], reg_in["arf"])
+ self.cross_in_spec.append(spec)
+ self.cross_in_arf.append(ARF(reg_in["cross_arf"]))
+ # regions into which the photons of this spectrum were scattered into
+ if "cross_out" in config.sections:
+ cross_arf = config["cross_out"].as_list("cross_arf")
+ for arffile in cross_arf:
+ self.cross_out_arf.append(ARF(arffile))
+ # output filename
+ self.outfile = config["outfile"]
+
+ def do_correction(self, fix_negative=False, verbose=False):
+ self.spectrum.header.add_history("Crosstalk Correction BEGIN")
+ self.spectrum.header.add_history(" TOOL: %s @ %s" % (\
+ os.path.basename(sys.argv[0]), datetime.utcnow().isoformat()))
+ # subtractions
+ if verbose:
+ print("INFO: apply subtractions ...", file=sys.stderr)
+ for spec, cross_arf in zip(self.cross_in_spec, self.cross_in_arf):
+ self.spectrum.subtract(spectrum=spec, cross_arf=cross_arf,
+ verbose=verbose)
+ # compensations
+ if verbose:
+ print("INFO: apply compensations ...", file=sys.stderr)
+ for cross_arf in self.cross_out_arf:
+ self.spectrum.compensate(cross_arf=cross_arf, verbose=verbose)
+ # fix negative values in channels
+ if fix_negative:
+ if verbose:
+ print("INFO: fix negative channel values ...", file=sys.stderr)
+ self.spectrum.fix_negative(verbose=verbose)
+ self.spectrum.header.add_history("END Crosstalk Correction")
+
+ def write(self, filename=None, clobber=False):
+ if filename is None:
+ filename = self.outfile
+ self.spectrum.write(filename, clobber=clobber)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Correct the crosstalk effects of XMM spectra")
+ parser.add_argument("config", help="config file in which describes " +\
+ "the crosstalk relations. ('ConfigObj' syntax)")
+ parser.add_argument("-N", "--fix-negative", dest="fix_negative",
+ action="store_true", help="fix negative channel values")
+ parser.add_argument("-C", "--clobber", dest="clobber",
+ action="store_true", help="overwrite output file if exists")
+ parser.add_argument("-v", "--verbose", dest="verbose",
+ action="store_true", help="show verbose information")
+ args = parser.parse_args()
+
+ config = ConfigObj(args.config)
+
+ fix_negative = False
+ if "fix_negative" in config.keys():
+ fix_negative = config.as_bool("fix_negative")
+ if args.fix_negative:
+ fix_negative = args.fix_negative
+
+ verbose = False
+ if "verbose" in config.keys():
+ verbose = config.as_bool("verbose")
+ if args.verbose:
+ verbose = args.verbose
+
+ clobber = False
+ if "clobber" in config.keys():
+ clobber = config.as_bool("clobber")
+ if args.clobber:
+ clobber = args.clobber
+
+ for region in config.sections:
+ if verbose:
+ print("INFO: processing '%s' ..." % region, file=sys.stderr)
+ crosstalk = Crosstalk(config.get(region))
+ crosstalk.do_correction(fix_negative=fix_negative, verbose=verbose)
+ crosstalk.write(clobber=clobber)
+
+
+if __name__ == "__main__":
+ main()
+
+# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: #
diff --git a/python/imapUTF7.py b/python/imapUTF7.py
new file mode 100644
index 0000000..2e4db0a
--- /dev/null
+++ b/python/imapUTF7.py
@@ -0,0 +1,189 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# This code was originally in PloneMailList, a GPL'd software.
+# http://svn.plone.org/svn/collective/mxmImapClient/trunk/imapUTF7.py
+# http://bugs.python.org/issue5305
+#
+# Port to Python 3.x
+# Credit: https://github.com/MarechJ/py3_imap_utf7
+#
+# 2016-01-23
+# Aaron LI
+#
+
+"""
+Imap folder names are encoded using a special version of utf-7 as defined in RFC
+2060 section 5.1.3.
+
+5.1.3. Mailbox International Naming Convention
+
+ By convention, international mailbox names are specified using a
+ modified version of the UTF-7 encoding described in [UTF-7]. The
+ purpose of these modifications is to correct the following problems
+ with UTF-7:
+
+ 1) UTF-7 uses the "+" character for shifting; this conflicts with
+ the common use of "+" in mailbox names, in particular USENET
+ newsgroup names.
+
+ 2) UTF-7's encoding is BASE64 which uses the "/" character; this
+ conflicts with the use of "/" as a popular hierarchy delimiter.
+
+ 3) UTF-7 prohibits the unencoded usage of "\"; this conflicts with
+ the use of "\" as a popular hierarchy delimiter.
+
+ 4) UTF-7 prohibits the unencoded usage of "~"; this conflicts with
+ the use of "~" in some servers as a home directory indicator.
+
+ 5) UTF-7 permits multiple alternate forms to represent the same
+ string; in particular, printable US-ASCII chararacters can be
+ represented in encoded form.
+
+ In modified UTF-7, printable US-ASCII characters except for "&"
+ represent themselves; that is, characters with octet values 0x20-0x25
+ and 0x27-0x7e. The character "&" (0x26) is represented by the two-
+ octet sequence "&-".
+
+ All other characters (octet values 0x00-0x1f, 0x7f-0xff, and all
+ Unicode 16-bit octets) are represented in modified BASE64, with a
+ further modification from [UTF-7] that "," is used instead of "/".
+ Modified BASE64 MUST NOT be used to represent any printing US-ASCII
+ character which can represent itself.
+
+ "&" is used to shift to modified BASE64 and "-" to shift back to US-
+ ASCII. All names start in US-ASCII, and MUST end in US-ASCII (that
+ is, a name that ends with a Unicode 16-bit octet MUST end with a "-
+ ").
+
+ For example, here is a mailbox name which mixes English, Japanese,
+ and Chinese text: ~peter/mail/&ZeVnLIqe-/&U,BTFw-
+"""
+
+
+import binascii
+import codecs
+
+
+## encoding
+
+def modified_base64(s:str):
+ s = s.encode('utf-16be') # UTF-16, big-endian byte order
+ return binascii.b2a_base64(s).rstrip(b'\n=').replace(b'/', b',')
+
+def doB64(_in, r):
+ if _in:
+ r.append(b'&' + modified_base64(''.join(_in)) + b'-')
+ del _in[:]
+
+def encoder(s:str):
+ r = []
+ _in = []
+ for c in s:
+ ordC = ord(c)
+ if 0x20 <= ordC <= 0x25 or 0x27 <= ordC <= 0x7e:
+ doB64(_in, r)
+ r.append(c.encode())
+ elif c == '&':
+ doB64(_in, r)
+ r.append(b'&-')
+ else:
+ _in.append(c)
+ doB64(_in, r)
+ return (b''.join(r), len(s))
+
+
+## decoding
+
+def modified_unbase64(s:bytes):
+ b = binascii.a2b_base64(s.replace(b',', b'/') + b'===')
+ return b.decode('utf-16be')
+
+def decoder(s:bytes):
+ r = []
+ decode = bytearray()
+ for c in s:
+ if c == ord('&') and not decode:
+ decode.append(ord('&'))
+ elif c == ord('-') and decode:
+ if len(decode) == 1:
+ r.append('&')
+ else:
+ r.append(modified_unbase64(decode[1:]))
+ decode = bytearray()
+ elif decode:
+ decode.append(c)
+ else:
+ r.append(chr(c))
+ if decode:
+ r.append(modified_unbase64(decode[1:]))
+ bin_str = ''.join(r)
+ return (bin_str, len(s))
+
+
+class StreamReader(codecs.StreamReader):
+ def decode(self, s, errors='strict'):
+ return decoder(s)
+
+
+class StreamWriter(codecs.StreamWriter):
+ def decode(self, s, errors='strict'):
+ return encoder(s)
+
+
+def imap4_utf_7(name):
+ if name == 'imap4-utf-7':
+ return (encoder, decoder, StreamReader, StreamWriter)
+
+
+codecs.register(imap4_utf_7)
+
+
+## testing methods
+
+def imapUTF7Encode(ust):
+ "Returns imap utf-7 encoded version of string"
+ return ust.encode('imap4-utf-7')
+
+def imapUTF7EncodeSequence(seq):
+ "Returns imap utf-7 encoded version of strings in sequence"
+ return [imapUTF7Encode(itm) for itm in seq]
+
+
+def imapUTF7Decode(st):
+ "Returns utf7 encoded version of imap utf-7 string"
+ return st.decode('imap4-utf-7')
+
+def imapUTF7DecodeSequence(seq):
+ "Returns utf7 encoded version of imap utf-7 strings in sequence"
+ return [imapUTF7Decode(itm) for itm in seq]
+
+
+def utf8Decode(st):
+ "Returns utf7 encoded version of imap utf-7 string"
+ return st.decode('utf-8')
+
+
+def utf7SequenceToUTF8(seq):
+ "Returns utf7 encoded version of imap utf-7 strings in sequence"
+ return [itm.decode('imap4-utf-7').encode('utf-8') for itm in seq]
+
+
+__all__ = [ 'imapUTF7Encode', 'imapUTF7Decode' ]
+
+
+if __name__ == '__main__':
+ testdata = [
+ (u'foo\r\n\nbar\n', b'foo&AA0ACgAK-bar&AAo-'),
+ (u'测试', b'&bUuL1Q-'),
+ (u'Hello 世界', b'Hello &ThZ1TA-')
+ ]
+ for s, e in testdata:
+ #assert s == decoder(encoder(s)[0])[0]
+ assert s == imapUTF7Decode(e)
+ assert e == imapUTF7Encode(s)
+ assert s == imapUTF7Decode(imapUTF7Encode(s))
+ assert e == imapUTF7Encode(imapUTF7Decode(e))
+ print("All tests passed!")
+
+# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: #
diff --git a/python/plot.py b/python/plot.py
new file mode 100644
index 0000000..b65f8a3
--- /dev/null
+++ b/python/plot.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+#
+# Credits: http://www.aosabook.org/en/matplotlib.html
+#
+# Aaron LI
+# 2016-03-14
+#
+
+# Import the FigureCanvas from the backend of your choice
+# and attach the Figure artist to it.
+from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
+from matplotlib.figure import Figure
+fig = Figure()
+canvas = FigureCanvas(fig)
+
+# Import the numpy library to generate the random numbers.
+import numpy as np
+x = np.random.randn(10000)
+
+# Now use a figure method to create an Axes artist; the Axes artist is
+# added automatically to the figure container fig.axes.
+# Here "111" is from the MATLAB convention: create a grid with 1 row and 1
+# column, and use the first cell in that grid for the location of the new
+# Axes.
+ax = fig.add_subplot(111)
+
+# Call the Axes method hist to generate the histogram; hist creates a
+# sequence of Rectangle artists for each histogram bar and adds them
+# to the Axes container. Here "100" means create 100 bins.
+ax.hist(x, 100)
+
+# Decorate the figure with a title and save it.
+ax.set_title('Normal distribution with $\mu=0, \sigma=1$')
+fig.savefig('matplotlib_histogram.png')
+
diff --git a/python/plot_tprofiles_zzh.py b/python/plot_tprofiles_zzh.py
new file mode 100644
index 0000000..e5824e9
--- /dev/null
+++ b/python/plot_tprofiles_zzh.py
@@ -0,0 +1,126 @@
+# -*- coding: utf-8 -*-
+#
+# Weitian LI
+# 2015-09-11
+#
+
+"""
+Plot a list of *temperature profiles* in a grid of subplots with Matplotlib.
+"""
+
+import matplotlib.pyplot as plt
+
+
+def plot_tprofiles(tplist, nrows, ncols,
+ xlim=None, ylim=None, logx=False, logy=False,
+ xlab="", ylab="", title=""):
+ """
+ Plot a list of *temperature profiles* in a grid of subplots of size
+ nrow x ncol. Each subplot is related to a temperature profile.
+ All the subplots share the same X and Y axes.
+ The order is by row.
+
+ The tplist is a list of dictionaries, each of which contains all the
+ necessary data to make the subplot.
+
+ The dictionary consists of the following components:
+ tpdat = {
+ "name": "NAME",
+ "radius": [[radius points], [radius errors]],
+ "temperature": [[temperature points], [temperature errors]],
+ "radius_model": [radus points of the fitted model],
+ "temperature_model": [
+ [fitted model value],
+ [lower bounds given by the model],
+ [upper bounds given by the model]
+ ]
+ }
+
+ Arguments:
+ tplist - a list of dictionaries containing the data of each
+ temperature profile.
+ Note that the length of this list should equal to nrows*ncols.
+ nrows - number of rows of the subplots
+ ncols - number of columns of the subplots
+ xlim - limits of the X axis
+ ylim - limits of the Y axis
+ logx - whether to set the log scale for X axis
+ logy - whether to set the log scale for Y axis
+ xlab - label for the X axis
+ ylab - label for the Y axis
+ title - title for the whole plot
+ """
+ assert len(tplist) == nrows*ncols, "tplist length != nrows*ncols"
+ # All subplots share both X and Y axes.
+ fig, axarr = plt.subplots(nrows, ncols, sharex=True, sharey=True)
+ # Set title for the whole plot.
+ if title != "":
+ fig.suptitle(title)
+ # Set xlab and ylab for each subplot
+ if xlab != "":
+ for ax in axarr[-1, :]:
+ ax.set_xlabel(xlab)
+ if ylab != "":
+ for ax in axarr[:, 0]:
+ ax.set_ylabel(ylab)
+ for ax in axarr.reshape(-1):
+ # Set xlim and ylim.
+ if xlim is not None:
+ ax.set_xlim(xlim)
+ if ylim is not None:
+ ax.set_ylim(ylim)
+ # Set xscale and yscale.
+ if logx:
+ ax.set_xscale("log", nonposx="clip")
+ if logy:
+ ax.set_yscale("log", nonposy="clip")
+ # Decrease the spacing between the subplots and suptitle
+ fig.subplots_adjust(top=0.94)
+ # Eleminate the spaces between each row and column.
+ fig.subplots_adjust(hspace=0, wspace=0)
+ # Hide X ticks for all subplots but the bottom row.
+ plt.setp([ax.get_xticklabels() for ax in axarr[:-1, :].reshape(-1)],
+ visible=False)
+ # Hide Y ticks for all subplots but the left column.
+ plt.setp([ax.get_yticklabels() for ax in axarr[:, 1:].reshape(-1)],
+ visible=False)
+ # Plot each temperature profile in the tplist
+ for i, ax in zip(range(len(tplist)), axarr.reshape(-1)):
+ tpdat = tplist[i]
+ # Add text to display the name.
+ # The text is placed at (0.95, 0.95), i.e., the top-right corner,
+ # with respect to this subplot, and the top-right part of the text
+ # is aligned to the above position.
+ ax_pois = ax.get_position()
+ ax.text(0.95, 0.95, tpdat["name"],
+ verticalalignment="top", horizontalalignment="right",
+ transform=ax.transAxes, color="black", fontsize=10)
+ # Plot data points
+ if isinstance(tpdat["radius"][0], list) and \
+ len(tpdat["radius"]) == 2 and \
+ isinstance(tpdat["temperature"][0], list) and \
+ len(tpdat["temperature"]) == 2:
+ # Data points have symmetric errorbar
+ ax.errorbar(tpdat["radius"][0], tpdat["temperature"][0],
+ xerr=tpdat["radius"][1], yerr=tpdat["temperature"][1],
+ color="black", linewidth=1.5, linestyle="None")
+ else:
+ ax.plot(tpdat["radius"], tpdat["temperature"],
+ color="black", linewidth=1.5, linestyle="None")
+ # Plot model line and bounds band
+ if isinstance(tpdat["temperature_model"][0], list) and \
+ len(tpdat["temperature_model"]) == 3:
+ # Model data have bounds
+ ax.plot(tpdat["radius_model"], tpdat["temperature_model"][0],
+ color="blue", linewidth=1.0)
+ # Plot model bounds band
+ ax.fill_between(tpdat["radius_model"],
+ y1=tpdat["temperature_model"][1],
+ y2=tpdat["temperature_model"][2],
+ color="gray", alpha=0.5)
+ else:
+ ax.plot(tpdat["radius_model"], tpdat["temperature_model"],
+ color="blue", linewidth=1.5)
+ return (fig, axarr)
+
+# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: #
diff --git a/python/randomize_events.py b/python/randomize_events.py
new file mode 100755
index 0000000..e1a6e31
--- /dev/null
+++ b/python/randomize_events.py
@@ -0,0 +1,72 @@
+#!/usr/bin/env python3
+#
+# Randomize the (X,Y) position of each X-ray photon events according
+# to a Gaussian distribution of given sigma.
+#
+# References:
+# [1] G. Scheellenberger, T.H. Reiprich, L. Lovisari, J. Nevalainen & L. David
+# 2015, A&A, 575, A30
+#
+#
+# Aaron LI
+# Created: 2016-03-24
+# Updated: 2016-03-24
+#
+
+from astropy.io import fits
+import numpy as np
+
+import os
+import sys
+import datetime
+import argparse
+
+
+CHANDRA_ARCSEC_PER_PIXEL = 0.492
+
+def randomize_events(infile, outfile, sigma, clobber=False):
+ """
+ Randomize the position (X,Y) of each X-ray event according to a
+ specified size/sigma Gaussian distribution.
+ """
+ sigma_pix = sigma / CHANDRA_ARCSEC_PER_PIXEL
+ evt_fits = fits.open(infile)
+ evt_table = evt_fits[1].data
+ # (X,Y) physical coordinate
+ evt_x = evt_table["x"]
+ evt_y = evt_table["y"]
+ rand_x = np.random.normal(scale=sigma_pix, size=evt_x.shape)\
+ .astype(evt_x.dtype)
+ rand_y = np.random.normal(scale=sigma_pix, size=evt_y.shape)\
+ .astype(evt_y.dtype)
+ evt_x += rand_x
+ evt_y += rand_y
+ # Add history to FITS header
+ evt_hdr = evt_fits[1].header
+ evt_hdr.add_history("TOOL: %s @ %s" % (
+ os.path.basename(sys.argv[0]),
+ datetime.datetime.utcnow().isoformat()))
+ evt_hdr.add_history("COMMAND: %s" % " ".join(sys.argv))
+ evt_fits.writeto(outfile, clobber=clobber, checksum=True)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Randomize the (X,Y) of each X-ray event")
+ parser.add_argument("infile", help="input event file")
+ parser.add_argument("outfile", help="output randomized event file")
+ parser.add_argument("-s", "--sigma", dest="sigma",
+ required=True, type=float,
+ help="sigma/size of the Gaussian distribution used" + \
+ "to randomize the position of events (unit: arcsec)")
+ parser.add_argument("-C", "--clobber", dest="clobber",
+ action="store_true", help="overwrite output file if exists")
+ args = parser.parse_args()
+
+ randomize_events(args.infile, args.outfile,
+ sigma=args.sigma, clobber=args.clobber)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/python/sbp_fit.py b/python/sbp_fit.py
new file mode 100644
index 0000000..0b7d29a
--- /dev/null
+++ b/python/sbp_fit.py
@@ -0,0 +1,613 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# Aaron LI
+# Created: 2016-03-13
+# Updated: 2016-03-14
+#
+# Changelogs:
+# 2016-03-28:
+# * Add `main()', `make_model()'
+# * Use `configobj' to handle configurations
+# * Save fit results and plot
+# * Add `ci_report()'
+# 2016-03-14:
+# * Refactor classes `FitModelSBeta' and `FitModelDBeta'
+# * Add matplotlib plot support
+# * Add `ignore_data()' and `notice_data()' support
+# * Add classes `FitModelSBetaNorm' and `FitModelDBetaNorm'
+#
+
+"""
+Fit the surface brightness profile (SBP) with the single-beta model:
+ s(r) = s0 * [1.0 + (r/rc)^2] ^ (0.5-3*beta) + bkg
+or the double-beta model:
+ s(r) = s01 * [1.0 + (r/rc1)^2] ^ (0.5-3*beta1) +
+ s02 * [1.0 + (r/rc2)^2] ^ (0.5-3*beta2) + bkg
+
+Sample config file:
+-------------------------------------------------
+name = <NAME>
+obsid = <OBSID>
+
+sbpfile = sbprofile.txt
+
+model = sbeta
+outfile = sbpfit_sbeta.txt
+imgfile = sbpfit_sbeta.png
+
+#model = dbeta
+#outfile = sbpfit_dbeta.txt
+#imgfile = sbpfit_dbeta.png
+
+# data range to be ignored during fitting
+#ignore = 0.0-20.0,
+
+[sbeta]
+# name = initial, lower, upper
+s0 = 1.0e-8, 0.0, 1.0e-6
+rc = 30.0, 1.0, 1.0e4
+beta = 0.7, 0.3, 1.1
+bkg = 1.0e-9, 0.0, 1.0e-7
+
+[dbeta]
+s01 = 1.0e-8, 0.0, 1.0e-6
+rc1 = 50.0, 10.0, 1.0e4
+beta1 = 0.7, 0.3, 1.1
+s02 = 1.0e-8, 0.0, 1.0e-6
+rc2 = 30.0, 1.0, 5.0e2
+beta2 = 0.7, 0.3, 1.1
+bkg = 1.0e-9, 0.0, 1.0e-7
+-------------------------------------------------
+"""
+
+__version__ = "0.4.0"
+__date__ = "2016-03-28"
+
+import numpy as np
+import lmfit
+import matplotlib.pyplot as plt
+
+from configobj import ConfigObj
+from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
+from matplotlib.figure import Figure
+
+import os
+import sys
+import re
+import argparse
+
+
+plt.style.use("ggplot")
+
+
+class FitModel:
+ """
+ Meta-class of the fitting model.
+
+ The supplied `func' should have the following syntax:
+ y = f(x, params)
+ where the `params' is the parameters to be fitted,
+ and should be provided as well.
+ """
+ def __init__(self, name=None, func=None, params=lmfit.Parameters()):
+ self.name = name
+ self.func = func
+ self.params = params
+
+ def f(self, x):
+ return self.func(x, self.params)
+
+ def get_param(self, name=None):
+ """
+ Return the requested `Parameter' object or the whole
+ `Parameters' object of no name supplied.
+ """
+ try:
+ return self.params[name]
+ except KeyError:
+ return self.params
+
+ def set_param(self, name, *args, **kwargs):
+ """
+ Set the properties of the specified parameter.
+ """
+ param = self.params[name]
+ param.set(*args, **kwargs)
+
+ def plot(self, params, xdata, ax):
+ """
+ Plot the fitted model.
+ """
+ f_fitted = lambda x: self.func(x, params)
+ ydata = f_fitted(xdata)
+ ax.plot(xdata, ydata, 'k-')
+
+class FitModelSBeta(FitModel):
+ """
+ The single-beta model to be fitted.
+ Single-beta model, with a constant background.
+ """
+ params = lmfit.Parameters()
+ params.add_many( # (name, value, vary, min, max, expr)
+ ("s0", 1.0e-8, True, 0.0, 1.0e-6, None),
+ ("rc", 30.0, True, 1.0, 1.0e4, None),
+ ("beta", 0.7, True, 0.3, 1.1, None),
+ ("bkg", 1.0e-9, True, 0.0, 1.0e-7, None))
+
+ @staticmethod
+ def sbeta(r, params):
+ parvals = params.valuesdict()
+ s0 = parvals["s0"]
+ rc = parvals["rc"]
+ beta = parvals["beta"]
+ bkg = parvals["bkg"]
+ return s0 * np.power((1 + (r/rc)**2), (0.5 - 3*beta)) + bkg
+
+ def __init__(self):
+ super(self.__class__, self).__init__(name="Single-beta",
+ func=self.sbeta, params=self.params)
+
+ def plot(self, params, xdata, ax):
+ """
+ Plot the fitted model, as well as the fitted parameters.
+ """
+ super(self.__class__, self).plot(params, xdata, ax)
+ ydata = self.sbeta(xdata, params)
+ # fitted paramters
+ ax.vlines(x=params["rc"].value, ymin=min(ydata), ymax=max(ydata),
+ linestyles="dashed")
+ ax.hlines(y=params["bkg"].value, xmin=min(xdata), xmax=max(xdata),
+ linestyles="dashed")
+ ax.text(x=params["rc"].value, y=min(ydata),
+ s="beta: %.2f\nrc: %.2f" % (params["beta"].value,
+ params["rc"].value))
+ ax.text(x=min(xdata), y=min(ydata),
+ s="bkg: %.3e" % params["bkg"].value,
+ verticalalignment="top")
+
+
+class FitModelDBeta(FitModel):
+ """
+ The double-beta model to be fitted.
+ Double-beta model, with a constant background.
+
+ NOTE:
+ the first beta component (s01, rc1, beta1) describes the main and
+ outer SBP; while the second beta component (s02, rc2, beta2) accounts
+ for the central brightness excess.
+ """
+ params = lmfit.Parameters()
+ params.add("s01", value=1.0e-8, min=0.0, max=1.0e-6)
+ params.add("rc1", value=50.0, min=10.0, max=1.0e4)
+ params.add("beta1", value=0.7, min=0.3, max=1.1)
+ #params.add("df_s0", value=1.0e-8, min=0.0, max=1.0e-6)
+ #params.add("s02", expr="s01 + df_s0")
+ params.add("s02", value=1.0e-8, min=0.0, max=1.0e-6)
+ #params.add("df_rc", value=30.0, min=0.0, max=1.0e4)
+ #params.add("rc2", expr="rc1 - df_rc")
+ params.add("rc2", value=20.0, min=1.0, max=5.0e2)
+ params.add("beta2", value=0.7, min=0.3, max=1.1)
+ params.add("bkg", value=1.0e-9, min=0.0, max=1.0e-7)
+
+ @staticmethod
+ def beta1(r, params):
+ """
+ This beta component describes the main/outer part of the SBP.
+ """
+ parvals = params.valuesdict()
+ s01 = parvals["s01"]
+ rc1 = parvals["rc1"]
+ beta1 = parvals["beta1"]
+ bkg = parvals["bkg"]
+ return s01 * np.power((1 + (r/rc1)**2), (0.5 - 3*beta1)) + bkg
+
+ @staticmethod
+ def beta2(r, params):
+ """
+ This beta component describes the central/excess part of the SBP.
+ """
+ parvals = params.valuesdict()
+ s02 = parvals["s02"]
+ rc2 = parvals["rc2"]
+ beta2 = parvals["beta2"]
+ return s02 * np.power((1 + (r/rc2)**2), (0.5 - 3*beta2))
+
+ @classmethod
+ def dbeta(self, r, params):
+ return self.beta1(r, params) + self.beta2(r, params)
+
+ def __init__(self):
+ super(self.__class__, self).__init__(name="Double-beta",
+ func=self.dbeta, params=self.params)
+
+ def plot(self, params, xdata, ax):
+ """
+ Plot the fitted model, and each beta component,
+ as well as the fitted parameters.
+ """
+ super(self.__class__, self).plot(params, xdata, ax)
+ beta1_ydata = self.beta1(xdata, params)
+ beta2_ydata = self.beta2(xdata, params)
+ ax.plot(xdata, beta1_ydata, 'b-.')
+ ax.plot(xdata, beta2_ydata, 'b-.')
+ # fitted paramters
+ ydata = beta1_ydata + beta2_ydata
+ ax.vlines(x=params["rc1"].value, ymin=min(ydata), ymax=max(ydata),
+ linestyles="dashed")
+ ax.vlines(x=params["rc2"].value, ymin=min(ydata), ymax=max(ydata),
+ linestyles="dashed")
+ ax.hlines(y=params["bkg"].value, xmin=min(xdata), xmax=max(xdata),
+ linestyles="dashed")
+ ax.text(x=params["rc1"].value, y=min(ydata),
+ s="beta1: %.2f\nrc1: %.2f" % (params["beta1"].value,
+ params["rc1"].value))
+ ax.text(x=params["rc2"].value, y=min(ydata),
+ s="beta2: %.2f\nrc2: %.2f" % (params["beta2"].value,
+ params["rc2"].value))
+ ax.text(x=min(xdata), y=min(ydata),
+ s="bkg: %.3e" % params["bkg"].value,
+ verticalalignment="top")
+
+
+class FitModelSBetaNorm(FitModel):
+ """
+ The single-beta model to be fitted.
+ Single-beta model, with a constant background.
+ Normalized the `s0' and `bkg' parameters by take the logarithm.
+ """
+ params = lmfit.Parameters()
+ params.add_many( # (name, value, vary, min, max, expr)
+ ("log10_s0", -8.0, True, -12.0, -6.0, None),
+ ("rc", 30.0, True, 1.0, 1.0e4, None),
+ ("beta", 0.7, True, 0.3, 1.1, None),
+ ("log10_bkg", -9.0, True, -12.0, -7.0, None))
+
+ @staticmethod
+ def sbeta(r, params):
+ parvals = params.valuesdict()
+ s0 = 10 ** parvals["log10_s0"]
+ rc = parvals["rc"]
+ beta = parvals["beta"]
+ bkg = 10 ** parvals["log10_bkg"]
+ return s0 * np.power((1 + (r/rc)**2), (0.5 - 3*beta)) + bkg
+
+ def __init__(self):
+ super(self.__class__, self).__init__(name="Single-beta",
+ func=self.sbeta, params=self.params)
+
+ def plot(self, params, xdata, ax):
+ """
+ Plot the fitted model, as well as the fitted parameters.
+ """
+ super(self.__class__, self).plot(params, xdata, ax)
+ ydata = self.sbeta(xdata, params)
+ # fitted paramters
+ ax.vlines(x=params["rc"].value, ymin=min(ydata), ymax=max(ydata),
+ linestyles="dashed")
+ ax.hlines(y=(10 ** params["bkg"].value), xmin=min(xdata),
+ xmax=max(xdata), linestyles="dashed")
+ ax.text(x=params["rc"].value, y=min(ydata),
+ s="beta: %.2f\nrc: %.2f" % (params["beta"].value,
+ params["rc"].value))
+ ax.text(x=min(xdata), y=min(ydata),
+ s="bkg: %.3e" % (10 ** params["bkg"].value),
+ verticalalignment="top")
+
+
+class FitModelDBetaNorm(FitModel):
+ """
+ The double-beta model to be fitted.
+ Double-beta model, with a constant background.
+ Normalized the `s01', `s02' and `bkg' parameters by take the logarithm.
+
+ NOTE:
+ the first beta component (s01, rc1, beta1) describes the main and
+ outer SBP; while the second beta component (s02, rc2, beta2) accounts
+ for the central brightness excess.
+ """
+ params = lmfit.Parameters()
+ params.add("log10_s01", value=-8.0, min=-12.0, max=-6.0)
+ params.add("rc1", value=50.0, min=10.0, max=1.0e4)
+ params.add("beta1", value=0.7, min=0.3, max=1.1)
+ #params.add("df_s0", value=1.0e-8, min=0.0, max=1.0e-6)
+ #params.add("s02", expr="s01 + df_s0")
+ params.add("log10_s02", value=-8.0, min=-12.0, max=-6.0)
+ #params.add("df_rc", value=30.0, min=0.0, max=1.0e4)
+ #params.add("rc2", expr="rc1 - df_rc")
+ params.add("rc2", value=20.0, min=1.0, max=5.0e2)
+ params.add("beta2", value=0.7, min=0.3, max=1.1)
+ params.add("log10_bkg", value=-9.0, min=-12.0, max=-7.0)
+
+ @staticmethod
+ def beta1(r, params):
+ """
+ This beta component describes the main/outer part of the SBP.
+ """
+ parvals = params.valuesdict()
+ s01 = 10 ** parvals["log10_s01"]
+ rc1 = parvals["rc1"]
+ beta1 = parvals["beta1"]
+ bkg = 10 ** parvals["log10_bkg"]
+ return s01 * np.power((1 + (r/rc1)**2), (0.5 - 3*beta1)) + bkg
+
+ @staticmethod
+ def beta2(r, params):
+ """
+ This beta component describes the central/excess part of the SBP.
+ """
+ parvals = params.valuesdict()
+ s02 = 10 ** parvals["log10_s02"]
+ rc2 = parvals["rc2"]
+ beta2 = parvals["beta2"]
+ return s02 * np.power((1 + (r/rc2)**2), (0.5 - 3*beta2))
+
+ @classmethod
+ def dbeta(self, r, params):
+ return self.beta1(r, params) + self.beta2(r, params)
+
+ def __init__(self):
+ super(self.__class__, self).__init__(name="Double-beta",
+ func=self.dbeta, params=self.params)
+
+ def plot(self, params, xdata, ax):
+ """
+ Plot the fitted model, and each beta component,
+ as well as the fitted parameters.
+ """
+ super(self.__class__, self).plot(params, xdata, ax)
+ beta1_ydata = self.beta1(xdata, params)
+ beta2_ydata = self.beta2(xdata, params)
+ ax.plot(xdata, beta1_ydata, 'b-.')
+ ax.plot(xdata, beta2_ydata, 'b-.')
+ # fitted paramters
+ ydata = beta1_ydata + beta2_ydata
+ ax.vlines(x=params["log10_rc1"].value, ymin=min(ydata), ymax=max(ydata),
+ linestyles="dashed")
+ ax.vlines(x=params["rc2"].value, ymin=min(ydata), ymax=max(ydata),
+ linestyles="dashed")
+ ax.hlines(y=(10 ** params["bkg"].value), xmin=min(xdata),
+ xmax=max(xdata), linestyles="dashed")
+ ax.text(x=params["rc1"].value, y=min(ydata),
+ s="beta1: %.2f\nrc1: %.2f" % (params["beta1"].value,
+ params["rc1"].value))
+ ax.text(x=params["rc2"].value, y=min(ydata),
+ s="beta2: %.2f\nrc2: %.2f" % (params["beta2"].value,
+ params["rc2"].value))
+ ax.text(x=min(xdata), y=min(ydata),
+ s="bkg: %.3e" % (10 ** params["bkg"].value),
+ verticalalignment="top")
+
+
+class SbpFit:
+ """
+ Class to handle the SBP fitting with single-/double-beta model.
+ """
+ def __init__(self, model, method="lbfgsb",
+ xdata=None, ydata=None, xerr=None, yerr=None,
+ name=None, obsid=None):
+ self.method = method
+ self.model = model
+ self.xdata = xdata
+ self.ydata = ydata
+ self.xerr = xerr
+ self.yerr = yerr
+ if xdata is not None:
+ self.mask = np.ones(xdata.shape, dtype=np.bool)
+ else:
+ self.mask = None
+ self.name = name
+ self.obsid = obsid
+
+ def set_source(self, name, obsid=None):
+ self.name = name
+ self.obsid = obsid
+
+ def load_data(self, xdata, ydata, xerr, yerr):
+ self.xdata = xdata
+ self.ydata = ydata
+ self.xerr = xerr
+ self.yerr = yerr
+ self.mask = np.ones(xdata.shape, dtype=np.bool)
+
+ def ignore_data(self, xmin=None, xmax=None):
+ """
+ Ignore the data points within range [xmin, xmax].
+ If xmin is None, then xmin=min(xdata);
+ if xmax is None, then xmax=max(xdata).
+ """
+ if xmin is None:
+ xmin = min(self.xdata)
+ if xmax is None:
+ xmax = max(self.xdata)
+ ignore_idx = np.logical_and(self.xdata >= xmin, self.xdata <= xmax)
+ self.mask[ignore_idx] = False
+ # reset `f_residual'
+ self.f_residual = None
+
+ def notice_data(self, xmin=None, xmax=None):
+ """
+ Notice the data points within range [xmin, xmax].
+ If xmin is None, then xmin=min(xdata);
+ if xmax is None, then xmax=max(xdata).
+ """
+ if xmin is None:
+ xmin = min(self.xdata)
+ if xmax is None:
+ xmax = max(self.xdata)
+ notice_idx = np.logical_and(self.xdata >= xmin, self.xdata <= xmax)
+ self.mask[notice_idx] = True
+ # reset `f_residual'
+ self.f_residual = None
+
+ def set_residual(self):
+ def f_residual(params):
+ if self.yerr is None:
+ return self.model.func(self.xdata[self.mask], params) - \
+ self.ydata
+ else:
+ return (self.model.func(self.xdata[self.mask], params) - \
+ self.ydata[self.mask]) / self.yerr[self.mask]
+ self.f_residual = f_residual
+
+ def fit(self, method=None):
+ if method is None:
+ method = self.method
+ if not hasattr(self, "f_residual") or self.f_residual is None:
+ self.set_residual()
+ self.fitter = lmfit.Minimizer(self.f_residual, self.model.params)
+ self.results = self.fitter.minimize(method=method)
+ self.fitted_model = lambda x: self.model.func(x, self.results.params)
+
+ def calc_ci(self, sigmas=[0.68, 0.90]):
+ # `conf_interval' requires the fitted results have valid `stderr',
+ # so we need to re-fit the model with method `leastsq'.
+ results = self.fitter.minimize(method="leastsq",
+ params=self.results.params)
+ self.ci, self.trace = lmfit.conf_interval(self.fitter, results,
+ sigmas=sigmas, trace=True)
+
+ @staticmethod
+ def ci_report(ci, with_offset=True):
+ """
+ Format and return the report for confidence intervals.
+ This function is intended to replace the `lmfit.ci_report()`
+ which do not use scitenfic notation for small values.
+ """
+ pnames = list(ci.keys())
+ plen = max(map(len, pnames))
+ report = []
+ # header
+ sigmas = [ v[0] for v in ci.get(pnames[0]) if v[0] > 0.0 ]
+ fmt = "{:<%d}" % (plen+1) + \
+ " {:11.2%} {:11.2%} _BEST_ {:11.2%} {:11.2%}"
+ report.append(fmt.format(" ", *sigmas))
+ # parameters ci
+ for par in pnames:
+ values = [ v[1] for v in ci.get(par) ]
+ if with_offset:
+ N = len(values)
+ offsets = [ values[N//2] for i in range(N) ]
+ offsets[N//2] = 0.0
+ values = list(map(lambda a,b: a-b, values, offsets))
+ fmt = "{:<%d}" % (plen+1) + \
+ " {:+11.5g} {:+11.5g} {:11.5g} {:+11.5g} {:+11.5g}"
+ report.append(fmt.format(par+":", *values))
+ return "\n".join(report)
+
+ def report(self, outfile=sys.stdout):
+ if hasattr(self, "results") and self.results is not None:
+ p = lmfit.fit_report(self.results)
+ print(p, file=outfile)
+ if hasattr(self, "ci") and self.ci is not None:
+ print("-----------------------------------------------",
+ file=outfile)
+ p = self.ci_report(self.ci)
+ print(p, file=outfile)
+
+ def plot(self, ax=None, fig=None):
+ if ax is None:
+ fig, ax = plt.subplots(1, 1)
+ # noticed data points
+ eb = ax.errorbar(self.xdata[self.mask], self.ydata[self.mask],
+ xerr=self.xerr[self.mask], yerr=self.yerr[self.mask],
+ fmt="none")
+ # ignored data points
+ ignore_mask = np.logical_not(self.mask)
+ if np.sum(ignore_mask) > 0:
+ eb = ax.errorbar(self.xdata[ignore_mask], self.ydata[ignore_mask],
+ xerr=self.xerr[ignore_mask], yerr=self.yerr[ignore_mask],
+ fmt="none")
+ eb[-1][0].set_linestyle("-.")
+ # fitted model
+ xmax = self.xdata[-1] + self.xerr[-1]
+ xpred = np.power(10, np.linspace(0, np.log10(xmax), 2*len(self.xdata)))
+ ypred = self.fitted_model(xpred)
+ ymin = min(min(self.ydata), min(ypred))
+ ymax = max(max(self.ydata), max(ypred))
+ self.model.plot(params=self.results.params, xdata=xpred, ax=ax)
+ ax.set_xscale("log")
+ ax.set_yscale("log")
+ ax.set_xlim(1.0, xmax)
+ ax.set_ylim(ymin/1.2, ymax*1.2)
+ name = self.name
+ if self.obsid is not None:
+ name += "; %s" % self.obsid
+ ax.set_title("Fitted Surface Brightness Profile (%s)" % name)
+ ax.set_xlabel("Radius (pixel)")
+ ax.set_ylabel(r"Surface Brightness (photons/cm$^2$/pixel$^2$/s)")
+ ax.text(x=xmax, y=ymax,
+ s="redchi: %.2f / %.2f = %.2f" % (self.results.chisqr,
+ self.results.nfree, self.results.chisqr/self.results.nfree),
+ horizontalalignment="right", verticalalignment="top")
+ return (fig, ax)
+
+
+def make_model(config):
+ """
+ Make the model with parameters set according to the config.
+ """
+ modelname = config["model"]
+ if modelname == "sbeta":
+ # single-beta model
+ model = FitModelSBeta()
+ elif modelname == "dbeta":
+ # double-beta model
+ model = FitModelDBeta()
+ else:
+ raise ValueError("Invalid model")
+ # set initial values and bounds for the model parameters
+ params = config.get(modelname)
+ for p, value in params.items():
+ model.set_param(name=p, value=float(value[0]),
+ min=float(value[1]), max=float(value[2]))
+ return model
+
+
+def main():
+ # parser for command line options and arguments
+ parser = argparse.ArgumentParser(
+ description="Fit surface brightness profile with " + \
+ "single-/double-beta model",
+ epilog="Version: %s (%s)" % (__version__, __date__))
+ parser.add_argument("-V", "--version", action="version",
+ version="%(prog)s " + "%s (%s)" % (__version__, __date__))
+ parser.add_argument("config", help="Config file for SBP fitting")
+ args = parser.parse_args()
+
+ config = ConfigObj(args.config)
+
+ # fit model
+ model = make_model(config)
+
+ # sbp data and fit object
+ sbpdata = np.loadtxt(config["sbpfile"])
+ sbpfit = SbpFit(model=model, xdata=sbpdata[:, 0], xerr=sbpdata[:, 1],
+ ydata=sbpdata[:, 2], yerr=sbpdata[:, 3])
+ sbpfit.set_source(config["name"], obsid=config.get("obsid"))
+
+ # apply data range ignorance
+ if "ignore" in config.keys():
+ for ig in config.as_list("ignore"):
+ xmin, xmax = map(float, ig.split("-"))
+ sbpfit.ignore_data(xmin=xmin, xmax=xmax)
+
+ # fit and calculate confidence intervals
+ sbpfit.fit()
+ sbpfit.calc_ci()
+ sbpfit.report()
+ with open(config["outfile"], "w") as outfile:
+ sbpfit.report(outfile=outfile)
+
+ # make and save a plot
+ fig = Figure()
+ canvas = FigureCanvas(fig)
+ ax = fig.add_subplot(111)
+ sbpfit.plot(ax=ax, fig=fig)
+ fig.savefig(config["imgfile"])
+
+
+if __name__ == "__main__":
+ main()
+
+# vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: #
diff --git a/python/starlet.py b/python/starlet.py
new file mode 100644
index 0000000..67e69ba
--- /dev/null
+++ b/python/starlet.py
@@ -0,0 +1,513 @@
+# -*- coding: utf-8 -*-
+#
+# References:
+# [1] Jean-Luc Starck, Fionn Murtagh & Jalal M. Fadili
+# Sparse Image and Signal Processing: Wavelets, Curvelets, Morphological Diversity
+# Section 3.5, 6.6
+#
+# Credits:
+# [1] https://github.com/abrazhe/image-funcut/blob/master/imfun/atrous.py
+#
+# Aaron LI
+# Created: 2016-03-17
+# Updated: 2016-03-22
+#
+
+"""
+Starlet wavelet transform, i.e., isotropic undecimated wavelet transform
+(IUWT), or à trous wavelet transform.
+And multi-scale variance stabling transform (MS-VST).
+"""
+
+import sys
+
+import numpy as np
+import scipy as sp
+from scipy import signal
+
+
+class B3Spline:
+ """
+ B3-spline wavelet.
+ """
+ # scaling function (phi)
+ dec_lo = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) / 16
+ dec_hi = np.array([-1.0, -4.0, 10.0, -4.0, -1.0]) / 16
+ rec_lo = np.array([0.0, 0.0, 1.0, 0.0, 0.0])
+ rec_hi = np.array([0.0, 0.0, 1.0, 0.0, 0.0])
+
+
+class IUWT:
+ """
+ Isotropic undecimated wavelet transform.
+ """
+ ## Decomposition filters list:
+ # a_{scale} = convole(a_0, filters[scale])
+ # Note: the zero-th scale filter (i.e., delta function) is the first
+ # element, thus the array index is the same as the decomposition scale.
+ filters = []
+
+ phi = None # wavelet scaling function (2D)
+ level = 0 # number of transform level
+ decomposition = None # decomposed coefficients/images
+ reconstruction = None # reconstructed image
+
+ # convolution boundary condition
+ boundary = "symm"
+
+ def __init__(self, phi=B3Spline.dec_lo, level=None, boundary="symm",
+ data=None):
+ self.set_wavelet(phi=phi)
+ self.level = level
+ self.boundary = boundary
+ self.data = np.array(data)
+
+ def reset(self):
+ """
+ Reset the object attributes.
+ """
+ self.data = None
+ self.phi = None
+ self.decomposition = None
+ self.reconstruction = None
+ self.level = 0
+ self.filters = []
+ self.boundary = "symm"
+
+ def load_data(self, data):
+ self.reset()
+ self.data = np.array(data)
+
+ def set_wavelet(self, phi):
+ self.reset()
+ phi = np.array(phi)
+ if phi.ndim == 1:
+ phi_ = phi.reshape(1, -1)
+ self.phi = np.dot(phi_.T, phi_)
+ elif phi.ndim == 2:
+ self.phi = phi
+ else:
+ raise ValueError("Invalid phi dimension")
+
+ def calc_filters(self):
+ """
+ Calculate the convolution filters of each scale.
+ Note: the zero-th scale filter (i.e., delta function) is the first
+ element, thus the array index is the same as the decomposition scale.
+ """
+ self.filters = []
+ # scale 0: delta function
+ h = np.array([[1]]) # NOTE: 2D
+ self.filters.append(h)
+ # scale 1
+ h = self.phi[::-1, ::-1]
+ self.filters.append(h)
+ for scale in range(2, self.level+1):
+ h_up = self.zupsample(self.phi, order=scale-1)
+ h2 = signal.convolve2d(h_up[::-1, ::-1], h, mode="same",
+ boundary=self.boundary)
+ self.filters.append(h2)
+
+ def transform(self, data, scale, boundary="symm"):
+ """
+ Perform only one scale wavelet transform for the given data.
+
+ return:
+ [ approx, detail ]
+ """
+ self.decomposition = []
+ approx = signal.convolve2d(data, self.filters[scale],
+ mode="same", boundary=self.boundary)
+ detail = data - approx
+ return [approx, detail]
+
+ def decompose(self, level, boundary="symm"):
+ """
+ Perform IUWT decomposition in the plain loop way.
+ The filters of each scale/level are calculated first, then the
+ approximations of each scale/level are calculated by convolving the
+ raw/finest image with these filters.
+
+ return:
+ [ W_1, W_2, ..., W_n, A_n ]
+ n = level
+ W: wavelet details
+ A: approximation
+ """
+ self.boundary = boundary
+ if self.level != level or self.filters == []:
+ self.level = level
+ self.calc_filters()
+ self.decomposition = []
+ approx = self.data
+ for scale in range(1, level+1):
+ # approximation:
+ approx2 = signal.convolve2d(self.data, self.filters[scale],
+ mode="same", boundary=self.boundary)
+ # wavelet details:
+ w = approx - approx2
+ self.decomposition.append(w)
+ if scale == level:
+ self.decomposition.append(approx2)
+ approx = approx2
+ return self.decomposition
+
+ def decompose_recursive(self, level, boundary="symm"):
+ """
+ Perform the IUWT decomposition in the recursive way.
+
+ return:
+ [ W_1, W_2, ..., W_n, A_n ]
+ n = level
+ W: wavelet details
+ A: approximation
+ """
+ self.level = level
+ self.boundary = boundary
+ self.decomposition = self.__decompose(self.data, self.phi, level=level)
+ return self.decomposition
+
+ def __decompose(self, data, phi, level):
+ """
+ 2D IUWT decomposition (or stationary wavelet transform).
+
+ This is a convolution version, where kernel is zero-upsampled
+ explicitly. Not fast.
+
+ Parameters:
+ - level : level of decomposition
+ - phi : low-pass filter kernel
+ - boundary : boundary conditions (passed to scipy.signal.convolve2d,
+ 'symm' by default)
+
+ Returns:
+ list of wavelet details + last approximation. Each element in
+ the list is an image of the same size as the input image.
+ """
+ if level <= 0:
+ return data
+ shapecheck = map(lambda a,b:a>b, data.shape, phi.shape)
+ assert np.all(shapecheck)
+ # approximation:
+ approx = signal.convolve2d(data, phi[::-1, ::-1], mode="same",
+ boundary=self.boundary)
+ # wavelet details:
+ w = data - approx
+ phi_up = self.zupsample(phi, order=1)
+ shapecheck = map(lambda a,b:a>b, data.shape, phi_up.shape)
+ if level == 1:
+ return [w, approx]
+ elif not np.all(shapecheck):
+ print("Maximum allowed decomposition level reached",
+ file=sys.stderr)
+ return [w, approx]
+ else:
+ return [w] + self.__decompose(approx, phi_up, level-1)
+
+ @staticmethod
+ def zupsample(data, order=1):
+ """
+ Upsample data array by interleaving it with zero's.
+
+ h{up_order: n}[l] = (1) h[l], if l % 2^n == 0;
+ (2) 0, otherwise
+ """
+ shape = data.shape
+ new_shape = [ (2**order * (n-1) + 1) for n in shape ]
+ output = np.zeros(new_shape, dtype=data.dtype)
+ output[[ slice(None, None, 2**order) for d in shape ]] = data
+ return output
+
+ def reconstruct(self, decomposition=None):
+ if decomposition is not None:
+ reconstruction = np.sum(decomposition, axis=0)
+ return reconstruction
+ else:
+ self.reconstruction = np.sum(self.decomposition, axis=0)
+
+ def get_detail(self, scale):
+ """
+ Get the wavelet detail coefficients of given scale.
+ Note: 1 <= scale <= level
+ """
+ if scale < 1 or scale > self.level:
+ raise ValueError("Invalid scale")
+ return self.decomposition[scale-1]
+
+ def get_approx(self):
+ """
+ Get the approximation coefficients of the largest scale.
+ """
+ return self.decomposition[-1]
+
+
+class IUWT_VST(IUWT):
+ """
+ IUWT with Multi-scale variance stabling transform.
+
+ Refernce:
+ [1] Bo Zhang, Jalal M. Fadili & Jean-Luc Starck,
+ IEEE Trans. Image Processing, 17, 17, 2008
+ """
+ # VST coefficients and the corresponding asymptotic standard deviation
+ # of each scale.
+ vst_coef = []
+
+ def reset(self):
+ super(self.__class__, self).reset()
+ vst_coef = []
+
+ def __decompose(self):
+ raise AttributeError("No '__decompose' attribute")
+
+ @staticmethod
+ def soft_threshold(data, threshold):
+ if isinstance(data, np.ndarray):
+ data_th = data.copy()
+ data_th[np.abs(data) <= threshold] = 0.0
+ data_th[data > threshold] -= threshold
+ data_th[data < -threshold] += threshold
+ else:
+ data_th = data
+ if np.abs(data) <= threshold:
+ data_th = 0.0
+ elif data > threshold:
+ data_th -= threshold
+ else:
+ data_th += threshold
+ return data_th
+
+ def tau(self, k, scale):
+ """
+ Helper function used in VST coefficients calculation.
+ """
+ return np.sum(np.power(self.filters[scale], k))
+
+ def filters_product(self, scale1, scale2):
+ """
+ Calculate the scalar product of the filters of two scales,
+ considering only the overlapped part.
+ Helper function used in VST coefficients calculation.
+ """
+ if scale1 > scale2:
+ filter_big = self.filters[scale1]
+ filter_small = self.filters[scale2]
+ else:
+ filter_big = self.filters[scale2]
+ filter_small = self.filters[scale1]
+ # crop the big filter to match the size of the small filter
+ size_big = filter_big.shape
+ size_small = filter_small.shape
+ size_diff2 = list(map(lambda a,b: (a-b)//2, size_big, size_small))
+ filter_big_crop = filter_big[
+ size_diff2[0]:(size_big[0]-size_diff2[0]),
+ size_diff2[1]:(size_big[1]-size_diff2[1])]
+ assert(np.all(list(map(lambda a,b: a==b,
+ size_small, filter_big_crop.shape))))
+ product = np.sum(filter_small * filter_big_crop)
+ return product
+
+ def calc_vst_coef(self):
+ """
+ Calculate the VST coefficients and the corresponding
+ asymptotic standard deviation of each scale, according to the
+ calculated filters of each scale/level.
+ """
+ self.vst_coef = []
+ for scale in range(self.level+1):
+ b = 2 * np.sqrt(np.abs(self.tau(1, scale)) / self.tau(2, scale))
+ c = 7.0*self.tau(2, scale) / (8.0*self.tau(1, scale)) - \
+ self.tau(3, scale) / (2.0*self.tau(2, scale))
+ if scale == 0:
+ std = -1.0
+ else:
+ std = np.sqrt((self.tau(2, scale-1) / \
+ (4 * self.tau(1, scale-1)**2)) + \
+ (self.tau(2, scale) / (4 * self.tau(1, scale)**2)) - \
+ (self.filters_product(scale-1, scale) / \
+ (2 * self.tau(1, scale-1) * self.tau(1, scale))))
+ self.vst_coef.append({ "b": b, "c": c, "std": std })
+
+ def vst(self, data, scale, coupled=True):
+ """
+ Perform variance stabling transform
+
+ XXX: parameter `coupled' why??
+ Credit: MSVST-V1.0/src/libmsvst/B3VSTAtrous.h
+ """
+ self.vst_coupled = coupled
+ if self.vst_coef == []:
+ self.calc_vst_coef()
+ if coupled:
+ b = 1.0
+ else:
+ b = self.vst_coef[scale]["b"]
+ data_vst = b * np.sqrt(np.abs(data + self.vst_coef[scale]["c"]))
+ return data_vst
+
+ def ivst(self, data, scale, cbias=True):
+ """
+ Inverse variance stabling transform
+ NOTE: assuming that `a_{j} + c^{j}' are all positive.
+
+ XXX: parameter `cbias' why??
+ `bias correction' is recommended while reconstruct the data
+ after estimation
+ Credit: MSVST-V1.0/src/libmsvst/B3VSTAtrous.h
+ """
+ self.vst_cbias = cbias
+ if cbias:
+ cb = 1.0 / (self.vst_coef[scale]["b"] ** 2)
+ else:
+ cb = 0.0
+ data_ivst = data ** 2 + cb - self.vst_coef[scale]["c"]
+ return data_ivst
+
+ def is_significant(self, scale, fdr=0.1, independent=False):
+ """
+ Multiple hypothesis testing with false discovery rate (FDR) control.
+
+ If `independent=True': FDR <= (m0/m) * q
+ otherwise: FDR <= (m0/m) * q * (1 + 1/2 + 1/3 + ... + 1/m)
+
+ References:
+ [1] False discovery rate - Wikipedia
+ https://en.wikipedia.org/wiki/False_discovery_rate
+ """
+ coef = self.get_detail(scale)
+ pvalues = 2 * (1 - sp.stats.norm.cdf(np.abs(coef) / \
+ self.vst_coef[scale]["std"]))
+ p_sorted = pvalues.flatten()
+ p_sorted.sort()
+ N = len(p_sorted)
+ if independent:
+ cn = 1.0
+ else:
+ cn = np.sum(1.0 / np.arange(1, N+1))
+ p_comp = fdr * np.arange(N) / (N * cn)
+ comp = (p_sorted < p_comp)
+ # cutoff p-value after FDR control/correction
+ p_cutoff = np.max(p_sorted[comp])
+ return (pvalues <= p_cutoff, p_cutoff)
+
+ def denoise(self, fdr=0.1, fdr_independent=False):
+ """
+ Denoise the wavelet coefficients by controlling FDR.
+ """
+ self.fdr = fdr
+ self.fdr_indepent = fdr_independent
+ self.denoised = []
+ # supports of significant coefficients of each scale
+ self.sig_supports = [None] # make index match the scale
+ self.p_cutoff = [None]
+ for scale in range(1, self.level+1):
+ coef = self.get_detail(scale)
+ sig, p_cutoff = self.is_significant(scale, fdr, fdr_independent)
+ coef[np.logical_not(sig)] = 0.0
+ self.denoised.append(coef)
+ self.sig_supports.append(sig)
+ self.p_cutoff.append(p_cutoff)
+ # append the last approximation
+ self.denoised.append(self.get_approx())
+
+ def decompose(self, level, boundary="symm"):
+ """
+ 2D IUWT decomposition with VST.
+ """
+ self.boundary = boundary
+ if self.level != level or self.filters == []:
+ self.level = level
+ self.calc_filters()
+ self.calc_vst_coef()
+ self.decomposition = []
+ approx = self.data
+ for scale in range(1, level+1):
+ # approximation:
+ approx2 = signal.convolve2d(self.data, self.filters[scale],
+ mode="same", boundary=self.boundary)
+ # wavelet details:
+ w = self.vst(approx, scale=scale-1) - self.vst(approx2, scale=scale)
+ self.decomposition.append(w)
+ if scale == level:
+ self.decomposition.append(approx2)
+ approx = approx2
+ return self.decomposition
+
+ def reconstruct_ivst(self, denoised=True, positive_project=True):
+ """
+ Reconstruct the original image from the *un-denoised* decomposition
+ by applying the inverse VST.
+
+ This reconstruction result is also used as the `initial condition'
+ for the below `iterative reconstruction' algorithm.
+
+ arguments:
+ * denoised: whether use th denoised data or the direct decomposition
+ * positive_project: whether replace negative values with zeros
+ """
+ if denoised:
+ decomposition = self.denoised
+ else:
+ decomposition = self.decomposition
+ self.positive_project = positive_project
+ details = np.sum(decomposition[:-1], axis=0)
+ approx = self.vst(decomposition[-1], scale=self.level)
+ reconstruction = self.ivst(approx+details, scale=0)
+ if positive_project:
+ reconstruction[reconstruction < 0.0] = 0.0
+ self.reconstruction = reconstruction
+ return reconstruction
+
+ def reconstruct(self, denoised=True, niter=20, verbose=False):
+ """
+ Reconstruct the original image using iterative method with
+ L1 regularization, because the denoising violates the exact inverse
+ procedure.
+
+ arguments:
+ * denoised: whether use the denoised coefficients
+ * niter: number of iterations
+ """
+ if denoised:
+ decomposition = self.denoised
+ else:
+ decomposition = self.decomposition
+ # L1 regularization
+ lbd = 1.0
+ delta = lbd / (niter - 1)
+ # initial solution
+ solution = self.reconstruct_ivst(denoised=denoised,
+ positive_project=True)
+ #
+ iuwt = IUWT(level=self.level)
+ iuwt.calc_filters()
+ # iterative reconstruction
+ for i in range(niter):
+ if verbose:
+ print("iter: %d" % i+1, file=sys.stderr)
+ tempd = self.data.copy()
+ solution_decomp = []
+ for scale in range(1, self.level+1):
+ approx, detail = iuwt.transform(tempd, scale)
+ approx_sol, detail_sol = iuwt.transform(solution, scale)
+ # Update coefficients according to the significant supports,
+ # which are acquired during the denosing precodure with FDR.
+ sig = self.sig_supports[scale]
+ detail_sol[sig] = detail[sig]
+ detail_sol = self.soft_threshold(detail_sol, threshold=lbd)
+ #
+ solution_decomp.append(detail_sol)
+ tempd = approx.copy()
+ solution = approx_sol.copy()
+ # last approximation (the two are the same)
+ solution_decomp.append(approx)
+ # reconstruct
+ solution = iuwt.reconstruct(decomposition=solution_decomp)
+ # discard all negative values
+ solution[solution < 0] = 0.0
+ #
+ lbd -= delta
+ #
+ self.reconstruction = solution
+ return self.reconstruction
+
diff --git a/python/xkeywordsync.py b/python/xkeywordsync.py
new file mode 100644
index 0000000..73f48b9
--- /dev/null
+++ b/python/xkeywordsync.py
@@ -0,0 +1,533 @@
+#!/bin/usr/env python3
+# -*- coding: utf-8 -*-
+#
+# Credits:
+# [1] Gaute Hope: gauteh/abunchoftags
+# https://github.com/gauteh/abunchoftags/blob/master/keywsync.cc
+#
+# TODO:
+# * Support case-insensitive tags merge
+# (ref: http://stackoverflow.com/a/1480230)
+# * Accept a specified mtime, and only deal with files with newer mtime.
+#
+# Aaron LI
+# Created: 2016-01-24
+#
+
+"""
+Sync message 'X-Keywords' header with notmuch tags.
+
+* tags-to-keywords:
+ Check if the messages in the query have a matching 'X-Keywords' header
+ to the list of notmuch tags.
+ If not, update the 'X-Keywords' and re-write the message.
+
+* keywords-to-tags:
+ Check if the messages in the query have matching notmuch tags to the
+ 'X-Keywords' header.
+ If not, update the tags in the notmuch database.
+
+* merge-keywords-tags:
+ Merge the 'X-Keywords' labels and notmuch tags, and update both.
+"""
+
+__version__ = "0.1.2"
+__date__ = "2016-01-25"
+
+import os
+import sys
+import argparse
+import email
+
+# Require Python 3.4, or install package 'enum34'
+from enum import Enum
+
+from notmuch import Database, Query
+
+from imapUTF7 import imapUTF7Decode, imapUTF7Encode
+
+
+class SyncDirection(Enum):
+ """
+ Synchronization direction
+ """
+ MERGE_KEYWORDS_TAGS = 0 # Merge 'X-Keywords' and notmuch tags and
+ # update both
+ KEYWORDS_TO_TAGS = 1 # Sync 'X-Keywords' header to notmuch tags
+ TAGS_TO_KEYWORDS = 2 # Sync notmuch tags to 'X-Keywords' header
+
+class SyncMode(Enum):
+ """
+ Sync mode
+ """
+ ADD_REMOVE = 0 # Allow add & remove tags/keywords
+ ADD_ONLY = 1 # Only allow add tags/keywords
+ REMOVE_ONLY = 2 # Only allow remove tags/keywords
+
+
+class KwMessage:
+ """
+ Message class to deal with 'X-Keywords' header synchronization
+ with notmuch tags.
+
+ NOTE:
+ * The same message may have multiple files with different keywords
+ (e.g, the same message exported under each label by Gmail)
+ managed by OfflineIMAP.
+ For example: a message file in OfflineIMAP synced folder of
+ '[Gmail]/All Mail' have keywords ['google', 'test']; however,
+ the file in synced folder 'test' of the same message only have
+ keywords ['google'] without the keyword 'test'.
+ * All files associated to the same message are regarded as the same.
+ The keywords are extracted from all files and merged.
+ And the same updated keywords are written back to all files, which
+ results all files finally having the same 'X-Keywords' header.
+ * You may only sync the '[Gmail]/All Mail' folder without other
+ folders exported according the labels by Gmail.
+ """
+ # Replace some special characters before mapping keyword to tag
+ enable_replace_chars = True
+ chars_replace = {
+ '/' : '.',
+ }
+ # Mapping between (Gmail) keywords and notmuch tags (before ignoring tags)
+ keywords_mapping = {
+ '\\Inbox' : 'inbox',
+ '\\Important' : 'important',
+ '\\Starred' : 'flagged',
+ '\\Sent' : 'sent',
+ '\\Muted' : 'killed',
+ '\\Draft' : 'draft',
+ '\\Trash' : 'deleted',
+ '\\Junk' : 'spam',
+ }
+ # Tags ignored from syncing
+ # These tags are either internal tags or tags handled by maildir flags.
+ enable_ignore_tags = True
+ tags_ignored = set([
+ 'new', 'unread', 'attachment', 'signed', 'encrypted',
+ 'flagged', 'replied', 'passed', 'draft',
+ ])
+ # Ignore case when merging tags
+ tags_ignorecase = True
+
+ # Whether the tags updated against the message 'X-Keywords' header
+ tags_updated = False
+ # Added & removed tags for notmuch database against 'X-Keywords'
+ tags_added = []
+ tags_removed = []
+ # Newly updated/merged notmuch tags against 'X-Keywords'
+ tags_new = []
+
+ # Whether the keywords updated against the notmuch tags
+ keywords_updated = False
+ # Added & removed tags for 'X-Keywords' against notmuch database
+ tags_kw_added = []
+ tags_kw_removed = []
+ # Newly updated/merged tags for 'X-Keywords' against notmuch database
+ tags_kw_new = []
+
+ def __init__(self, msg, filename=None):
+ self.message = msg
+ self.filename = filename
+ self.allfiles = [ fn for fn in msg.get_filenames() ]
+ self.tags = set(msg.get_tags())
+
+ def sync(self, direction, mode=SyncMode.ADD_REMOVE,
+ dryrun=False, verbose=False):
+ """
+ Wrapper function to sync between 'X-Keywords' and notmuch tags.
+ """
+ if direction == SyncDirection.KEYWORDS_TO_TAGS:
+ self.sync_keywords_to_tags(sync_mode=mode, dryrun=dryrun,
+ verbose=verbose)
+ elif direction == SyncDirection.TAGS_TO_KEYWORDS:
+ self.sync_tags_to_keywords(sync_mode=mode, dryrun=dryrun,
+ verbose=verbose)
+ elif direction == SyncDirection.MERGE_KEYWORDS_TAGS:
+ self.merge_keywords_tags(sync_mode=mode, dryrun=dryrun,
+ verbose=verbose)
+ else:
+ raise ValueError("Invalid sync direction: %s" % direction)
+
+ def sync_keywords_to_tags(self, sync_mode=SyncMode.ADD_REMOVE,
+ dryrun=False, verbose=False):
+ """
+ Wrapper function to sync 'X-Keywords' to notmuch tags.
+ """
+ self.get_keywords()
+ self.map_keywords()
+ self.merge_tags(sync_direction=SyncDirection.KEYWORDS_TO_TAGS,
+ sync_mode=sync_mode)
+ if dryrun or verbose:
+ print('* MSG: %s' % self.message)
+ print(' TAG: [%s] +[%s] -[%s] => [%s]' % (
+ ','.join(self.tags), ','.join(self.tags_added),
+ ','.join(self.tags_removed), ','.join(self.tags_new)))
+ if not dryrun:
+ self.update_tags()
+
+ def sync_tags_to_keywords(self, sync_mode=SyncMode.ADD_REMOVE,
+ dryrun=False, verbose=False):
+ """
+ Wrapper function to sync notmuch tags to 'X-Keywords'
+ """
+ self.get_keywords()
+ self.map_keywords()
+ self.merge_tags(sync_direction=SyncDirection.TAGS_TO_KEYWORDS,
+ sync_mode=sync_mode)
+ keywords_new = self.map_tags(tags=self.tags_kw_new)
+ if dryrun or verbose:
+ print('* MSG: %s' % self.message)
+ print('* FILES: %s' % ' ; '.join(self.allfiles))
+ print(' XKW: {%s} +[%s] -[%s] => {%s}' % (
+ ','.join(self.keywords), ','.join(self.tags_kw_added),
+ ','.join(self.tags_kw_removed), ','.join(keywords_new)))
+ if not dryrun:
+ self.update_keywords(keywords_new=keywords_new)
+
+ def merge_keywords_tags(self, sync_mode=SyncMode.ADD_REMOVE,
+ dryrun=False, verbose=False):
+ """
+ Wrapper function to merge 'X-Keywords' and notmuch tags
+ """
+ self.get_keywords()
+ self.map_keywords()
+ self.merge_tags(sync_direction=SyncDirection.MERGE_KEYWORDS_TAGS,
+ sync_mode=sync_mode)
+ keywords_new = self.map_tags(tags=self.tags_kw_new)
+ if dryrun or verbose:
+ print('* MSG: %s' % self.message)
+ print('* FILES: %s' % ' ; '.join(self.allfiles))
+ print(' TAG: [%s] +[%s] -[%s] => [%s]' % (
+ ','.join(self.tags), ','.join(self.tags_added),
+ ','.join(self.tags_removed), ','.join(self.tags_new)))
+ print(' XKW: {%s} +[%s] -[%s] => {%s}' % (
+ ','.join(self.keywords), ','.join(self.tags_kw_added),
+ ','.join(self.tags_kw_removed), ','.join(keywords_new)))
+ if not dryrun:
+ self.update_tags()
+ self.update_keywords(keywords_new=keywords_new)
+
+ def get_keywords(self):
+ """
+ Get 'X-Keywords' header from all files associated with the same
+ message, decode, split and merge.
+
+ NOTE: Do NOT simply use the `message.get_header()` method, which
+ cannot get the complete keywords from all files.
+ """
+ keywords_utf7 = []
+ for fn in self.allfiles:
+ msg = email.message_from_file(open(fn, 'r'))
+ val = msg['X-Keywords']
+ if val:
+ keywords_utf7.append(val)
+ else:
+ print("WARNING: 'X-Keywords' header not found or empty " +\
+ "for file: %s" % fn, file=sys.stderr)
+ keywords_utf7 = ','.join(keywords_utf7)
+ if keywords_utf7 != '':
+ keywords = imapUTF7Decode(keywords_utf7.encode()).split(',')
+ keywords = [ kw.strip() for kw in keywords ]
+ # Remove duplications
+ keywords = set(keywords)
+ else:
+ keywords = set()
+ self.keywords = keywords
+ return keywords
+
+ def map_keywords(self, keywords=None):
+ """
+ Map keywords to notmuch tags according to the mapping table.
+ """
+ if keywords is None:
+ keywords = self.keywords
+ if self.enable_replace_chars:
+ # Replace specified characters in keywords
+ trans = str.maketrans(self.chars_replace)
+ keywords = [ kw.translate(trans) for kw in keywords ]
+ # Map keywords to tags
+ tags = set([ self.keywords_mapping.get(kw, kw) for kw in keywords ])
+ self.tags_kw = tags
+ return tags
+
+ def map_tags(self, tags=None):
+ """
+ Map tags to keywords according to the inversed mapping table.
+ """
+ if tags is None:
+ tags = self.tags
+ if self.enable_replace_chars:
+ # Inversely replace specified characters in tags
+ chars_replace_inv = { v: k for k, v in self.chars_replace.items() }
+ trans = str.maketrans(chars_replace_inv)
+ tags = [ tag.translate(trans) for tag in tags ]
+ # Map keywords to tags
+ keywords_mapping_inv = { v:k for k,v in self.keywords_mapping.items() }
+ keywords = set([ keywords_mapping_inv.get(tag, tag) for tag in tags ])
+ self.keywords_tags = keywords
+ return keywords
+
+ def merge_tags(self, sync_direction, sync_mode=SyncMode.ADD_REMOVE,
+ tags_nm=None, tags_kw=None):
+ """
+ Merge the tags from notmuch database and 'X-Keywords' header,
+ according to the specified sync direction and operation restriction.
+
+ TODO: support case-insensitive set operations
+ """
+ # Added & removed tags for notmuch database against 'X-Keywords'
+ tags_added = []
+ tags_removed = []
+ # Newly updated/merged notmuch tags against 'X-Keywords'
+ tags_new = []
+ # Added & removed tags for 'X-Keywords' against notmuch database
+ tags_kw_added = []
+ tags_kw_removed = []
+ # Newly updated/merged tags for 'X-Keywords' against notmuch database
+ tags_kw_new = []
+ #
+ if tags_nm is None:
+ tags_nm = self.tags
+ if tags_kw is None:
+ tags_kw = self.tags_kw
+ if self.enable_ignore_tags:
+ # Remove ignored tags before merge
+ tags_nm2 = tags_nm.difference(self.tags_ignored)
+ tags_kw2 = tags_kw.difference(self.tags_ignored)
+ else:
+ tags_nm2 = tags_nm
+ tags_kw2 = tags_kw
+ #
+ if sync_direction == SyncDirection.KEYWORDS_TO_TAGS:
+ # Sync 'X-Keywords' to notmuch tags
+ tags_added = tags_kw2.difference(tags_nm2)
+ tags_removed = tags_nm2.difference(tags_kw2)
+ elif sync_direction == SyncDirection.TAGS_TO_KEYWORDS:
+ # Sync notmuch tags to 'X-Keywords'
+ tags_kw_added = tags_nm2.difference(tags_kw2)
+ tags_kw_removed = tags_kw2.difference(tags_nm2)
+ elif sync_direction == SyncDirection.MERGE_KEYWORDS_TAGS:
+ # Merge both notmuch tags and 'X-Keywords'
+ tags_merged = tags_nm2.union(tags_kw2)
+ # notmuch tags
+ tags_added = tags_merged.difference(tags_nm2)
+ tags_removed = tags_nm2.difference(tags_merged)
+ # tags for 'X-Keywords'
+ tags_kw_added = tags_merged.difference(tags_kw2)
+ tags_kw_removed = tags_kw2.difference(tags_merged)
+ else:
+ raise ValueError("Invalid synchronization direction")
+ # Apply sync operation restriction
+ self.tags_added = []
+ self.tags_removed = []
+ self.tags_kw_added = []
+ self.tags_kw_removed = []
+ tags_new = tags_nm # Use un-ignored notmuch tags
+ tags_kw_new = tags_kw # Use un-ignored 'X-Keywords' tags
+ if sync_mode != SyncMode.REMOVE_ONLY:
+ self.tags_added = tags_added
+ self.tags_kw_added = tags_kw_added
+ tags_new = tags_new.union(tags_added)
+ tags_kw_new = tags_kw_new.union(tags_kw_added)
+ if sync_mode != SyncMode.ADD_ONLY:
+ self.tags_removed = tags_removed
+ self.tags_kw_removed = tags_kw_removed
+ tags_new = tags_new.difference(tags_removed)
+ tags_kw_new = tags_kw_new.difference(tags_kw_removed)
+ #
+ self.tags_new = tags_new
+ self.tags_kw_new = tags_kw_new
+ if self.tags_added or self.tags_removed:
+ self.tags_updated = True
+ if self.tags_kw_added or self.tags_kw_removed:
+ self.keywords_updated = True
+ #
+ return {
+ 'tags_updated' : self.tags_updated,
+ 'tags_added' : self.tags_added,
+ 'tags_removed' : self.tags_removed,
+ 'tags_new' : self.tags_new,
+ 'keywords_updated' : self.keywords_updated,
+ 'tags_kw_added' : self.tags_kw_added,
+ 'tags_kw_removed' : self.tags_kw_removed,
+ 'tags_kw_new' : self.tags_kw_new,
+ }
+
+ def update_keywords(self, keywords_new=None, outfile=None):
+ """
+ Encode the keywords (default: self.keywords_new) and write back to
+ all message files.
+
+ If parameter 'outfile' specified, then write the updated message
+ to that file instead of overwriting.
+
+ NOTE:
+ * The modification time of the message file should be kept to prevent
+ OfflineIMAP from treating it as a new one (and the previous a
+ deleted one).
+ * All files associated with the same message are updated to have
+ the same 'X-Keywords' header.
+ """
+ if not self.keywords_updated:
+ # keywords NOT updated, just skip
+ return
+
+ if keywords_new is None:
+ keywords_new = self.keywords_new
+ #
+ if outfile is not None:
+ infile = self.allfiles[0:1]
+ outfile = [ os.path.expanduser(outfile) ]
+ else:
+ infile = self.allfiles
+ outfile = self.allfiles
+ #
+ for ifname, ofname in zip(infile, outfile):
+ msg = email.message_from_file(open(ifname, 'r'))
+ fstat = os.stat(ifname)
+ if keywords_new == []:
+ # Delete 'X-Keywords' header
+ print("WARNING: delete 'X-Keywords' header from file: %s" %
+ ifname, file=sys.stderr)
+ del msg['X-Keywords']
+ else:
+ # Update 'X-Keywords' header
+ keywords = ','.join(keywords_new)
+ keywords_utf7 = imapUTF7Encode(keywords).decode()
+ # Delete then add, to avoid multiple occurrences
+ del msg['X-Keywords']
+ msg['X-Keywords'] = keywords_utf7
+ # Write updated message
+ with open(ofname, 'w') as fp:
+ fp.write(msg.as_string())
+ # Reset the timestamps
+ os.utime(ofname, ns=(fstat.st_atime_ns, fstat.st_mtime_ns))
+
+ def update_tags(self, tags_added=None, tags_removed=None):
+ """
+ Update notmuch tags according to keywords.
+ """
+ if not self.tags_updated:
+ # tags NOT updated, just skip
+ return
+
+ if tags_added is None:
+ tags_added = self.tags_added
+ if tags_removed is None:
+ tags_removed = self.tags_removed
+ # Use freeze/thaw for safer transactions to change tag values.
+ self.message.freeze()
+ for tag in tags_added:
+ self.message.add_tag(tag, sync_maildir_flags=False)
+ for tag in tags_removed:
+ self.message.remove_tag(tag, sync_maildir_flags=False)
+ self.message.thaw()
+
+
+def get_notmuch_revision(dbpath=None):
+ """
+ Get the current revision and UUID of notmuch database.
+ """
+ import subprocess
+ import tempfile
+ if dbpath:
+ tf = tempfile.NamedTemporaryFile()
+ # Create a minimal notmuch config for the specified dbpath
+ config = '[database]\npath=%s\n' % os.path.expanduser(dbpath)
+ tf.file.write(config.encode())
+ tf.file.flush()
+ cmd = 'notmuch --config=%s count --lastmod' % tf.name
+ output = subprocess.check_output(cmd, shell=True)
+ tf.close()
+ else:
+ cmd = 'notmuch count --lastmod'
+ output = subprocess.check_output(cmd, shell=True)
+ # Extract output
+ dbinfo = output.decode().split()
+ return { 'revision': int(dbinfo[2]), 'uuid': dbinfo[1] }
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Sync message 'X-Keywords' header with notmuch tags.")
+ parser.add_argument("-V", "--version", action="version",
+ version="%(prog)s " + "v%s (%s)" % (__version__, __date__))
+ parser.add_argument("-q", "--query", dest="query", required=True,
+ help="notmuch database query string")
+ parser.add_argument("-p", "--db-path", dest="dbpath",
+ help="notmuch database path (default to try user configuration)")
+ parser.add_argument("-n", "--dry-run", dest="dryrun",
+ action="store_true", help="dry run")
+ parser.add_argument("-v", "--verbose", dest="verbose",
+ action="store_true", help="show verbose information")
+ # Exclusive argument group for sync mode
+ exgroup1 = parser.add_mutually_exclusive_group(required=True)
+ exgroup1.add_argument("-m", "--merge-keywords-tags",
+ dest="direction_merge", action="store_true",
+ help="merge 'X-Keywords' and tags and update both")
+ exgroup1.add_argument("-k", "--keywords-to-tags",
+ dest="direction_keywords2tags", action="store_true",
+ help="sync 'X-Keywords' to notmuch tags")
+ exgroup1.add_argument("-t", "--tags-to-keywords",
+ dest="direction_tags2keywords", action="store_true",
+ help="sync notmuch tags to 'X-Keywords'")
+ # Exclusive argument group for tag operation mode
+ exgroup2 = parser.add_mutually_exclusive_group(required=False)
+ exgroup2.add_argument("-a", "--add-only", dest="mode_addonly",
+ action="store_true", help="only add notmuch tags")
+ exgroup2.add_argument("-r", "--remove-only", dest="mode_removeonly",
+ action="store_true", help="only remove notmuch tags")
+ # Parse
+ args = parser.parse_args()
+ # Sync direction
+ if args.direction_merge:
+ sync_direction = SyncDirection.MERGE_KEYWORDS_TAGS
+ elif args.direction_keywords2tags:
+ sync_direction = SyncDirection.KEYWORDS_TO_TAGS
+ elif args.direction_tags2keywords:
+ sync_direction = SyncDirection.TAGS_TO_KEYWORDS
+ else:
+ raise ValueError("Invalid synchronization direction")
+ # Sync mode
+ if args.mode_addonly:
+ sync_mode = SyncMode.ADD_ONLY
+ elif args.mode_removeonly:
+ sync_mode = SyncMode.REMOVE_ONLY
+ else:
+ sync_mode = SyncMode.ADD_REMOVE
+ #
+ if args.dbpath:
+ dbpath = os.path.abspath(os.path.expanduser(args.dbpath))
+ else:
+ dbpath = None
+ #
+ db = Database(path=dbpath, create=False, mode=Database.MODE.READ_WRITE)
+ dbinfo = get_notmuch_revision(dbpath=dbpath)
+ q = Query(db, args.query)
+ total_msgs = q.count_messages()
+ msgs = q.search_messages()
+ #
+ if args.verbose:
+ print("# Notmuch database path: %s" % dbpath)
+ print("# Database revision: %d (uuid: %s)" %
+ (dbinfo['revision'], dbinfo['uuid']))
+ print("# Query: %s" % args.query)
+ print("# Sync direction: %s" % sync_direction.name)
+ print("# Sync mode: %s" % sync_mode.name)
+ print("# Total messages to check: %d" % total_msgs)
+ print("# Dryn run: %s" % args.dryrun)
+ #
+ for msg in msgs:
+ kwmsg = KwMessage(msg)
+ kwmsg.sync(direction=sync_direction, mode=sync_mode,
+ dryrun=args.dryrun, verbose=args.verbose)
+ #
+ db.close()
+
+
+if __name__ == "__main__":
+ main()
+
+# vim: set ts=4 sw=4 tw=0 fenc= ft=python: #