swh:1:snp:38fb28ee2fbdb2310fc8c93e867e0ae2e97d4195
Raw File
Tip revision: 0e28c2957481be12e90990401218c344a2ad2cbf authored by Dawid on 12 October 2023, 16:48:31 UTC
typo fix in doc/intro
Tip revision: 0e28c29
EMD2d.py
#!/usr/bin/python
# coding: UTF-8
#
# Author:   Dawid Laszuk
# Contact:  https://github.com/laszukdawid/PyEMD/issues
#
# Edited:   07/07/2017
#
# Feel free to contact for any information.

import logging

import numpy as np

try:
    from scipy.interpolate import SmoothBivariateSpline as SBS
    from scipy.ndimage.filters import maximum_filter
    from scipy.ndimage.morphology import binary_erosion, generate_binary_structure
except ImportError:
    raise ImportError(
        "EMD2D and BEMD are not supported. Feel free to play around and improve them. "
        + "Required depdenecies are in `requriements-extra`."
    )


class EMD2D:
    """
    **Empirical Mode Decomposition** on images.

    **Important** This is an experimental module.
    Experiments performed using this module didn't provide acceptable results,
    either in actual output nor in computation performance. The author is not
    an expert in image processing so it's very likely that the code could
    have been improved. Take your best shot.

    Method decomposes images into 2D representations of loose Intrinsic Mode
    Functions (IMFs).

    The current version of the algorithm detects local extrema, separately
    minima and maxima, and then connects them to create envelopes. These
    are then used to create a mean trend and subtracted from the input.

    Threshold values that control goodness of the decomposition:
        * `mse_thr` --- proto-IMF check whether small mean square error.
        * `mean_thr` --- proto-IMF chekc whether small mean value.
    """

    logger = logging.getLogger(__name__)

    def __init__(self, **config):
        # ProtoIMF related
        self.mse_thr = 0.01
        self.mean_thr = 0.01

        self.FIXE = 0
        self.FIXE_H = 0

        self.MAX_ITERATION = 1000

        # Update based on options
        for key in config.keys():
            if key in self.__dict__.keys():
                self.__dict__[key] = config[key]

    def __call__(self, image, max_imf=-1):
        return self.emd(image, max_imf=max_imf)

    def extract_max_min_spline(self, image):
        """Calculates top and bottom envelopes for image.

        Parameters
        ----------
        image : numpy 2D array

        Returns
        -------
        min_env : numpy 2D array
            Bottom envelope in form of an image.
        max_env : numpy 2D array
            Top envelope in form of an image.
        """

        big_image = self.prepare_image(image)
        big_min_peaks, big_max_peaks = self.find_extrema(big_image)

        # Prepare grid for interpolation. Doesn't seem necessary.
        xi = np.arange(image.shape[0], image.shape[0] * 2)
        yi = np.arange(image.shape[1], image.shape[1] * 2)

        big_min_image_val = big_image[big_min_peaks]
        big_max_image_val = big_image[big_max_peaks]
        min_env = self.spline_points(big_min_peaks[0], big_min_peaks[1], big_min_image_val, xi, yi)
        max_env = self.spline_points(big_max_peaks[0], big_max_peaks[1], big_max_image_val, xi, yi)

        return min_env, max_env

    @classmethod
    def prepare_image(cls, image):
        """Prepares image for edge extrapolation.
        Method bloats image by mirroring it along all axes. This turns
        extrapolation on edges into interpolation within bigger image.

        Parameters
        ----------
        image : numpy 2D array
            Image for which interpolation is required,

        Returns
        -------
        image : numpy 2D array
            Big image based on the input. Grid 3x3 where the center block is input and
            neighbouring panels are respective mirror images.
        """

        # TODO: This is nasty. Instead of bloating whole image and then trying to
        #      find all extrema, it's better to deal directly with indices.
        shape = image.shape
        big_image = np.zeros((shape[0] * 3, shape[1] * 3))

        image_lr = np.fliplr(image)
        image_ud = np.flipud(image)
        image_ud_lr = np.flipud(image_lr)
        image_lr_ud = np.fliplr(image_ud)

        # Fill center with default image
        big_image[shape[0] : 2 * shape[0], shape[1] : 2 * shape[1]] = image

        # Fill left center
        big_image[shape[0] : 2 * shape[0], : shape[1]] = image_lr

        # Fill right center
        big_image[shape[0] : 2 * shape[0], 2 * shape[1] :] = image_lr

        # Fill center top
        big_image[: shape[0], shape[1] : shape[1] * 2] = image_ud

        # Fill center bottom
        big_image[2 * shape[0] :, shape[1] : 2 * shape[1]] = image_ud

        # Fill left top
        big_image[: shape[0], : shape[1]] = image_ud_lr

        # Fill left bottom
        big_image[2 * shape[0] :, : shape[1]] = image_ud_lr

        # Fill right top
        big_image[: shape[0], 2 * shape[1] :] = image_lr_ud

        # Fill right bottom
        big_image[2 * shape[0] :, 2 * shape[1] :] = image_lr_ud

        return big_image

    @classmethod
    def spline_points(cls, X, Y, Z, xi, yi):
        """Interpolates for given set of points"""

        # SBS requires at least m=(kx+1)*(ky+1) points,
        # where kx=ky=3 (default) is the degree of bivariate spline.
        # Thus, if less than 16=(3+1)*(3+1) points, adjust kx & ky.
        spline = SBS(X, Y, Z)

        return spline(xi, yi)

    @classmethod
    def find_extrema(cls, image):
        """
        Finds extrema, both mininma and maxima, based on local maximum filter.
        Returns extrema in form of two rows, where the first and second are
        positions of x and y, respectively.

        Parameters
        ----------
        image : numpy 2D array
            Monochromatic image or any 2D array.

        Returns
        -------
        min_peaks : numpy array
            Minima positions.
        max_peaks : numpy array
            Maxima positions.
        """

        # define an 3x3 neighborhood
        neighborhood = generate_binary_structure(2, 2)

        # apply the local maximum filter; all pixel of maximal value
        # in their neighborhood are set to 1
        local_min = maximum_filter(-image, footprint=neighborhood) == -image
        local_max = maximum_filter(image, footprint=neighborhood) == image

        # can't distinguish between background zero and filter zero
        background = image == 0

        # appear along the bg border (artifact of the local max filter)
        eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)

        # we obtain the final mask, containing only peaks,
        # by removing the background from the local_max mask (xor operation)
        min_peaks = local_min ^ eroded_background
        max_peaks = local_max ^ eroded_background

        min_peaks = local_min
        max_peaks = local_max
        min_peaks[[0, -1], :] = False
        min_peaks[:, [0, -1]] = False
        max_peaks[[0, -1], :] = False
        max_peaks[:, [0, -1]] = False

        min_peaks = np.nonzero(min_peaks)
        max_peaks = np.nonzero(max_peaks)

        return min_peaks, max_peaks

    @classmethod
    def end_condition(cls, image, IMFs):
        """Determins whether decomposition should be stopped.

        Parameters
        ----------
        image : numpy 2D array
            Input image which is decomposed.
        IMFs : numpy 3D array
            Array for which first dimensions relates to respective IMF,
            i.e. (numIMFs, imageX, imageY).
        """
        rec = np.sum(IMFs, axis=0)

        # If reconstruction is perfect, no need for more tests
        if np.allclose(image, rec):
            return True

        return False

    def check_proto_imf(self, proto_imf, proto_imf_prev, mean_env):
        """Check whether passed (proto) IMF is actual IMF.
        Current condition is solely based on checking whether the mean is below threshold.

        Parameters
        ----------
        proto_imf : numpy 2D array
            Current iteration of proto IMF.
        proto_imf_prev : numpy 2D array
            Previous iteration of proto IMF.
        mean_env : numpy 2D array
            Local mean computed from top and bottom envelopes.

        Returns
        -------
        boolean
            Whether current proto IMF is actual IMF.
        """

        # TODO: Sifiting is very sensitive and subtracting const val can often flip
        #      maxima with minima in decompoisition and thus repeating above/below
        #      behaviour. For now, mean_env is checked whether close to zero excluding
        #      its offset.
        if np.all(np.abs(mean_env - mean_env.mean()) < self.mean_thr):
            # if np.all(np.abs(mean_env)<self.mean_thr):
            return True

        # If very little change with sifting
        if np.allclose(proto_imf, proto_imf_prev):
            return True

        # If IMF mean close to zero (below threshold)
        if np.mean(np.abs(proto_imf)) < self.mean_thr:
            return True

        # Everything relatively close to 0
        mse_proto_imf = np.mean(proto_imf * proto_imf)
        if mse_proto_imf < self.mse_thr:
            return True

        return False

    def emd(self, image, max_imf=-1):
        """Performs EMD on input image with specified parameters.

        Parameters
        ----------
        image : numpy 2D array,
            Image which will be decomposed.
        max_imf : int, (default: -1)
            IMF number to which decomposition should be performed.
            Negative value means *all*.

        Returns
        -------
        IMFs : numpy 3D array
            Set of IMFs in form of numpy array where the first dimension
            relates to IMF's ordinary number.
        """
        image_min, image_max = np.min(image), np.max(image)
        offset = image_min
        scale = image_max - image_min

        image_s = (image - offset) / scale

        imf = np.zeros(image.shape)
        imf_old = imf.copy()

        imfNo = 0
        IMF = np.empty((imfNo,) + image.shape)
        notFinished = True

        while notFinished:
            self.logger.debug("IMF -- " + str(imfNo))

            res = image_s - np.sum(IMF[:imfNo], axis=0)
            imf = res.copy()
            mean_env = np.zeros(image.shape)
            stop_sifting = False

            # Counters
            n = 0  # All iterations for current imf.
            n_h = 0  # counts when mean(proto_imf) < threshold

            while not stop_sifting and n < self.MAX_ITERATION:
                n += 1
                self.logger.debug("Iteration: " + str(n))

                min_peaks, max_peaks = self.find_extrema(imf)

                self.logger.debug("min_peaks = %i  |  max_peaks = %i", len(min_peaks[0]), len(max_peaks[0]))
                if len(min_peaks[0]) > 4 and len(max_peaks[0]) > 4:
                    imf_old = imf.copy()
                    imf = imf - mean_env

                    min_env, max_env = self.extract_max_min_spline(imf)

                    mean_env = 0.5 * (min_env + max_env)

                    imf_old = imf.copy()
                    imf = imf - mean_env

                    # Fix number of iterations
                    if self.FIXE:
                        if n >= self.FIXE + 1:
                            stop_sifting = True

                    # Fix number of iterations after number of zero-crossings
                    # and extrema differ at most by one.
                    elif self.FIXE_H:
                        if n == 1:
                            continue
                        if self.check_proto_imf(imf, imf_old, mean_env):
                            n_h += 1
                        else:
                            n_h = 0

                        # STOP if enough n_h
                        if n_h >= self.FIXE_H:
                            stop_sifting = True

                    # Stops after default stopping criteria are met
                    else:
                        if self.check_proto_imf(imf, imf_old, mean_env):
                            stop_sifting = True

                else:
                    notFinished = False
                    stop_sifting = True

            IMF = np.vstack((IMF, imf.copy()[None, :]))
            imfNo += 1

            if self.end_condition(image, IMF) or (max_imf > 0 and imfNo >= max_imf):
                notFinished = False
                break

        res = image_s - np.sum(IMF[:imfNo], axis=0)
        if not np.allclose(res, 0):
            IMF = np.vstack((IMF, res[None, :]))
            imfNo += 1

        IMF = IMF * scale
        IMF[-1] += offset
        return IMF


########################################
if __name__ == "__main__":
    print("Running example on EMD2D")
    PLOT = True

    logging.basicConfig(level=logging.DEBUG)

    # Generate image
    print("Generating image... ", end="")
    rows, cols = 1024, 1024
    row_scale, col_scale = 256, 256
    x = np.arange(rows) / float(row_scale)
    y = np.arange(cols).reshape((-1, 1)) / float(col_scale)

    pi2 = 2 * np.pi
    img = np.zeros((rows, cols))
    img = img + np.sin(2 * pi2 * x) * np.cos(y * 4 * pi2 + 4 * x * pi2)
    img = img + 3 * np.sin(2 * pi2 * x) + 2
    img = img + 5 * x * y + 2 * (y - 0.2) * y
    print("Done")

    # Perform decomposition
    print("Performing decomposition... ", end="")
    emd2d = EMD2D()
    # emd2d.FIXE_H = 5
    IMFs = emd2d.emd(img, max_imf=4)
    imfNo = IMFs.shape[0]
    print("Done")

    if PLOT:
        print("Plotting results... ", end="")
        import pylab as plt

        # Save image for preview
        plt.figure(figsize=(4, 4 * (imfNo + 1)))
        plt.subplot(imfNo + 1, 1, 1)
        plt.imshow(img)
        plt.colorbar()
        plt.title("Input image")

        # Save reconstruction
        for n, imf in enumerate(IMFs):
            plt.subplot(imfNo + 1, 1, n + 2)
            plt.imshow(imf)
            plt.colorbar()
            plt.title("IMF %i" % (n + 1))

        plt.savefig("image_decomp")
        print("Done")
back to top