CofeehousePy/deps/scikit-image/skimage/registration/_masked_phase_cross_correla...

288 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Implementation of the masked normalized cross-correlation.
Based on the following publication:
D. Padfield. Masked object registration in the Fourier domain.
IEEE Transactions on Image Processing (2012)
and the author's original MATLAB implementation, available on this website:
http://www.dirkpadfield.com/
"""
import numpy as np
from functools import partial
from .._shared.fft import fftmodule, next_fast_len
def _masked_phase_cross_correlation(reference_image, moving_image,
reference_mask, moving_mask=None,
overlap_ratio=0.3):
"""Masked image translation registration by masked normalized
cross-correlation.
Parameters
----------
reference_image : ndarray
Reference image.
moving_image : ndarray
Image to register. Must be same dimensionality as ``reference_image``,
but not necessarily the same size.
reference_mask : ndarray
Boolean mask for ``reference_image``. The mask should evaluate
to ``True`` (or 1) on valid pixels. ``reference_mask`` should
have the same shape as ``reference_image``.
moving_mask : ndarray or None, optional
Boolean mask for ``moving_image``. The mask should evaluate to ``True``
(or 1) on valid pixels. ``moving_mask`` should have the same shape
as ``moving_image``. If ``None``, ``reference_mask`` will be used.
overlap_ratio : float, optional
Minimum allowed overlap ratio between images. The correlation for
translations corresponding with an overlap ratio lower than this
threshold will be ignored. A lower `overlap_ratio` leads to smaller
maximum translation, while a higher `overlap_ratio` leads to greater
robustness against spurious matches due to small overlap between
masked images.
Returns
-------
shifts : ndarray
Shift vector (in pixels) required to register ``moving_image``
with ``reference_image``. Axis ordering is consistent with
numpy (e.g. Z, Y, X)
References
----------
.. [1] Dirk Padfield. Masked Object Registration in the Fourier Domain.
IEEE Transactions on Image Processing, vol. 21(5),
pp. 2706-2718 (2012). :DOI:`10.1109/TIP.2011.2181402`
.. [2] D. Padfield. "Masked FFT registration". In Proc. Computer Vision and
Pattern Recognition, pp. 2918-2925 (2010).
:DOI:`10.1109/CVPR.2010.5540032`
"""
if moving_mask is None:
if reference_image.shape != moving_image.shape:
raise ValueError(
"Input images have different shapes, moving_mask must "
"be explicitely set.")
moving_mask = reference_mask.astype(bool)
# We need masks to be of the same size as their respective images
for (im, mask) in [(reference_image, reference_mask),
(moving_image, moving_mask)]:
if im.shape != mask.shape:
raise ValueError(
"Image sizes must match their respective mask sizes.")
xcorr = cross_correlate_masked(moving_image, reference_image, moving_mask,
reference_mask, axes=(0, 1), mode='full',
overlap_ratio=overlap_ratio)
# Generalize to the average of multiple equal maxima
maxima = np.stack(np.nonzero(xcorr == xcorr.max()), axis=1)
center = np.mean(maxima, axis=0)
shifts = center - np.array(reference_image.shape) + 1
# The mismatch in size will impact the center location of the
# cross-correlation
size_mismatch = (np.array(moving_image.shape)
- np.array(reference_image.shape))
return -shifts + (size_mismatch / 2)
def cross_correlate_masked(arr1, arr2, m1, m2, mode='full', axes=(-2, -1),
overlap_ratio=0.3):
"""
Masked normalized cross-correlation between arrays.
Parameters
----------
arr1 : ndarray
First array.
arr2 : ndarray
Seconds array. The dimensions of `arr2` along axes that are not
transformed should be equal to that of `arr1`.
m1 : ndarray
Mask of `arr1`. The mask should evaluate to `True`
(or 1) on valid pixels. `m1` should have the same shape as `arr1`.
m2 : ndarray
Mask of `arr2`. The mask should evaluate to `True`
(or 1) on valid pixels. `m2` should have the same shape as `arr2`.
mode : {'full', 'same'}, optional
'full':
This returns the convolution at each point of overlap. At
the end-points of the convolution, the signals do not overlap
completely, and boundary effects may be seen.
'same':
The output is the same size as `arr1`, centered with respect
to the `full` output. Boundary effects are less prominent.
axes : tuple of ints, optional
Axes along which to compute the cross-correlation.
overlap_ratio : float, optional
Minimum allowed overlap ratio between images. The correlation for
translations corresponding with an overlap ratio lower than this
threshold will be ignored. A lower `overlap_ratio` leads to smaller
maximum translation, while a higher `overlap_ratio` leads to greater
robustness against spurious matches due to small overlap between
masked images.
Returns
-------
out : ndarray
Masked normalized cross-correlation.
Raises
------
ValueError : if correlation `mode` is not valid, or array dimensions along
non-transformation axes are not equal.
References
----------
.. [1] Dirk Padfield. Masked Object Registration in the Fourier Domain.
IEEE Transactions on Image Processing, vol. 21(5),
pp. 2706-2718 (2012). :DOI:`10.1109/TIP.2011.2181402`
.. [2] D. Padfield. "Masked FFT registration". In Proc. Computer Vision and
Pattern Recognition, pp. 2918-2925 (2010).
:DOI:`10.1109/CVPR.2010.5540032`
"""
if mode not in {'full', 'same'}:
raise ValueError("Correlation mode '{}' is not valid.".format(mode))
fixed_image = np.array(arr1, dtype=float)
fixed_mask = np.array(m1, dtype=bool)
moving_image = np.array(arr2, dtype=float)
moving_mask = np.array(m2, dtype=bool)
eps = np.finfo(float).eps
# Array dimensions along non-transformation axes should be equal.
all_axes = set(range(fixed_image.ndim))
for axis in (all_axes - set(axes)):
if fixed_image.shape[axis] != moving_image.shape[axis]:
raise ValueError(
"Array shapes along non-transformation axes should be "
"equal, but dimensions along axis {a} are not".format(a=axis))
# Determine final size along transformation axes
# Note that it might be faster to compute Fourier transform in a slightly
# larger shape (`fast_shape`). Then, after all fourier transforms are done,
# we slice back to`final_shape` using `final_slice`.
final_shape = list(arr1.shape)
for axis in axes:
final_shape[axis] = fixed_image.shape[axis] + \
moving_image.shape[axis] - 1
final_shape = tuple(final_shape)
final_slice = tuple([slice(0, int(sz)) for sz in final_shape])
# Extent transform axes to the next fast length (i.e. multiple of 3, 5, or
# 7)
fast_shape = tuple([next_fast_len(final_shape[ax]) for ax in axes])
# We use numpy.fft or the new scipy.fft because they allow leaving the
# transform axes unchanged which was not possible with scipy.fftpack's
# fftn/ifftn in older versions of SciPy.
# E.g. arr shape (2, 3, 7), transform along axes (0, 1) with shape (4, 4)
# results in arr_fft shape (4, 4, 7)
fft = partial(fftmodule.fftn, s=fast_shape, axes=axes)
ifft = partial(fftmodule.ifftn, s=fast_shape, axes=axes)
fixed_image[np.logical_not(fixed_mask)] = 0.0
moving_image[np.logical_not(moving_mask)] = 0.0
# N-dimensional analog to rotation by 180deg is flip over all relevant axes.
# See [1] for discussion.
rotated_moving_image = _flip(moving_image, axes=axes)
rotated_moving_mask = _flip(moving_mask, axes=axes)
fixed_fft = fft(fixed_image)
rotated_moving_fft = fft(rotated_moving_image)
fixed_mask_fft = fft(fixed_mask)
rotated_moving_mask_fft = fft(rotated_moving_mask)
# Calculate overlap of masks at every point in the convolution.
# Locations with high overlap should not be taken into account.
number_overlap_masked_px = np.real(
ifft(rotated_moving_mask_fft * fixed_mask_fft))
number_overlap_masked_px[:] = np.round(number_overlap_masked_px)
number_overlap_masked_px[:] = np.fmax(number_overlap_masked_px, eps)
masked_correlated_fixed_fft = ifft(rotated_moving_mask_fft * fixed_fft)
masked_correlated_rotated_moving_fft = ifft(
fixed_mask_fft * rotated_moving_fft)
numerator = ifft(rotated_moving_fft * fixed_fft)
numerator -= masked_correlated_fixed_fft * \
masked_correlated_rotated_moving_fft / number_overlap_masked_px
fixed_squared_fft = fft(np.square(fixed_image))
fixed_denom = ifft(rotated_moving_mask_fft * fixed_squared_fft)
fixed_denom -= np.square(masked_correlated_fixed_fft) / \
number_overlap_masked_px
fixed_denom[:] = np.fmax(fixed_denom, 0.0)
rotated_moving_squared_fft = fft(np.square(rotated_moving_image))
moving_denom = ifft(fixed_mask_fft * rotated_moving_squared_fft)
moving_denom -= np.square(masked_correlated_rotated_moving_fft) / \
number_overlap_masked_px
moving_denom[:] = np.fmax(moving_denom, 0.0)
denom = np.sqrt(fixed_denom * moving_denom)
# Slice back to expected convolution shape.
numerator = numerator[final_slice]
denom = denom[final_slice]
number_overlap_masked_px = number_overlap_masked_px[final_slice]
if mode == 'same':
_centering = partial(_centered,
newshape=fixed_image.shape, axes=axes)
denom = _centering(denom)
numerator = _centering(numerator)
number_overlap_masked_px = _centering(number_overlap_masked_px)
# Pixels where `denom` is very small will introduce large
# numbers after division. To get around this problem,
# we zero-out problematic pixels.
tol = 1e3 * eps * np.max(np.abs(denom), axis=axes, keepdims=True)
nonzero_indices = denom > tol
out = np.zeros_like(denom)
out[nonzero_indices] = numerator[nonzero_indices] / denom[nonzero_indices]
np.clip(out, a_min=-1, a_max=1, out=out)
# Apply overlap ratio threshold
number_px_threshold = overlap_ratio * np.max(number_overlap_masked_px,
axis=axes, keepdims=True)
out[number_overlap_masked_px < number_px_threshold] = 0.0
return out
def _centered(arr, newshape, axes):
""" Return the center `newshape` portion of `arr`, leaving axes not
in `axes` untouched. """
newshape = np.asarray(newshape)
currshape = np.array(arr.shape)
slices = [slice(None, None)] * arr.ndim
for ax in axes:
startind = (currshape[ax] - newshape[ax]) // 2
endind = startind + newshape[ax]
slices[ax] = slice(startind, endind)
return arr[tuple(slices)]
def _flip(arr, axes=None):
""" Reverse array over many axes. Generalization of arr[::-1] for many
dimensions. If `axes` is `None`, flip along all axes. """
if axes is None:
reverse = [slice(None, None, -1)] * arr.ndim
else:
reverse = [slice(None, None, None)] * arr.ndim
for axis in axes:
reverse[axis] = slice(None, None, -1)
return arr[tuple(reverse)]