aboutsummaryrefslogtreecommitdiffstats
path: root/python/msvst_starlet.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/msvst_starlet.py')
-rw-r--r--python/msvst_starlet.py602
1 files changed, 602 insertions, 0 deletions
diff --git a/python/msvst_starlet.py b/python/msvst_starlet.py
new file mode 100644
index 0000000..b90adf9
--- /dev/null
+++ b/python/msvst_starlet.py
@@ -0,0 +1,602 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# References:
+# [1] Jean-Luc Starck, Fionn Murtagh & Jalal M. Fadili
+# Sparse Image and Signal Processing: Wavelets, Curvelets, Morphological Diversity
+# Section 3.5, 6.6
+#
+# Credits:
+# [1] https://github.com/abrazhe/image-funcut/blob/master/imfun/atrous.py
+#
+# Aaron LI
+# Created: 2016-03-17
+# Updated: 2016-04-20
+#
+# ChangeLog:
+# 2016-04-20:
+# * Add argparse and main() for scripting
+#
+
+"""
+Starlet wavelet transform, i.e., isotropic undecimated wavelet transform
+(IUWT), or à trous wavelet transform.
+And multi-scale variance stabling transform (MS-VST), which can be used
+to effectively remove the Poisson noises.
+"""
+
+__version__ = "0.2.0"
+__date__ = "2016-04-20"
+
+
+import sys
+import os
+import argparse
+from datetime import datetime
+
+import numpy as np
+import scipy as sp
+from scipy import signal
+from astropy.io import fits
+
+
+class B3Spline: # {{{
+ """
+ B3-spline wavelet.
+ """
+ # scaling function (phi)
+ dec_lo = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) / 16
+ dec_hi = np.array([-1.0, -4.0, 10.0, -4.0, -1.0]) / 16
+ rec_lo = np.array([0.0, 0.0, 1.0, 0.0, 0.0])
+ rec_hi = np.array([0.0, 0.0, 1.0, 0.0, 0.0])
+# B3Spline }}}
+
+
+class IUWT: # {{{
+ """
+ Isotropic undecimated wavelet transform.
+ """
+ ## Decomposition filters list:
+ # a_{scale} = convole(a_0, filters[scale])
+ # Note: the zero-th scale filter (i.e., delta function) is the first
+ # element, thus the array index is the same as the decomposition scale.
+ filters = []
+
+ phi = None # wavelet scaling function (2D)
+ level = 0 # number of transform level
+ decomposition = None # decomposed coefficients/images
+ reconstruction = None # reconstructed image
+
+ # convolution boundary condition
+ boundary = "symm"
+
+ def __init__(self, phi=B3Spline.dec_lo, level=None, boundary="symm",
+ data=None):
+ self.set_wavelet(phi=phi)
+ self.level = level
+ self.boundary = boundary
+ self.data = np.array(data)
+
+ def reset(self):
+ """
+ Reset the object attributes.
+ """
+ self.data = None
+ self.phi = None
+ self.decomposition = None
+ self.reconstruction = None
+ self.level = 0
+ self.filters = []
+ self.boundary = "symm"
+
+ def load_data(self, data):
+ self.reset()
+ self.data = np.array(data)
+
+ def set_wavelet(self, phi):
+ self.reset()
+ phi = np.array(phi)
+ if phi.ndim == 1:
+ phi_ = phi.reshape(1, -1)
+ self.phi = np.dot(phi_.T, phi_)
+ elif phi.ndim == 2:
+ self.phi = phi
+ else:
+ raise ValueError("Invalid phi dimension")
+
+ def calc_filters(self):
+ """
+ Calculate the convolution filters of each scale.
+ Note: the zero-th scale filter (i.e., delta function) is the first
+ element, thus the array index is the same as the decomposition scale.
+ """
+ self.filters = []
+ # scale 0: delta function
+ h = np.array([[1]]) # NOTE: 2D
+ self.filters.append(h)
+ # scale 1
+ h = self.phi[::-1, ::-1]
+ self.filters.append(h)
+ for scale in range(2, self.level+1):
+ h_up = self.zupsample(self.phi, order=scale-1)
+ h2 = signal.convolve2d(h_up[::-1, ::-1], h, mode="same",
+ boundary=self.boundary)
+ self.filters.append(h2)
+
+ def transform(self, data, scale, boundary="symm"):
+ """
+ Perform only one scale wavelet transform for the given data.
+
+ return:
+ [ approx, detail ]
+ """
+ self.decomposition = []
+ approx = signal.convolve2d(data, self.filters[scale],
+ mode="same", boundary=self.boundary)
+ detail = data - approx
+ return [approx, detail]
+
+ def decompose(self, level, boundary="symm"):
+ """
+ Perform IUWT decomposition in the plain loop way.
+ The filters of each scale/level are calculated first, then the
+ approximations of each scale/level are calculated by convolving the
+ raw/finest image with these filters.
+
+ return:
+ [ W_1, W_2, ..., W_n, A_n ]
+ n = level
+ W: wavelet details
+ A: approximation
+ """
+ self.boundary = boundary
+ if self.level != level or self.filters == []:
+ self.level = level
+ self.calc_filters()
+ self.decomposition = []
+ approx = self.data
+ for scale in range(1, level+1):
+ # approximation:
+ approx2 = signal.convolve2d(self.data, self.filters[scale],
+ mode="same", boundary=self.boundary)
+ # wavelet details:
+ w = approx - approx2
+ self.decomposition.append(w)
+ if scale == level:
+ self.decomposition.append(approx2)
+ approx = approx2
+ return self.decomposition
+
+ def decompose_recursive(self, level, boundary="symm"):
+ """
+ Perform the IUWT decomposition in the recursive way.
+
+ return:
+ [ W_1, W_2, ..., W_n, A_n ]
+ n = level
+ W: wavelet details
+ A: approximation
+ """
+ self.level = level
+ self.boundary = boundary
+ self.decomposition = self.__decompose(self.data, self.phi, level=level)
+ return self.decomposition
+
+ def __decompose(self, data, phi, level):
+ """
+ 2D IUWT decomposition (or stationary wavelet transform).
+
+ This is a convolution version, where kernel is zero-upsampled
+ explicitly. Not fast.
+
+ Parameters:
+ - level : level of decomposition
+ - phi : low-pass filter kernel
+ - boundary : boundary conditions (passed to scipy.signal.convolve2d,
+ 'symm' by default)
+
+ Returns:
+ list of wavelet details + last approximation. Each element in
+ the list is an image of the same size as the input image.
+ """
+ if level <= 0:
+ return data
+ shapecheck = map(lambda a,b:a>b, data.shape, phi.shape)
+ assert np.all(shapecheck)
+ # approximation:
+ approx = signal.convolve2d(data, phi[::-1, ::-1], mode="same",
+ boundary=self.boundary)
+ # wavelet details:
+ w = data - approx
+ phi_up = self.zupsample(phi, order=1)
+ shapecheck = map(lambda a,b:a>b, data.shape, phi_up.shape)
+ if level == 1:
+ return [w, approx]
+ elif not np.all(shapecheck):
+ print("Maximum allowed decomposition level reached",
+ file=sys.stderr)
+ return [w, approx]
+ else:
+ return [w] + self.__decompose(approx, phi_up, level-1)
+
+ @staticmethod
+ def zupsample(data, order=1):
+ """
+ Upsample data array by interleaving it with zero's.
+
+ h{up_order: n}[l] = (1) h[l], if l % 2^n == 0;
+ (2) 0, otherwise
+ """
+ shape = data.shape
+ new_shape = [ (2**order * (n-1) + 1) for n in shape ]
+ output = np.zeros(new_shape, dtype=data.dtype)
+ output[[ slice(None, None, 2**order) for d in shape ]] = data
+ return output
+
+ def reconstruct(self, decomposition=None):
+ if decomposition is not None:
+ reconstruction = np.sum(decomposition, axis=0)
+ return reconstruction
+ else:
+ self.reconstruction = np.sum(self.decomposition, axis=0)
+
+ def get_detail(self, scale):
+ """
+ Get the wavelet detail coefficients of given scale.
+ Note: 1 <= scale <= level
+ """
+ if scale < 1 or scale > self.level:
+ raise ValueError("Invalid scale")
+ return self.decomposition[scale-1]
+
+ def get_approx(self):
+ """
+ Get the approximation coefficients of the largest scale.
+ """
+ return self.decomposition[-1]
+# IUWT }}}
+
+
+class IUWT_VST(IUWT): # {{{
+ """
+ IUWT with Multi-scale variance stabling transform.
+
+ Refernce:
+ [1] Bo Zhang, Jalal M. Fadili & Jean-Luc Starck,
+ IEEE Trans. Image Processing, 17, 17, 2008
+ """
+ # VST coefficients and the corresponding asymptotic standard deviation
+ # of each scale.
+ vst_coef = []
+
+ def reset(self):
+ super(self.__class__, self).reset()
+ vst_coef = []
+
+ def __decompose(self):
+ raise AttributeError("No '__decompose' attribute")
+
+ @staticmethod
+ def soft_threshold(data, threshold):
+ if isinstance(data, np.ndarray):
+ data_th = data.copy()
+ data_th[np.abs(data) <= threshold] = 0.0
+ data_th[data > threshold] -= threshold
+ data_th[data < -threshold] += threshold
+ else:
+ data_th = data
+ if np.abs(data) <= threshold:
+ data_th = 0.0
+ elif data > threshold:
+ data_th -= threshold
+ else:
+ data_th += threshold
+ return data_th
+
+ def tau(self, k, scale):
+ """
+ Helper function used in VST coefficients calculation.
+ """
+ return np.sum(np.power(self.filters[scale], k))
+
+ def filters_product(self, scale1, scale2):
+ """
+ Calculate the scalar product of the filters of two scales,
+ considering only the overlapped part.
+ Helper function used in VST coefficients calculation.
+ """
+ if scale1 > scale2:
+ filter_big = self.filters[scale1]
+ filter_small = self.filters[scale2]
+ else:
+ filter_big = self.filters[scale2]
+ filter_small = self.filters[scale1]
+ # crop the big filter to match the size of the small filter
+ size_big = filter_big.shape
+ size_small = filter_small.shape
+ size_diff2 = list(map(lambda a,b: (a-b)//2, size_big, size_small))
+ filter_big_crop = filter_big[
+ size_diff2[0]:(size_big[0]-size_diff2[0]),
+ size_diff2[1]:(size_big[1]-size_diff2[1])]
+ assert(np.all(list(map(lambda a,b: a==b,
+ size_small, filter_big_crop.shape))))
+ product = np.sum(filter_small * filter_big_crop)
+ return product
+
+ def calc_vst_coef(self):
+ """
+ Calculate the VST coefficients and the corresponding
+ asymptotic standard deviation of each scale, according to the
+ calculated filters of each scale/level.
+ """
+ self.vst_coef = []
+ for scale in range(self.level+1):
+ b = 2 * np.sqrt(np.abs(self.tau(1, scale)) / self.tau(2, scale))
+ c = 7.0*self.tau(2, scale) / (8.0*self.tau(1, scale)) - \
+ self.tau(3, scale) / (2.0*self.tau(2, scale))
+ if scale == 0:
+ std = -1.0
+ else:
+ std = np.sqrt((self.tau(2, scale-1) / \
+ (4 * self.tau(1, scale-1)**2)) + \
+ (self.tau(2, scale) / (4 * self.tau(1, scale)**2)) - \
+ (self.filters_product(scale-1, scale) / \
+ (2 * self.tau(1, scale-1) * self.tau(1, scale))))
+ self.vst_coef.append({ "b": b, "c": c, "std": std })
+
+ def vst(self, data, scale, coupled=True):
+ """
+ Perform variance stabling transform
+
+ XXX: parameter `coupled' why??
+ Credit: MSVST-V1.0/src/libmsvst/B3VSTAtrous.h
+ """
+ self.vst_coupled = coupled
+ if self.vst_coef == []:
+ self.calc_vst_coef()
+ if coupled:
+ b = 1.0
+ else:
+ b = self.vst_coef[scale]["b"]
+ data_vst = b * np.sqrt(np.abs(data + self.vst_coef[scale]["c"]))
+ return data_vst
+
+ def ivst(self, data, scale, cbias=True):
+ """
+ Inverse variance stabling transform
+ NOTE: assuming that `a_{j} + c^{j}' are all positive.
+
+ XXX: parameter `cbias' why??
+ `bias correction' is recommended while reconstruct the data
+ after estimation
+ Credit: MSVST-V1.0/src/libmsvst/B3VSTAtrous.h
+ """
+ self.vst_cbias = cbias
+ if cbias:
+ cb = 1.0 / (self.vst_coef[scale]["b"] ** 2)
+ else:
+ cb = 0.0
+ data_ivst = data ** 2 + cb - self.vst_coef[scale]["c"]
+ return data_ivst
+
+ def is_significant(self, scale, fdr=0.1, independent=False):
+ """
+ Multiple hypothesis testing with false discovery rate (FDR) control.
+
+ `independent': whether the test statistics of all the null
+ hypotheses are independent.
+ If `independent=True': FDR <= (m0/m) * q
+ otherwise: FDR <= (m0/m) * q * (1 + 1/2 + 1/3 + ... + 1/m)
+
+ References:
+ [1] False discovery rate - Wikipedia
+ https://en.wikipedia.org/wiki/False_discovery_rate
+ """
+ coef = self.get_detail(scale)
+ pvalues = 2 * (1 - sp.stats.norm.cdf(np.abs(coef) / \
+ self.vst_coef[scale]["std"]))
+ p_sorted = pvalues.flatten()
+ p_sorted.sort()
+ N = len(p_sorted)
+ if independent:
+ cn = 1.0
+ else:
+ cn = np.sum(1.0 / np.arange(1, N+1))
+ p_comp = fdr * np.arange(N) / (N * cn)
+ comp = (p_sorted < p_comp)
+ # cutoff p-value after FDR control/correction
+ p_cutoff = np.max(p_sorted[comp])
+ return (pvalues <= p_cutoff, p_cutoff)
+
+ def denoise(self, fdr=0.1, fdr_independent=False):
+ """
+ Denoise the wavelet coefficients by controlling FDR.
+ """
+ self.fdr = fdr
+ self.fdr_indepent = fdr_independent
+ self.denoised = []
+ # supports of significant coefficients of each scale
+ self.sig_supports = [None] # make index match the scale
+ self.p_cutoff = [None]
+ for scale in range(1, self.level+1):
+ coef = self.get_detail(scale)
+ sig, p_cutoff = self.is_significant(scale, fdr, fdr_independent)
+ coef[np.logical_not(sig)] = 0.0
+ self.denoised.append(coef)
+ self.sig_supports.append(sig)
+ self.p_cutoff.append(p_cutoff)
+ # append the last approximation
+ self.denoised.append(self.get_approx())
+
+ def decompose(self, level, boundary="symm"):
+ """
+ 2D IUWT decomposition with VST.
+ """
+ self.boundary = boundary
+ if self.level != level or self.filters == []:
+ self.level = level
+ self.calc_filters()
+ self.calc_vst_coef()
+ self.decomposition = []
+ approx = self.data
+ for scale in range(1, level+1):
+ # approximation:
+ approx2 = signal.convolve2d(self.data, self.filters[scale],
+ mode="same", boundary=self.boundary)
+ # wavelet details:
+ w = self.vst(approx, scale=scale-1) - self.vst(approx2, scale=scale)
+ self.decomposition.append(w)
+ if scale == level:
+ self.decomposition.append(approx2)
+ approx = approx2
+ return self.decomposition
+
+ def reconstruct_ivst(self, denoised=True, positive_project=True):
+ """
+ Reconstruct the original image from the *un-denoised* decomposition
+ by applying the inverse VST.
+
+ This reconstruction result is also used as the `initial condition'
+ for the below `iterative reconstruction' algorithm.
+
+ arguments:
+ * denoised: whether use th denoised data or the direct decomposition
+ * positive_project: whether replace negative values with zeros
+ """
+ if denoised:
+ decomposition = self.denoised
+ else:
+ decomposition = self.decomposition
+ self.positive_project = positive_project
+ details = np.sum(decomposition[:-1], axis=0)
+ approx = self.vst(decomposition[-1], scale=self.level)
+ reconstruction = self.ivst(approx+details, scale=0)
+ if positive_project:
+ reconstruction[reconstruction < 0.0] = 0.0
+ self.reconstruction = reconstruction
+ return reconstruction
+
+ def reconstruct(self, denoised=True, niter=20, verbose=False):
+ """
+ Reconstruct the original image using iterative method with
+ L1 regularization, because the denoising violates the exact inverse
+ procedure.
+
+ arguments:
+ * denoised: whether use the denoised coefficients
+ * niter: number of iterations
+ """
+ if denoised:
+ decomposition = self.denoised
+ else:
+ decomposition = self.decomposition
+ # L1 regularization
+ lbd = 1.0
+ delta = lbd / (niter - 1)
+ # initial solution
+ solution = self.reconstruct_ivst(denoised=denoised,
+ positive_project=True)
+ #
+ iuwt = IUWT(level=self.level)
+ iuwt.calc_filters()
+ # iterative reconstruction
+ if verbose:
+ print("Iteratively reconstructing (%d times): " % niter,
+ end="", flush=True, file=sys.stderr)
+ for i in range(niter):
+ if verbose:
+ print("%d..." % i, end="", flush=True, file=sys.stderr)
+ tempd = self.data.copy()
+ solution_decomp = []
+ for scale in range(1, self.level+1):
+ approx, detail = iuwt.transform(tempd, scale)
+ approx_sol, detail_sol = iuwt.transform(solution, scale)
+ # Update coefficients according to the significant supports,
+ # which are acquired during the denosing precodure with FDR.
+ sig = self.sig_supports[scale]
+ detail_sol[sig] = detail[sig]
+ detail_sol = self.soft_threshold(detail_sol, threshold=lbd)
+ #
+ solution_decomp.append(detail_sol)
+ tempd = approx.copy()
+ solution = approx_sol.copy()
+ # last approximation (the two are the same)
+ solution_decomp.append(approx)
+ # reconstruct
+ solution = iuwt.reconstruct(decomposition=solution_decomp)
+ # discard all negative values
+ solution[solution < 0] = 0.0
+ #
+ lbd -= delta
+ if verbose:
+ print("DONE!", flush=True, file=sys.stderr)
+ #
+ self.reconstruction = solution
+ return self.reconstruction
+# IUWT_VST }}}
+
+
+def main():
+ # commandline arguments parser
+ parser = argparse.ArgumentParser(
+ description="Poisson Noise Removal with Multi-scale Variance " + \
+ "Stabling Transform and Wavelet Transform",
+ epilog="Version: %s (%s)" % (__version__, __date__))
+ parser.add_argument("-l", "--level", dest="level",
+ type=int, default=5,
+ help="level of the IUWT decomposition")
+ parser.add_argument("-r", "--fdr", dest="fdr",
+ type=float, default=0.1,
+ help="false discovery rate")
+ parser.add_argument("-I", "--fdr-independent", dest="fdr_independent",
+ action="store_true", default=False,
+ help="whether the FDR null hypotheses are independent")
+ parser.add_argument("-n", "--niter", dest="niter",
+ type=int, default=20,
+ help="number of iterations for reconstruction")
+ parser.add_argument("-v", "--verbose", dest="verbose",
+ action="store_true", default=False,
+ help="show verbose progress")
+ parser.add_argument("-C", "--clobber", dest="clobber",
+ action="store_true", default=False,
+ help="overwrite output file if exists")
+ parser.add_argument("infile", help="input image with Poisson noises")
+ parser.add_argument("outfile", help="output denoised image")
+ args = parser.parse_args()
+
+ if args.verbose:
+ print("infile: '%s'" % args.infile, file=sys.stderr)
+ print("outfile: '%s'" % args.outfile, file=sys.stderr)
+ print("level: %s" % args.level, file=sys.stderr)
+ print("fdr: %s" % args.fdr, file=sys.stderr)
+ print("fdr_independent: %s" % args.fdr_independent, file=sys.stderr)
+ print("niter: %s\n" % args.niter, file=sys.stderr)
+
+ imgfits = fits.open(args.infile)
+ img = imgfits[0].data
+ # Remove Poisson noises
+ msvst = IUWT_VST(data=img)
+ if args.verbose:
+ print("INFO: MSVST decomposing ...", file=sys.stderr)
+ msvst.decompose(level=args.level)
+ if args.verbose:
+ print("INFO: MSVST denosing ...", file=sys.stderr)
+ msvst.denoise(fdr=args.fdr, fdr_independent=args.fdr_independent)
+ if args.verbose:
+ print("INFO: MSVST reconstructing (this may take a while) ...",
+ file=sys.stderr)
+ msvst.reconstruct(denoised=True, niter=args.niter, verbose=args.verbose)
+ img_denoised = msvst.reconstruction
+ # Output
+ imgfits[0].data = img_denoised
+ imgfits[0].header.add_history("%s: Removed Poisson Noises @ %s" % (
+ os.path.basename(sys.argv[0]), datetime.utcnow().isoformat()))
+ imgfits[0].header.add_history(" TOOL: %s (v%s)" % (
+ os.path.basename(sys.argv[0]), __version__))
+ imgfits[0].header.add_history(" PARAM: %s" % " ".join(sys.argv[1:]))
+ imgfits.writeto(args.outfile, checksum=True, clobber=args.clobber)
+
+
+if __name__ == "__main__":
+ main()
+