diff options
author | Aaron LI <aaronly.me@outlook.com> | 2016-03-31 10:49:48 +0800 |
---|---|---|
committer | Aaron LI <aaronly.me@outlook.com> | 2016-03-31 10:49:48 +0800 |
commit | a374fac3362a0204a93a3a50011366239a80fc1b (patch) | |
tree | 8b5e9646e8a7e715942b96aea8267e27157a3c36 /python/starlet.py | |
parent | d6cb7d0636c94612c667d6555001ec50dd81ea4f (diff) | |
download | atoolbox-a374fac3362a0204a93a3a50011366239a80fc1b.tar.bz2 |
add several python scripts
Diffstat (limited to 'python/starlet.py')
-rw-r--r-- | python/starlet.py | 513 |
1 files changed, 513 insertions, 0 deletions
diff --git a/python/starlet.py b/python/starlet.py new file mode 100644 index 0000000..67e69ba --- /dev/null +++ b/python/starlet.py @@ -0,0 +1,513 @@ +# -*- coding: utf-8 -*- +# +# References: +# [1] Jean-Luc Starck, Fionn Murtagh & Jalal M. Fadili +# Sparse Image and Signal Processing: Wavelets, Curvelets, Morphological Diversity +# Section 3.5, 6.6 +# +# Credits: +# [1] https://github.com/abrazhe/image-funcut/blob/master/imfun/atrous.py +# +# Aaron LI +# Created: 2016-03-17 +# Updated: 2016-03-22 +# + +""" +Starlet wavelet transform, i.e., isotropic undecimated wavelet transform +(IUWT), or à trous wavelet transform. +And multi-scale variance stabling transform (MS-VST). +""" + +import sys + +import numpy as np +import scipy as sp +from scipy import signal + + +class B3Spline: + """ + B3-spline wavelet. + """ + # scaling function (phi) + dec_lo = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) / 16 + dec_hi = np.array([-1.0, -4.0, 10.0, -4.0, -1.0]) / 16 + rec_lo = np.array([0.0, 0.0, 1.0, 0.0, 0.0]) + rec_hi = np.array([0.0, 0.0, 1.0, 0.0, 0.0]) + + +class IUWT: + """ + Isotropic undecimated wavelet transform. + """ + ## Decomposition filters list: + # a_{scale} = convole(a_0, filters[scale]) + # Note: the zero-th scale filter (i.e., delta function) is the first + # element, thus the array index is the same as the decomposition scale. + filters = [] + + phi = None # wavelet scaling function (2D) + level = 0 # number of transform level + decomposition = None # decomposed coefficients/images + reconstruction = None # reconstructed image + + # convolution boundary condition + boundary = "symm" + + def __init__(self, phi=B3Spline.dec_lo, level=None, boundary="symm", + data=None): + self.set_wavelet(phi=phi) + self.level = level + self.boundary = boundary + self.data = np.array(data) + + def reset(self): + """ + Reset the object attributes. + """ + self.data = None + self.phi = None + self.decomposition = None + self.reconstruction = None + self.level = 0 + self.filters = [] + self.boundary = "symm" + + def load_data(self, data): + self.reset() + self.data = np.array(data) + + def set_wavelet(self, phi): + self.reset() + phi = np.array(phi) + if phi.ndim == 1: + phi_ = phi.reshape(1, -1) + self.phi = np.dot(phi_.T, phi_) + elif phi.ndim == 2: + self.phi = phi + else: + raise ValueError("Invalid phi dimension") + + def calc_filters(self): + """ + Calculate the convolution filters of each scale. + Note: the zero-th scale filter (i.e., delta function) is the first + element, thus the array index is the same as the decomposition scale. + """ + self.filters = [] + # scale 0: delta function + h = np.array([[1]]) # NOTE: 2D + self.filters.append(h) + # scale 1 + h = self.phi[::-1, ::-1] + self.filters.append(h) + for scale in range(2, self.level+1): + h_up = self.zupsample(self.phi, order=scale-1) + h2 = signal.convolve2d(h_up[::-1, ::-1], h, mode="same", + boundary=self.boundary) + self.filters.append(h2) + + def transform(self, data, scale, boundary="symm"): + """ + Perform only one scale wavelet transform for the given data. + + return: + [ approx, detail ] + """ + self.decomposition = [] + approx = signal.convolve2d(data, self.filters[scale], + mode="same", boundary=self.boundary) + detail = data - approx + return [approx, detail] + + def decompose(self, level, boundary="symm"): + """ + Perform IUWT decomposition in the plain loop way. + The filters of each scale/level are calculated first, then the + approximations of each scale/level are calculated by convolving the + raw/finest image with these filters. + + return: + [ W_1, W_2, ..., W_n, A_n ] + n = level + W: wavelet details + A: approximation + """ + self.boundary = boundary + if self.level != level or self.filters == []: + self.level = level + self.calc_filters() + self.decomposition = [] + approx = self.data + for scale in range(1, level+1): + # approximation: + approx2 = signal.convolve2d(self.data, self.filters[scale], + mode="same", boundary=self.boundary) + # wavelet details: + w = approx - approx2 + self.decomposition.append(w) + if scale == level: + self.decomposition.append(approx2) + approx = approx2 + return self.decomposition + + def decompose_recursive(self, level, boundary="symm"): + """ + Perform the IUWT decomposition in the recursive way. + + return: + [ W_1, W_2, ..., W_n, A_n ] + n = level + W: wavelet details + A: approximation + """ + self.level = level + self.boundary = boundary + self.decomposition = self.__decompose(self.data, self.phi, level=level) + return self.decomposition + + def __decompose(self, data, phi, level): + """ + 2D IUWT decomposition (or stationary wavelet transform). + + This is a convolution version, where kernel is zero-upsampled + explicitly. Not fast. + + Parameters: + - level : level of decomposition + - phi : low-pass filter kernel + - boundary : boundary conditions (passed to scipy.signal.convolve2d, + 'symm' by default) + + Returns: + list of wavelet details + last approximation. Each element in + the list is an image of the same size as the input image. + """ + if level <= 0: + return data + shapecheck = map(lambda a,b:a>b, data.shape, phi.shape) + assert np.all(shapecheck) + # approximation: + approx = signal.convolve2d(data, phi[::-1, ::-1], mode="same", + boundary=self.boundary) + # wavelet details: + w = data - approx + phi_up = self.zupsample(phi, order=1) + shapecheck = map(lambda a,b:a>b, data.shape, phi_up.shape) + if level == 1: + return [w, approx] + elif not np.all(shapecheck): + print("Maximum allowed decomposition level reached", + file=sys.stderr) + return [w, approx] + else: + return [w] + self.__decompose(approx, phi_up, level-1) + + @staticmethod + def zupsample(data, order=1): + """ + Upsample data array by interleaving it with zero's. + + h{up_order: n}[l] = (1) h[l], if l % 2^n == 0; + (2) 0, otherwise + """ + shape = data.shape + new_shape = [ (2**order * (n-1) + 1) for n in shape ] + output = np.zeros(new_shape, dtype=data.dtype) + output[[ slice(None, None, 2**order) for d in shape ]] = data + return output + + def reconstruct(self, decomposition=None): + if decomposition is not None: + reconstruction = np.sum(decomposition, axis=0) + return reconstruction + else: + self.reconstruction = np.sum(self.decomposition, axis=0) + + def get_detail(self, scale): + """ + Get the wavelet detail coefficients of given scale. + Note: 1 <= scale <= level + """ + if scale < 1 or scale > self.level: + raise ValueError("Invalid scale") + return self.decomposition[scale-1] + + def get_approx(self): + """ + Get the approximation coefficients of the largest scale. + """ + return self.decomposition[-1] + + +class IUWT_VST(IUWT): + """ + IUWT with Multi-scale variance stabling transform. + + Refernce: + [1] Bo Zhang, Jalal M. Fadili & Jean-Luc Starck, + IEEE Trans. Image Processing, 17, 17, 2008 + """ + # VST coefficients and the corresponding asymptotic standard deviation + # of each scale. + vst_coef = [] + + def reset(self): + super(self.__class__, self).reset() + vst_coef = [] + + def __decompose(self): + raise AttributeError("No '__decompose' attribute") + + @staticmethod + def soft_threshold(data, threshold): + if isinstance(data, np.ndarray): + data_th = data.copy() + data_th[np.abs(data) <= threshold] = 0.0 + data_th[data > threshold] -= threshold + data_th[data < -threshold] += threshold + else: + data_th = data + if np.abs(data) <= threshold: + data_th = 0.0 + elif data > threshold: + data_th -= threshold + else: + data_th += threshold + return data_th + + def tau(self, k, scale): + """ + Helper function used in VST coefficients calculation. + """ + return np.sum(np.power(self.filters[scale], k)) + + def filters_product(self, scale1, scale2): + """ + Calculate the scalar product of the filters of two scales, + considering only the overlapped part. + Helper function used in VST coefficients calculation. + """ + if scale1 > scale2: + filter_big = self.filters[scale1] + filter_small = self.filters[scale2] + else: + filter_big = self.filters[scale2] + filter_small = self.filters[scale1] + # crop the big filter to match the size of the small filter + size_big = filter_big.shape + size_small = filter_small.shape + size_diff2 = list(map(lambda a,b: (a-b)//2, size_big, size_small)) + filter_big_crop = filter_big[ + size_diff2[0]:(size_big[0]-size_diff2[0]), + size_diff2[1]:(size_big[1]-size_diff2[1])] + assert(np.all(list(map(lambda a,b: a==b, + size_small, filter_big_crop.shape)))) + product = np.sum(filter_small * filter_big_crop) + return product + + def calc_vst_coef(self): + """ + Calculate the VST coefficients and the corresponding + asymptotic standard deviation of each scale, according to the + calculated filters of each scale/level. + """ + self.vst_coef = [] + for scale in range(self.level+1): + b = 2 * np.sqrt(np.abs(self.tau(1, scale)) / self.tau(2, scale)) + c = 7.0*self.tau(2, scale) / (8.0*self.tau(1, scale)) - \ + self.tau(3, scale) / (2.0*self.tau(2, scale)) + if scale == 0: + std = -1.0 + else: + std = np.sqrt((self.tau(2, scale-1) / \ + (4 * self.tau(1, scale-1)**2)) + \ + (self.tau(2, scale) / (4 * self.tau(1, scale)**2)) - \ + (self.filters_product(scale-1, scale) / \ + (2 * self.tau(1, scale-1) * self.tau(1, scale)))) + self.vst_coef.append({ "b": b, "c": c, "std": std }) + + def vst(self, data, scale, coupled=True): + """ + Perform variance stabling transform + + XXX: parameter `coupled' why?? + Credit: MSVST-V1.0/src/libmsvst/B3VSTAtrous.h + """ + self.vst_coupled = coupled + if self.vst_coef == []: + self.calc_vst_coef() + if coupled: + b = 1.0 + else: + b = self.vst_coef[scale]["b"] + data_vst = b * np.sqrt(np.abs(data + self.vst_coef[scale]["c"])) + return data_vst + + def ivst(self, data, scale, cbias=True): + """ + Inverse variance stabling transform + NOTE: assuming that `a_{j} + c^{j}' are all positive. + + XXX: parameter `cbias' why?? + `bias correction' is recommended while reconstruct the data + after estimation + Credit: MSVST-V1.0/src/libmsvst/B3VSTAtrous.h + """ + self.vst_cbias = cbias + if cbias: + cb = 1.0 / (self.vst_coef[scale]["b"] ** 2) + else: + cb = 0.0 + data_ivst = data ** 2 + cb - self.vst_coef[scale]["c"] + return data_ivst + + def is_significant(self, scale, fdr=0.1, independent=False): + """ + Multiple hypothesis testing with false discovery rate (FDR) control. + + If `independent=True': FDR <= (m0/m) * q + otherwise: FDR <= (m0/m) * q * (1 + 1/2 + 1/3 + ... + 1/m) + + References: + [1] False discovery rate - Wikipedia + https://en.wikipedia.org/wiki/False_discovery_rate + """ + coef = self.get_detail(scale) + pvalues = 2 * (1 - sp.stats.norm.cdf(np.abs(coef) / \ + self.vst_coef[scale]["std"])) + p_sorted = pvalues.flatten() + p_sorted.sort() + N = len(p_sorted) + if independent: + cn = 1.0 + else: + cn = np.sum(1.0 / np.arange(1, N+1)) + p_comp = fdr * np.arange(N) / (N * cn) + comp = (p_sorted < p_comp) + # cutoff p-value after FDR control/correction + p_cutoff = np.max(p_sorted[comp]) + return (pvalues <= p_cutoff, p_cutoff) + + def denoise(self, fdr=0.1, fdr_independent=False): + """ + Denoise the wavelet coefficients by controlling FDR. + """ + self.fdr = fdr + self.fdr_indepent = fdr_independent + self.denoised = [] + # supports of significant coefficients of each scale + self.sig_supports = [None] # make index match the scale + self.p_cutoff = [None] + for scale in range(1, self.level+1): + coef = self.get_detail(scale) + sig, p_cutoff = self.is_significant(scale, fdr, fdr_independent) + coef[np.logical_not(sig)] = 0.0 + self.denoised.append(coef) + self.sig_supports.append(sig) + self.p_cutoff.append(p_cutoff) + # append the last approximation + self.denoised.append(self.get_approx()) + + def decompose(self, level, boundary="symm"): + """ + 2D IUWT decomposition with VST. + """ + self.boundary = boundary + if self.level != level or self.filters == []: + self.level = level + self.calc_filters() + self.calc_vst_coef() + self.decomposition = [] + approx = self.data + for scale in range(1, level+1): + # approximation: + approx2 = signal.convolve2d(self.data, self.filters[scale], + mode="same", boundary=self.boundary) + # wavelet details: + w = self.vst(approx, scale=scale-1) - self.vst(approx2, scale=scale) + self.decomposition.append(w) + if scale == level: + self.decomposition.append(approx2) + approx = approx2 + return self.decomposition + + def reconstruct_ivst(self, denoised=True, positive_project=True): + """ + Reconstruct the original image from the *un-denoised* decomposition + by applying the inverse VST. + + This reconstruction result is also used as the `initial condition' + for the below `iterative reconstruction' algorithm. + + arguments: + * denoised: whether use th denoised data or the direct decomposition + * positive_project: whether replace negative values with zeros + """ + if denoised: + decomposition = self.denoised + else: + decomposition = self.decomposition + self.positive_project = positive_project + details = np.sum(decomposition[:-1], axis=0) + approx = self.vst(decomposition[-1], scale=self.level) + reconstruction = self.ivst(approx+details, scale=0) + if positive_project: + reconstruction[reconstruction < 0.0] = 0.0 + self.reconstruction = reconstruction + return reconstruction + + def reconstruct(self, denoised=True, niter=20, verbose=False): + """ + Reconstruct the original image using iterative method with + L1 regularization, because the denoising violates the exact inverse + procedure. + + arguments: + * denoised: whether use the denoised coefficients + * niter: number of iterations + """ + if denoised: + decomposition = self.denoised + else: + decomposition = self.decomposition + # L1 regularization + lbd = 1.0 + delta = lbd / (niter - 1) + # initial solution + solution = self.reconstruct_ivst(denoised=denoised, + positive_project=True) + # + iuwt = IUWT(level=self.level) + iuwt.calc_filters() + # iterative reconstruction + for i in range(niter): + if verbose: + print("iter: %d" % i+1, file=sys.stderr) + tempd = self.data.copy() + solution_decomp = [] + for scale in range(1, self.level+1): + approx, detail = iuwt.transform(tempd, scale) + approx_sol, detail_sol = iuwt.transform(solution, scale) + # Update coefficients according to the significant supports, + # which are acquired during the denosing precodure with FDR. + sig = self.sig_supports[scale] + detail_sol[sig] = detail[sig] + detail_sol = self.soft_threshold(detail_sol, threshold=lbd) + # + solution_decomp.append(detail_sol) + tempd = approx.copy() + solution = approx_sol.copy() + # last approximation (the two are the same) + solution_decomp.append(approx) + # reconstruct + solution = iuwt.reconstruct(decomposition=solution_decomp) + # discard all negative values + solution[solution < 0] = 0.0 + # + lbd -= delta + # + self.reconstruction = solution + return self.reconstruction + |