From f372ced0663836c4bde13794cafd25cb51111d3c Mon Sep 17 00:00:00 2001
From: Aaron LI <aly@aaronly.me>
Date: Sun, 13 Aug 2017 01:45:30 +0800
Subject: sky.py: Add float32/clobber/checksum etc. to SkyBase

---
 fg21sim/sky.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 61 insertions(+), 6 deletions(-)

(limited to 'fg21sim')

diff --git a/fg21sim/sky.py b/fg21sim/sky.py
index 1fca698..2e5d4a8 100644
--- a/fg21sim/sky.py
+++ b/fg21sim/sky.py
@@ -28,12 +28,54 @@ from .utils.units import UnitConversions as AUC
 logger = logging.getLogger(__name__)
 
 
-class SkyPatch:
 class SkyBase:
     """
     The base class for both the sky patch and HEALPix all-sky
     map classes.
+
+    Attributes
+    ----------
+    type_ : str
+        The type of the sky image
+        Values: ``patch`` or ``healpix``
+    data : `~numpy.ndarray`
+        The data array read from input sky image, or to be written into
+        output FITS file.
+    frequency_ : float
+        The frequency of the input/output sky image.
+        Unit: [MHz]
+    creator_ : str
+        The creator of the (output) sky image.
+        Default: ``__name__``
+    header_ : `~astropy.io.fits.Header`
+        The FITS header information of the input/output file.
+    float32_ : bool
+        Whether to use single/float32 data type to save the sky image?
+        Default: True
+    clobber_ : bool, optional
+        Whether to overwrite the existing output file.
+        Default: False
+    checksum_ : bool, optional
+        Whether to calculate the checksum data for the output
+        FITS file, which may cost some time.
+        Default: False
     """
+    def __init__(self, float32=True, clobber=False, checksum=False):
+        self.type_ = None
+        self.data = None
+        self.frequency_ = None
+        self.creator_ = __name__
+        self.header_ = fits.Header()
+        self.float32_ = float32
+        self.clobber_ = clobber
+        self.checksum_ = checksum
+
+    @property
+    def shape(self):
+        """
+        Numpy array shape of the (current/output) sky data.
+        """
+        return self.data.shape
 
     @property
     def frequency(self):
@@ -142,6 +184,7 @@ class SkyBase:
         raise NotImplementedError
 
 
+class SkyPatch(SkyBase):
     """
     Support reading & writing FITS file of sky patches.
 
@@ -195,8 +238,8 @@ class SkyBase:
     coordinates = None
 
     def __init__(self, size, pixelsize, center=(0.0, 0.0),
-                 infile=None, frequency=None):
-        self.xcenter, self.ycenter = center
+                 infile=None, frequency=None, **kwargs):
+        super().__init__(**kwargs)
         self.xsize, self.ysize = size
         self.pixelsize = pixelsize
         if infile is not None:
@@ -453,10 +496,12 @@ class SkyBase:
         return (lon, lat)
 
 
-class SkyHealpix:
+class SkyHealpix(SkyBase):
     """
     Support the HEALPix all-sky map.
 
+    XXX/TODO: Update against ``SkyBase`` and ``SkyPatch``!!
+
     Parameters
     ----------
     nside : int
@@ -637,6 +682,9 @@ class SkyHealpix:
         return (lon, lat)
 
 
+##########################################################################
+
+
 def get_sky(configs):
     """
     Sky class factory function to support both the sky patch and
@@ -648,6 +696,13 @@ def get_sky(configs):
         An `ConfigManager` object contains default and user configurations.
         For more details, see the example config specification.
     """
+    # Parameters for the base sky class
+    kwargs = {
+        "float32": configs.getn("output/use_float"),
+        "clobber": configs.getn("output/clobber"),
+        "checksum": configs.getn("output/checksum"),
+    }
+
     skytype = configs.getn("sky/type")
     if skytype == "patch":
         sec = "sky/patch"
@@ -657,10 +712,10 @@ def get_sky(configs):
         ycenter = configs.getn(sec+"/ycenter")
         pixelsize = configs.getn(sec+"/pixelsize")
         return SkyPatch(size=(xsize, ysize), pixelsize=pixelsize,
-                        center=(xcenter, ycenter))
+                        center=(xcenter, ycenter), **kwargs)
     elif skytype == "healpix":
         sec = "sky/healpix"
         nside = configs.getn(sec+"/nside")
-        return SkyHealpix(nside=nside)
+        return SkyHealpix(nside=nside, **kwargs)
     else:
         raise ValueError("unknown sky type: %s" % skytype)
-- 
cgit v1.2.2