aboutsummaryrefslogtreecommitdiffstats
path: root/astro/spectrum/adjust_spectrum_error.py
blob: 0f80ec7ea92d1498d9f247151dcc0c8620d68270 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Squeeze the spectrum according to the grouping specification, then
calculate the statistical errors for each group, and apply error
adjustments (e.g., incorporate the systematic uncertainties).
"""

__version__ = "0.1.0"
__date__ = "2016-01-11"


import sys
import argparse

import numpy as np
from astropy.io import fits


class Spectrum:
    """
    Spectrum class to keep spectrum information and perform manipulations.
    """
    header = None
    channel = None
    counts = None
    grouping = None
    quality = None

    def __init__(self, specfile):
        f = fits.open(specfile)
        spechdu = f['SPECTRUM']
        self.header = spechdu.header
        self.channel = spechdu.data.field('CHANNEL')
        self.counts = spechdu.data.field('COUNTS')
        self.grouping = spechdu.data.field('GROUPING')
        self.quality = spechdu.data.field('QUALITY')
        f.close()

    def squeezeByGrouping(self):
        """
        Squeeze the spectrum according to the grouping specification,
        i.e., sum the counts belonging to the same group, and place the
        sum as the first channel within each group with other channels
        of counts zero's.
        """
        counts_squeezed = []
        cnt_sum = 0
        cnt_num = 0
        first = True
        for grp, cnt in zip(self.grouping, self.counts):
            if first and grp == 1:
                # first group
                cnt_sum = cnt
                cnt_num = 1
                first = False
            elif grp == 1:
                # save previous group
                counts_squeezed.append(cnt_sum)
                counts_squeezed += [ 0 for i in range(cnt_num-1) ]
                # start new group
                cnt_sum = cnt
                cnt_num = 1
            else:
                # group continues
                cnt_sum += cnt
                cnt_num += 1
        # last group
        # save previous group
        counts_squeezed.append(cnt_sum)
        counts_squeezed += [ 0 for i in range(cnt_num-1) ]
        self.counts_squeezed = np.array(counts_squeezed, dtype=np.int32)

    def calcStatErr(self, gehrels=False):
        """
        Calculate the statistical errors for the grouped channels,
        and save as the STAT_ERR column.
        """
        idx_nz = np.nonzero(self.counts_squeezed)
        stat_err = np.zeros(self.counts_squeezed.shape)
        if gehrels:
            # Gehrels
            stat_err[idx_nz] = 1 + np.sqrt(self.counts_squeezed[idx_nz] + 0.75)
        else:
            stat_err[idx_nz] = np.sqrt(self.counts_squeezed[idx_nz])
        self.stat_err = stat_err

    @staticmethod
    def parseSysErr(syserr):
        """
        Parse the string format of syserr supplied in the commandline.
        """
        items = map(str.strip, syserr.split(','))
        syserr_spec = []
        for item in items:
            spec = item.split(':')
            try:
                spec = (int(spec[0]), int(spec[1]), float(spec[2]))
            except:
                raise ValueError("invalid syserr specficiation")
            syserr_spec.append(spec)
        return syserr_spec

    def applySysErr(self, syserr):
        """
        Apply systematic error adjustments to the above calculated
        statistical errors.
        """
        syserr_spec = self.parseSysErr(syserr)
        for lo, hi, se in syserr_spec:
            err_adjusted = self.stat_err[(lo-1):(hi-1)] * np.sqrt(1+se)
        self.stat_err_adjusted = err_adjusted

    def updateHeader(self):
        """
        Update header accordingly.
        """
        # POISSERR
        self.header['POISSERR'] = False

    def write(self, filename, clobber=False):
        """
        Write the updated/modified spectrum block to file.
        """
        channel_col = fits.Column(name='CHANNEL', format='J',
                array=self.channel)
        counts_col = fits.Column(name='COUNTS', format='J',
                array=self.counts_squeezed)
        stat_err_col = fits.Column(name='STAT_ERR', format='D',
                array=self.stat_err_adjusted)
        grouping_col = fits.Column(name='GROUPING', format='I',
                array=self.grouping)
        quality_col = fits.Column(name='QUALITY', format='I',
                array=self.quality)
        spec_cols = fits.ColDefs([channel_col, counts_col, stat_err_col,
                                  grouping_col, quality_col])
        spechdu = fits.BinTableHDU.from_columns(spec_cols, header=self.header)
        spechdu.writeto(filename, clobber=clobber)


def main():
    parser = argparse.ArgumentParser(
            description="Apply systematic error adjustments to spectrum.")
    parser.add_argument("-V", "--version", action="version",
            version="%(prog)s " + "%s (%s)" % (__version__, __date__))
    parser.add_argument("infile", help="input spectrum file")
    parser.add_argument("outfile", help="output adjusted spectrum file")
    parser.add_argument("-e", "--syserr", dest="syserr", required=True,
            help="systematic error specification; " + \
                 "syntax: ch1low:ch1high:syserr1,...")
    parser.add_argument("-C", "--clobber", dest="clobber",
            action="store_true", help="overwrite output file if exists")
    parser.add_argument("-G", "--gehrels", dest="gehrels",
            action="store_true", help="use Gehrels error?")
    args = parser.parse_args()

    spec = Spectrum(args.infile)
    spec.squeezeByGrouping()
    spec.calcStatErr(gehrels=args.gehrels)
    spec.applySysErr(syserr=args.syserr)
    spec.updateHeader()
    spec.write(args.outfile, clobber=args.clobber)


if __name__ == "__main__":
    main()


#  vim: set ts=4 sw=4 tw=0 fenc=utf-8 ft=python: #