From e715adfe6b3021d605b31e431300d10870997017 Mon Sep 17 00:00:00 2001
From: Aaron LI <aaronly.me@outlook.com>
Date: Fri, 1 Apr 2016 22:08:46 +0800
Subject: correct_crosstalk.py: Split out class "SpectrumSet"; implement
 background subtraction

---
 python/correct_crosstalk.py | 209 ++++++++++++++++++++++++++++++++------------
 1 file changed, 154 insertions(+), 55 deletions(-)

diff --git a/python/correct_crosstalk.py b/python/correct_crosstalk.py
index 9f86763..c514504 100755
--- a/python/correct_crosstalk.py
+++ b/python/correct_crosstalk.py
@@ -1,6 +1,13 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 #
+# References:
+# [?] astropy - FITS format code
+#     http://docs.astropy.org/en/stable/io/fits/usage/table.html#column-creation
+# [?] XSPEC - Spectral Fitting
+#     https://heasarc.gsfc.nasa.gov/docs/xanadu/xspec/manual/XspecSpectralFitting.html
+#
+#
 # Weitian LI
 # Created: 2016-03-26
 # Updated: 2016-04-01
@@ -40,14 +47,20 @@ fix_negative = True
 outfile = cc_reg2.pi
 spec = reg2.pi
 arf = reg2.arf
+rmf = reg2.rmf
+bkg = reg2_bkg.pi
   [[cross_in]]
     [[[in1]]]
     spec = reg1.pi
     arf = reg1.arf
+    rmf = reg1.rmf
+    bkg = reg1_bkg.pi
     cross_arf = reg_1-2.arf
     [[[in2]]]
     spec = reg3.pi
     arf = reg3.arf
+    rmf = reg3.rmf
+    bkg = reg3_bkg.pi
     cross_arf = reg_3-2.arf
   [[cross_out]]
   cross_arf = reg_2-1.arf, reg_2-3.arf
@@ -288,49 +301,62 @@ class RMF:  # {{{
 
 class Spectrum:  # {{{
     """
-    Deal with X-ray spectrum (.pi)
-
-    TODO/XXX:
-    * Add background spectrum fields
-    * Subtract background spectrum before correcting crosstalk effects
-    * Strip keywords ANCRFILE, RESPFILE, BACKFILE from the output spectrum
-    * Estimate channel errors by Monte Carlo simulations
+    Class that deals with the X-ray spectrum file (usually *.pi).
 
-    NOTE:
-    The "COUNTS" column data are converted from "int32" to "float32".
+    TODO:
+    * to implement the grouping function (and quality columns)
     """
-    filename = None
+    filename  = None
     # FITS object return by `fits.open()'
-    fitsobj = None
+    fitsobj   = None
     # header of "SPECTRUM" extension
-    header = None
+    header    = None
     # "SPECTRUM" extension data
-    channel = None
-    # name of the column containing the spectrum data ("COUNTS" or "RATE")
-    spec_colname = None
+    channel   = None
+    # name of the spectrum data column (i.e., type, "COUNTS" or "RATE")
+    spec_type = None
+    # unit of the spectrum data ("count" for "COUNTS", "count/s" for "RATE")
+    spec_unit = None
     # spectrum data
     spec_data = None
-    # ARF object for this spectrum
-    arf = None
-    # RMF object for this spectrum
-    rmf = None
+    # several important keywords
+    EXPOSURE  = None
+    BACKSCAL  = None
+    RESPFILE  = None
+    ANCRFILE  = None
+    BACKFILE  = None
+    # numpy dtype and FITS format code of the spectrum data
+    spec_dtype       = None
+    spec_fits_format = None
 
-    def __init__(self, filename, arffile):
+    def __init__(self, filename):
         self.filename = filename
-        self.fitsobj = fits.open(filename)
-        ext_spec = self.fitsobj["SPECTRUM"]
-        self.header = ext_spec.header.copy(strip=True)
-        colnames = ext_spec.columns.names
+        self.fitsobj  = fits.open(filename)
+        ext_spec      = self.fitsobj["SPECTRUM"]
+        self.header   = ext_spec.header.copy(strip=True)
+        colnames      = ext_spec.columns.names
         if "COUNTS" in colnames:
-            self.spec_colname = "COUNTS"
+            self.spec_type        = "COUNTS"
+            self.spec_unit        = "count"
+            self.spec_dtype       = np.int32
+            self.spec_fits_format = "J"
         elif "RATE" in colnames:
-            self.spec_colname = "RATE"
+            self.spec_type        = "RATE"
+            self.spec_unit        = "count/s"
+            self.spec_dtype       = np.float32
+            self.spec_fits_format = "E"
         else:
             raise ValueError("Invalid spectrum file")
-        self.channel = ext_spec.data["CHANNEL"].copy()
-        self.spec_data = ext_spec.data.field(self.spec_colname)\
-                .astype(np.float32)
-        self.arf = ARF(arffile)
+        self.channel   = ext_spec.data["CHANNEL"].copy()
+        self.spec_data = ext_spec.data.field(self.spec_type)\
+                .astype(self.spec_dtype)
+        # keywords
+        self.EXPOSURE = self.header.get("EXPOSURE")
+        self.BACKSCAL = self.header.get("BACKSCAL")
+        self.AREASCAL = self.header.get("AREASCAL")
+        self.RESPFILE = self.header.get("RESPFILE")
+        self.ANCRFILE = self.header.get("ANCRFILE")
+        self.BACKFILE = self.header.get("BACKFILE")
 
     def get_data(self, copy=True):
         if copy:
@@ -344,6 +370,71 @@ class Spectrum:  # {{{
         else:
             return self.channel
 
+    def reset_header_keywords(self,
+            keywords=["ANCRFILE", "RESPFILE", "BACKFILE"]):
+        """
+        Reset the keywords to "NONE" to avoid confusion or mistakes.
+        """
+        for kw in keywords:
+            if kw in self.header:
+                header[kw] = "NONE"
+
+    def write(self, filename, clobber=False):
+        """
+        Create a new "SPECTRUM" table/extension and replace the original
+        one, then write to output file.
+        """
+        ext_spec_cols = fits.ColDefs([
+                fits.Column(name="CHANNEL", format="I", array=self.channel),
+                fits.Column(name=self.spec_type, format=self.spec_fits_format,
+                    unit=self.spec_unit, array=self.spec_data)])
+        ext_spec = fits.BinTableHDU.from_columns(ext_spec_cols,
+                header=self.header)
+        self.fitsobj["SPECTRUM"] = ext_spec
+        self.fitsobj.writeto(filename, clobber=clobber, checksum=True)
+# class Spectrum }}}
+
+
+class SpectrumSet(Spectrum):  # {{{
+    """
+    This class handles a set of spectrum, including the source spectrum,
+    RMF, ARF, and the background spectrum.
+
+    TODO:
+    * Subtract background spectrum before correcting crosstalk effects
+    * Estimate channel errors by Monte Carlo simulations
+
+    **NOTE**:
+    The "COUNTS" column data are converted from "int32" to "float32",
+    since this spectrum will be subtracted/compensated according to the
+    ratios of ARFs.
+    """
+    # ARF object for this spectrum
+    arf = None
+    # RMF object for this spectrum
+    rmf = None
+    # background Spectrum object for this spectrum
+    bkg = None
+
+    # numpy dtype and FITS format code to which the spectrum data be
+    # converted if the data is "COUNTS"
+    _spec_dtype       = np.float32
+    _spec_fits_format = "E"
+
+    def __init__(self, filename, arffile=None, rmffile=None, bkgfile=None):
+        super(self.__class__, self).__init__(filename)
+        # convert spectrum data type if necessary
+        if self.spec_data.dtype != self._spec_dtype:
+            self.spec_data        = self.spec_data.astype(self._spec_dtype)
+            self.spec_dtype       = self._spec_dtype
+            self.spec_fits_format = self._spec_fits_format
+        if arffile is not None:
+            self.arf = ARF(arffile)
+        if rmffile is not None:
+            self.rmf = RMF(rmffile)
+        if bkgfile is not None:
+            self.bkg = Spectrum(bkgfile)
+
     def get_energy(self, mean="geometric"):
         """
         Get the energy values of each channel if RMF present.
@@ -362,11 +453,39 @@ class Spectrum:  # {{{
         else:
             return self.arf.get_data(copy=copy)
 
+    def subtract_bkg(self, inplace=True):
+        """
+        Subtract the background contribution from the source spectrum.
+        The `EXPOSURE' and `BACKSCAL' values are required to calculate
+        the fraction/ratio for the background subtraction.
+
+        Arguments:
+          * inplace: whether replace the `spec_data' with the background-
+                     subtracted spectrum data; If True, the attribute
+                     `spec_bkg_subtracted' is also set to `True' when
+                     the subtraction finished.
+
+        Return:
+          background-subtracted spectrum data
+        """
+        ratio = (self.EXPOSURE / self.bkg.EXPOSURE) * \
+                (self.BACKSCAL / self.bkg.BACKSCAL) * \
+                (self.AREASCAL / self.bkg.AREASCAL)
+        spec_data_subbkg = self.spec_data - ratio * self.bkg.get_data()
+        if inplace:
+            self.spec_data = spec_data_subbkg
+            self.spec_bkg_subtracted = True
+        return spec_data_subbkg
+
     def subtract(self, spectrum, cross_arf, verbose=False):
         """
         Subtract the photons that originate from the surrounding regions
         but were scattered into this spectrum due to the finite PSF.
 
+        The background of this spectrum and the given spectrum should
+        both be subtracted before applying this subtraction for crosstalk
+        correction, as well as the below `compensate()' procedure.
+
         NOTE:
         The crosstalk ARF must be provided, since the `spectrum.arf' is
         required to be its ARF without taking crosstalk into account:
@@ -430,29 +549,6 @@ class Spectrum:  # {{{
             neg_channels = np.arange(N, dtype=np.int)[neg_counts]
         if i > 0:
             print("*** Fixed negative channels ***", file=sys.stderr)
-
-    def reset_header_keywords(self,
-            keywords=["ANCRFILE", "RESPFILE", "BACKFILE"]):
-        """
-        Reset the keywords to "NONE" to avoid confusion or mistakes.
-        """
-        for kw in keywords:
-            if kw in self.header:
-                header[kw] = "NONE"
-
-    def write(self, filename, clobber=False):
-        """
-        Create a new "SPECTRUM" table/extension and replace the original
-        one, then write to output file.
-        """
-        ext_spec_cols = fits.ColDefs([
-                fits.Column(name="CHANNEL", format="I", array=self.channel),
-                fits.Column(name="COUNTS", format="E", unit="count",
-                    array=self.spec_data)])
-        ext_spec = fits.BinTableHDU.from_columns(ext_spec_cols,
-                header=self.header)
-        self.fitsobj["SPECTRUM"] = ext_spec
-        self.fitsobj.writeto(filename, clobber=clobber, checksum=True)
 # class Spectrum }}}
 
 
@@ -476,13 +572,16 @@ class Crosstalk:  # {{{
 
     def __init__(self, config):
         """
-        `config': a section of the whole config file (`ConfigObj` object).
+        Arguments:
+          * config: a section of the whole config file (`ConfigObj' object)
         """
         self.cross_in_spec = []
         self.cross_in_arf  = []
         self.cross_out_arf = []
         # this spectrum to be corrected
-        self.spectrum = Spectrum(config["spec"], config["arf"])
+        self.spectrum = Spectrum(filename=config["spec"],
+                arffile=config["arf"], rmffile=config["rmf"],
+                bkgfile=config["bkg"])
         # spectra and cross arf from which photons were scattered in
         for reg_in in config["cross_in"].values():
             spec = Spectrum(reg_in["spec"], reg_in["arf"])
-- 
cgit v1.2.2