From fcb46e8afd5d89e1f36014c18205ef7ec772c246 Mon Sep 17 00:00:00 2001
From: Aaron LI <aaronly.me@outlook.com>
Date: Tue, 25 Oct 2016 21:07:54 +0800
Subject: grid.py: Optimize "map_grid_to_healpix()" using "numba.jit"

There are some limitations with the numba JIT, which are commented in
the code.
---
 fg21sim/utils/grid.py | 30 +++++++++++++++++++-----------
 1 file changed, 19 insertions(+), 11 deletions(-)

(limited to 'fg21sim')

diff --git a/fg21sim/utils/grid.py b/fg21sim/utils/grid.py
index 8831000..a3030ac 100644
--- a/fg21sim/utils/grid.py
+++ b/fg21sim/utils/grid.py
@@ -7,11 +7,11 @@ Grid utilities.
 
 
 import numpy as np
-from scipy import ndimage
-import healpy as hp
 import numba as nb
+from scipy import ndimage
 
 from .draw import ellipse
+from .healpix import ang2pix_ring
 
 
 @nb.jit(nopython=True)
@@ -180,6 +180,9 @@ def make_grid_ellipse(center, size, resolution, rotation=None):
     return (lon, lat, gridmap)
 
 
+@nb.jit(nb.types.Tuple((nb.int64[:], nb.float64[:]))(
+    nb.types.UniTuple(nb.float64[:, :], 3), nb.int64),
+        nopython=True)
 def map_grid_to_healpix(grid, nside):
     """Map the filled coordinate grid to the HEALPix map (RING ordering).
 
@@ -213,17 +216,22 @@ def map_grid_to_healpix(grid, nside):
     XXX/TODO:
     - Implement the flux-preserving algorithm (reference ???)
     """
-    lon, lat, gridmap = grid
+    # XXX: ``numba`` does not support using 2D array as indexes
+    lon = grid[0].flatten()
+    lat = grid[1].flatten()
+    gridmap = grid[2].flatten()
     phi = np.radians(lon)
     theta = np.radians(90.0 - lat)
-    ipix = hp.ang2pix(nside, theta, phi, nest=False)
+    ipix = ang2pix_ring(nside, theta, phi)
     # Get the corresponding input grid pixels for each HEALPix pixel
-    indexes, counts = np.unique(ipix, return_counts=True)
-    shape = (len(indexes), max(counts))
-    datamap = np.zeros(shape) * np.nan
-    # TODO: how to avoid this explicit loop ??
+    # XXX: ``numba`` currently does not support ``numpy.unique()``
+    ipix_perm = ipix.argsort()
+    ipix_sorted = ipix[ipix_perm]
+    idx_uniq = np.concatenate((np.array([True]),
+                               ipix_sorted[1:] != ipix_sorted[:-1]))
+    indexes = ipix_sorted[idx_uniq]
+    values = np.zeros(indexes.shape)
     for i, idx in enumerate(indexes):
-        pixels = gridmap[ipix == idx]
-        datamap[i, :len(pixels)] = pixels
-    values = np.nanmean(datamap, axis=1)
+        # XXX: ``numba`` does not support using 2D array as indexes
+        values[i] = np.mean(gridmap[ipix == idx])
     return (indexes, values)
-- 
cgit v1.2.2