diff options
-rwxr-xr-x | astro/msutils.py | 352 |
1 files changed, 352 insertions, 0 deletions
diff --git a/astro/msutils.py b/astro/msutils.py new file mode 100755 index 0000000..debbee7 --- /dev/null +++ b/astro/msutils.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2016-2017 Sphesihle Makhathini +# Copyright (c) 2017 Aaron LI +# GNU General Public License v2.0 (GPLv2) +# + +""" +MSUtils - A set of CASA MeasurementSet manipulation tools +Based on: https://github.com/SpheMakh/msutils +""" + +import argparse +from collections import OrderedDict +from pprint import pprint + +import numpy as np +from casacore import tables +from casacore.tables import table, maketabdesc, makearrcoldesc + + +def getinfo(msname): + """ + Summarize the basic information of a MS. + + Parameters + ---------- + msname : str + Name of the MS + """ + tab = tables.table(msname, ack=False) + + info = OrderedDict([ + ("Ncol", tab.ncols()), + ("Nrow", tab.nrows()), + ("Ncor", tab.getcell("DATA", 0).shape[-1]), + ("Info", tab.info()), + ("Keywords", tab.getkeywords().keys()), + ("Columns", tab.colnames()), + ("ColKeywords", OrderedDict([ + (cname, tab.getcolkeywords(cname)) for cname in tab.colnames() + ])), + ("Exposure", tab.getcell("EXPOSURE", 0)), + ("FIELD", OrderedDict()), + ("SPW", OrderedDict()), + ("SCAN", OrderedDict()), + ]) + tabs = { + "FIELD": tables.table(msname+"/FIELD", ack=False), + "SPW": tables.table(msname+"/SPECTRAL_WINDOW", ack=False), + } + + field_ids = tabs["FIELD"].getcol("SOURCE_ID") + info["FIELD"]["STATE_ID"] = [None]*len(field_ids) + info["FIELD"]["PERIOD"] = [None]*len(field_ids) + for fid in field_ids: + ftab = tab.query("FIELD_ID=={0:d}".format(fid)) + state_id = ftab.getcol("STATE_ID")[0] + info["FIELD"]["STATE_ID"][fid] = int(state_id) + scans = {} + total_length = 0 + for scan in set(ftab.getcol("SCAN_NUMBER")): + stab = ftab.query("SCAN_NUMBER=={0:d}".format(scan)) + length = (stab.getcol("TIME").max() - stab.getcol("TIME").min()) + stab.close() + scans[str(scan)] = length + total_length += length + + info["SCAN"][str(fid)] = scans + info["FIELD"]["PERIOD"][fid] = total_length + ftab.close() + + for key, _tab in tabs.items(): + if key == "SPW": + colnames = ["CHAN_FREQ", "MEAS_FREQ_REF", + "REF_FREQUENCY", "TOTAL_BANDWIDTH", + "NAME", "NUM_CHAN", "IF_CONV_CHAIN", + "NET_SIDEBAND", "FREQ_GROUP_NAME"] + else: + colnames = _tab.colnames() + for name in colnames: + try: + info[key][name] = _tab.getcol(name).tolist() + except AttributeError: + info[key][name] = _tab.getcol(name) + _tab.close() + + # Get the minimum and maximum baselines + uv = tab.getcol("UVW")[:, :2] + baselines = np.sqrt(np.sum(uv**2, axis=1)) + info["Baseline"] = {"min": baselines.min(), "max": baselines.max()} + + tab.close() + return info + + +def addcol(msname, colname=None, shape=None, + data_desc_type="array", + valuetype=None, + init_with=None, + coldesc=None, + coldmi=None, + clone="DATA", + rowchunk=None): + """ + Add a column to MS + + Parameters + ---------- + msanme : str + MS to which to add the column + colname : str + Name of the column to be added + shape : shape + valuetype : data type + data_desc_type : + * ``scalar`` - scalar elements + * ``array`` - array elements + init_with : value to initialize the column with + """ + tab = table(msname, readonly=False) + + if colname in tab.colnames(): + print("Column already exists") + return "exists" + + print("Attempting to add %s column to %s" % (colname, msname)) + + valuetype = valuetype or "complex" + + if coldesc: + data_desc = coldesc + shape = coldesc["shape"] + elif shape: + data_desc = maketabdesc(makearrcoldesc(colname, + init_with, + shape=shape, + valuetype=valuetype)) + elif valuetype == "scalar": + data_desc = maketabdesc(makearrcoldesc(colname, + init_with, + valuetype=valuetype)) + elif clone: + element = tab.getcell(clone, 0) + try: + shape = element.shape + data_desc = maketabdesc(makearrcoldesc(colname, + element.flatten()[0], + shape=shape, + valuetype=valuetype)) + except AttributeError: + shape = [] + data_desc = maketabdesc(makearrcoldesc(colname, + element, + valuetype=valuetype)) + + colinfo = [data_desc, coldmi] if coldmi else [data_desc] + tab.addcols(*colinfo) + + print("Column added successfully.") + + if init_with is None: + tab.close() + return "added" + else: + spwids = set(tab.getcol("DATA_DESC_ID")) + for spw in spwids: + print("Initializing column {0}. DDID is {1}".format(colname, spw)) + tab_spw = tab.query("DATA_DESC_ID=={0:d}".format(spw)) + nrows = tab_spw.nrows() + + rowchunk = rowchunk or nrows/10 + dshape = [0] + [a for a in shape] + for row0 in range(0, nrows, rowchunk): + nr = min(rowchunk, nrows-row0) + dshape[0] = nr + print("Wrtiting to column %s (rows %d to %d)" % + (colname, row0, row0+nr-1)) + dtype = init_with.dtype + tab_spw.putcol(colname, + np.ones(dshape, dtype=dtype) * init_with, + row0, nr) + tab_spw.close() + tab.close() + + +def sumcols(msname, col1=None, col2=None, outcol=None, cols=None, + subtract=False): + """ + Add col1 to col2, or sum columns in "cols" list. + + Parameters + ---------- + subtract : bool + Subtract ``col2`` from ``col1`` + """ + tab = table(msname, readonly=False) + if outcol not in tab.colnames(): + print("outcol {0:s} does not exist, will add it first.".format(outcol)) + addcol(msname, outcol, clone=col1 or cols[0]) + + spws = set(tab.getcol("DATA_DESC_ID")) + for spw in spws: + tab_spw = tab.query("DATA_DESC_ID=={0:d}".format(spw)) + nrows = tab_spw.nrows() + rowchunk = nrows//10 if nrows > 10000 else nrows + for row0 in range(0, nrows, rowchunk): + nr = min(rowchunk, nrows-row0) + print("Wrtiting to column %s (rows %d to %d)" % + (outcol, row0, row0+nr-1)) + if subtract: + data = (tab_spw.getcol(col1, row0, nr) - + tab_spw.getcol(col2, row0, nr)) + else: + cols = cols or [col1, col2] + data = 0 + for col in cols: + data += tab.getcol(col, row0, nr) + + tab_spw.putcol(outcol, data, row0, nr) + tab_spw.close() + + tab.close() + + +def copycol(msname, fromcol, tocol): + """ + Copy data from one column to another + """ + + tab = table(msname, readonly=False) + if tocol not in tab.colnames(): + addcol(msname, tocol, clone=fromcol) + + spws = set(tab.getcol("DATA_DESC_ID")) + for spw in spws: + tab_spw = tab.query("DATA_DESC_ID=={0:d}".format(spw)) + nrows = tab_spw.nrows() + rowchunk = nrows//10 if nrows > 5000 else nrows + for row0 in range(0, nrows, rowchunk): + nr = min(rowchunk, nrows-row0) + data = tab_spw.getcol(fromcol, row0, nr) + tab_spw.putcol(tocol, data, row0, nr) + + tab_spw.close() + tab.close() + + +def calc_vis_noise(msname, sefd, spw_id=0): + """ + Calculate the nominal per-visibility noise + """ + tab = table(msname) + spwtab = table(msname + "/SPECTRAL_WINDOW") + + freq0 = spwtab.getcol("CHAN_FREQ")[spw_id, 0] + wavelength = 300e+6/freq0 + bw = spwtab.getcol("CHAN_WIDTH")[spw_id, 0] + dt = tab.getcol("EXPOSURE", 0, 1)[0] + dtf = (tab.getcol("TIME", tab.nrows()-1, 1)-tab.getcol("TIME", 0, 1))[0] + + tab.close() + spwtab.close() + + print("%s: frequency %.2f MHz (lambda=%.2fm)" % + (msname, freq0/1e6, wavelength)) + print("%s: bandwidth %.2g kHz, %.2fs integration, %.2fh synthesis" % + (bw*1e-3, dt, dtf/3600)) + noise = sefd / np.sqrt(abs(2*bw*dt)) + print("SEFD of %.2f Jy gives per-visibility noise of %.2f mJy" % + (sefd, noise*1000)) + + return noise + + +def addnoise(msname, column="MODEL_DATA", + noise=0, sefd=551, + rowchunk=None, + addToCol=None, + spw_id=None): + """ + Add Gaussian noise to MS, given a stdandard deviation (noise). + This noise can be also be calculated given SEFD value + """ + + tab = table(msname, readonly=False) + + multi_chan_noise = False + if hasattr(noise, "__iter__"): + multi_chan_noise = True + elif hasattr(sefd, "__iter__"): + multi_chan_noise = True + else: + noise = noise or calc_vis_noise(msname, sefd=sefd, + spw_id=spw_id or 0) + + spws = set(tab.getcol("DATA_DESC_ID")) + for spw in spws: + tab_spw = tab.query("DATA_DESC_ID=={0:d}".format(spw)) + nrows = tab_spw.nrows() + nchan, ncor = tab_spw.getcell("DATA", 0).shape + rowchunk = rowchunk or nrows/10 + for row0 in range(0, nrows, rowchunk): + nr = min(rowchunk, nrows-row0) + data = (np.random.randn(nr, nchan, ncor) + + 1j*np.random.randn(nr, nchan, ncor)) + if multi_chan_noise: + noise = noise[np.newaxis, :, np.newaxis] + data *= noise + + if addToCol: + data += tab_spw.getcol(addToCol, row0, nr) + print("%s + noise --> %s (rows %d to %d)" % + (addToCol, column, row0, row0+nr-1)) + else: + print("Adding noise to column %s (rows %d to %d)" % + (column, row0, row0+nr-1)) + + tab_spw.putcol(column, data, row0, nr) + tab_spw.close() + + tab.close() + + +def cmd_info(args): + """ + Sub-command: "info", show MS basic information + """ + msname = args.msname + info = getinfo(msname) + pprint(info) + + +def main(): + parser = argparse.ArgumentParser( + description="CASA MeasurementSet (MS) manipulation utilities") + + subparsers = parser.add_subparsers(dest="subparser_name", + title="sub-commands", + help="additional help") + + # sub-command: "info" + parser_info = subparsers.add_parser("info", help="show MS basic info") + parser_info.add_argument("msname", help="MS name") + parser_info.set_defaults(func=cmd_info) + + args = parser.parse_args() + args.func(args) + + +if __name__ == "__main__": + main() |