aboutsummaryrefslogtreecommitdiffstats
path: root/astro/fitstable_merge.py
blob: 3732d9979ca5375b57ac6d1c103ac5b0f05804a5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#!/usr/bin/env python3
#
# Copyright (c) 2015 Weitian LI <weitian@aaronly.me>
# MIT license
#

"""
Merge several (>=2) of FITS file.

By default the *first* extend tables are merged and write out to a new
FITS file containing the *common* columns.  If the data types of the
columns of each FITS table do not match, then the data type of the column
of the *first* FITS table is used, and other columns are coerced.

If the FITS files have only *1* HDU (i.e., the Primary HDU), then data of
these HDU's are summed up to make up the output FITS file (an image),
on conditional that the shapes of all these HDU's are the same.
"""

__version__ = "0.3.0"
__date__ = "2015/06/17"

# default blockname to be merged
BLOCKNAME_DFT = "EVENTS"

DEBUG = True

import sys
import argparse
import re

import numpy as np
try:
    from astropy.io import fits
except ImportError:
    try:
        import pyfits as fits
    except ImportError:
        raise ImportError("cannot import 'astropy.io.fits' or 'pyfits'")


def merge2fits(file1, file2, block1=1, block2=1, columns=None):
    """
    Merge *two* FITS files of the given blocks (extension table),
    and return the merged FITS object.

    TODO:
        * log history to header

    Arguments:
        file1, file2: input two FITS files
        block1, block2: table number or table name to be merged
        columns: the columns to be merged; by default to merge the
            common columns

    Return:
        the merged FITS object
    """
    # open file if provide filename
    if isinstance(file1, str):
        file1 = fits.open(file1)
    if isinstance(file2, str):
        file2 = fits.open(file2)
    # if has only *1* HDU => image
    if len(file1) == 1:
        block1 = 0
    if len(file2) == 1:
        block2 = 0
    if block1 == 0 or block2 == 0:
        # TODO
        raise NotImplementedError("image FITS merge currently not supported!")
    # get table to be merged
    table1 = file1[block1]
    table2 = file2[block2]
    # create column names to be merged
    # get names of all columns (convert to upper case)
    colnames1 = [col.name.upper() for col in table1.columns]
    colnames2 = [col.name.upper() for col in table2.columns]
    colnames_common = list(set(colnames1).intersection(set(colnames2)))
    # sort the common column names acoording original column orders
    colnames_common.sort(key = lambda x: colnames1.index(x))
    if columns is not None:
        if isinstance(columns, list):
            columnlist = list(map(str.upper, columns))
        else:
            columnlist = list(columns.upper())
        # check the specified columns whether in the above colnames_common
        for name in columnlist:
            if name not in colnames_common:
                raise ValueError("column '%s' not found in both files" % name)
        # use the specified columns
        colnames_common = columnlist
    # "STATUS" columns don't have equal-length format, so remove it
    if "STATUS" in colnames_common:
        colnames_common.remove("STATUS")
    if DEBUG:
        print("DEBUG: columns to merge: ", colnames_common, file=sys.stderr)
    # filter out the common columns
    nrow1 = table1.data.shape[0]
    nrow2 = table2.data.shape[0]
    hdu_merged = fits.BinTableHDU.from_columns(
            fits.ColDefs([table1.columns[name] for name in colnames_common]),
            nrows=nrow1+nrow2)
    for name in colnames_common:
        if DEBUG:
            print("DEBUG: merging column: ", name, file=sys.stderr)
        dtype = hdu_merged.columns[name].array.dtype
        hdu_merged.columns[name].array[nrow1:] = \
                table2.columns[name].array.astype(dtype)
    # process headers, based on the header of the first FITS file
    # DO NOT strip the base header, in order to keep the position of
    # XTENSION/BITPIX/NAXIS/NAXIS1/NAXIS2/PCOUNT/GCOUNT/TFIELDS keywords.
    header = table1.header.copy()  # do not strip
    # IGNORE the header of the second FITS file to avoid keyword conflictions.
    #header2 = table2.header.copy(strip=True)
    ## merge two headers; COMMENT and HISTORY needs special handle
    #for comment in header2["COMMENT"]:
    #    header.add_comment(comment)
    #for history in header2["HISTORY"]:
    #    header.add_history(history)
    #if "COMMENT" in header2:
    #    del header2["COMMENT"]
    #if "HISTORY" in header2:
    #    del header2["HISTORY"]
    #if "" in header2:
    #    del header2[""]
    #header.update(header2)
    # remove the original TLMIN??/TLMAX??/TTYPE??/TFORM??/TUNIT?? keywords
    del_key_startswith(header,
            startswith=["TLMIN", "TLMAX", "TTYPE", "TFORM", "TUNIT"],
            lastlength=len(header))
    # update with new TLMIN??/TLMAX??/TTYPE??/TFORM??/TUNIT?? keywords
    header.update(hdu_merged.header)
    hdu_merged.header = header
    # copy PrimaryHDU from first FITS
    primary_hdu = file1[0].copy()
    # make HDUList and return
    return fits.HDUList([primary_hdu, hdu_merged])


def del_key_startswith(header, startswith, lastlength=0):
    """
    Delete the keys which start with the specified strings.

    Arguments:
        header: FITS table header
        startswith: a list of strings; If a key starts with any
            of these strings, then the key-value pair is removed.

    XXX: the deletion must be repeated several times until the
         length of the header does not decrease any more.
         (This may be a bug related to the astropy.io.fits???)
    """
    if not isinstance(startswith, list):
        startswith = list(startswith)
    re_key = re.compile(r"^(%s)" % "|".join(startswith), re.I)
    for k in header.keys():
        if re_key.match(k):
            del header[k]
    curlength = len(header)
    if lastlength == curlength:
        return
    else:
        # recursively continue deletion
        if DEBUG:
            print("DEBUG: recursively continue header keywords deleteion",
                    file=sys.stderr)
        del_key_startswith(header, startswith, curlength)


def get_filename_blockname(pstr):
    """
    Separate privided 'pstr' (parameter string) into filename and
    blockname.  If does not have a blockname, then the default
    blockname returned.
    """
    try:
        filename, blockname = re.sub(r"[\[\]]", " ", pstr).split()
    except ValueError:
        filename = pstr
        blockname = BLOCKNAME_DFT
    return (filename, blockname)


def main():
    parser = argparse.ArgumentParser(
            description="Merge several FITS files with the common columns.")
    parser.add_argument("-V", "--version", action="version",
            version="%(prog)s " + "%s (%s)" % (__version__, __date__))
    parser.add_argument("infile1", help="input FITS file 1; " + \
            "The blockname can be appended, e.g., infile1.fits[EVENTS]")
    parser.add_argument("infile2", nargs="+",
            help="input FITS file 2 and more")
    parser.add_argument("outfile", help="merged output file")
    parser.add_argument("-c", "--columns", dest="columns",
            help="list of columns to be merged (comma separated)")
    parser.add_argument("-C", "--clobber", dest="clobber",
            action="store_true", help="overwrite output file if exists")
    args = parser.parse_args()
    if DEBUG:
        print("DEBUG: infile2: ", args.infile2, file=sys.stderr)

    if args.columns:
        columns = args.columns.upper().replace(",", " ").split()
    file1, block1 = get_filename_blockname(args.infile1)
    merged_fits = fits.open(file1)
    for fitsfile in args.infile2:
        # split filename and block name
        file2, block2 = get_filename_blockname(fitsfile)
        merged_fits = merge2fits(merged_fits, file2, block1, block2, columns)
    merged_fits.writeto(args.outfile, checksum=True, clobber=args.clobber)


if __name__ == "__main__":
    main()