145 lines
5.6 KiB
Cython
145 lines
5.6 KiB
Cython
#cython: cdivision=True
|
|
#cython: boundscheck=False
|
|
#cython: nonecheck=False
|
|
#cython: wraparound=False
|
|
|
|
import numpy as np
|
|
cimport numpy as cnp
|
|
|
|
from libc.math cimport exp, sqrt, ceil
|
|
from libc.float cimport DBL_MAX
|
|
|
|
cnp.import_array()
|
|
|
|
|
|
def _quickshift_cython(double[:, :, ::1] image, double kernel_size,
|
|
double max_dist, bint return_tree, int random_seed):
|
|
"""Segments image using quickshift clustering in Color-(x,y) space.
|
|
|
|
Produces an oversegmentation of the image using the quickshift mode-seeking
|
|
algorithm.
|
|
|
|
Parameters
|
|
----------
|
|
image : (width, height, channels) ndarray
|
|
Input image.
|
|
kernel_size : float
|
|
Width of Gaussian kernel used in smoothing the
|
|
sample density. Higher means fewer clusters.
|
|
max_dist : float
|
|
Cut-off point for data distances.
|
|
Higher means fewer clusters.
|
|
return_tree : bool
|
|
Whether to return the full segmentation hierarchy tree and distances.
|
|
random_seed : int
|
|
Random seed used for breaking ties.
|
|
|
|
Returns
|
|
-------
|
|
segment_mask : (width, height) ndarray
|
|
Integer mask indicating segment labels.
|
|
"""
|
|
|
|
random_state = np.random.RandomState(random_seed)
|
|
|
|
# TODO join orphaned roots?
|
|
# Some nodes might not have a point of higher density within the
|
|
# search window. We could do a global search over these in the end.
|
|
# Reference implementation doesn't do that, though, and it only has
|
|
# an effect for very high max_dist.
|
|
|
|
# window size for neighboring pixels to consider
|
|
cdef double inv_kernel_size_sqr = -0.5 / (kernel_size * kernel_size)
|
|
cdef int kernel_width = <int>ceil(3 * kernel_size)
|
|
|
|
cdef Py_ssize_t height = image.shape[0]
|
|
cdef Py_ssize_t width = image.shape[1]
|
|
cdef Py_ssize_t channels = image.shape[2]
|
|
|
|
cdef double[:, ::1] densities = np.zeros((height, width), dtype=np.double)
|
|
|
|
cdef double current_density, closest, dist, t
|
|
cdef Py_ssize_t r, c, r_, c_, channel, r_min, r_max, c_min, c_max
|
|
cdef double* current_pixel_ptr
|
|
|
|
# this will break ties that otherwise would give us headache
|
|
densities += random_state.normal(scale=0.00001, size=(height, width))
|
|
# default parent to self
|
|
cdef Py_ssize_t[:, ::1] parent = \
|
|
np.arange(width * height, dtype=np.intp).reshape(height, width)
|
|
cdef double[:, ::1] dist_parent = np.zeros((height, width), dtype=np.double)
|
|
|
|
# compute densities
|
|
with nogil:
|
|
current_pixel_ptr = &image[0, 0, 0]
|
|
for r in range(height):
|
|
r_min = max(r - kernel_width, 0)
|
|
r_max = min(r + kernel_width + 1, height)
|
|
for c in range(width):
|
|
c_min = max(c - kernel_width, 0)
|
|
c_max = min(c + kernel_width + 1, width)
|
|
for r_ in range(r_min, r_max):
|
|
for c_ in range(c_min, c_max):
|
|
dist = 0
|
|
for channel in range(channels):
|
|
t = (current_pixel_ptr[channel] -
|
|
image[r_, c_, channel])
|
|
dist += t * t
|
|
t = r - r_
|
|
dist += t * t
|
|
t = c - c_
|
|
dist += t * t
|
|
densities[r, c] += exp(dist * inv_kernel_size_sqr)
|
|
current_pixel_ptr += channels
|
|
|
|
# find nearest node with higher density
|
|
current_pixel_ptr = &image[0, 0, 0]
|
|
for r in range(height):
|
|
r_min = max(r - kernel_width, 0)
|
|
r_max = min(r + kernel_width + 1, height)
|
|
for c in range(width):
|
|
current_density = densities[r, c]
|
|
closest = DBL_MAX
|
|
c_min = max(c - kernel_width, 0)
|
|
c_max = min(c + kernel_width + 1, width)
|
|
for r_ in range(r_min, r_max):
|
|
for c_ in range(c_min, c_max):
|
|
if densities[r_, c_] > current_density:
|
|
dist = 0
|
|
# We compute the distances twice since otherwise
|
|
# we get crazy memory overhead
|
|
# (width * height * windowsize**2)
|
|
for channel in range(channels):
|
|
t = (current_pixel_ptr[channel] -
|
|
image[r_, c_, channel])
|
|
dist += t * t
|
|
t = r - r_
|
|
dist += t * t
|
|
t = c - c_
|
|
dist += t * t
|
|
if dist < closest:
|
|
closest = dist
|
|
parent[r, c] = r_ * width + c_
|
|
dist_parent[r, c] = sqrt(closest)
|
|
current_pixel_ptr += channels
|
|
|
|
dist_parent_flat = np.array(dist_parent).ravel()
|
|
parent_flat = np.array(parent).ravel()
|
|
|
|
# remove parents with distance > max_dist
|
|
too_far = dist_parent_flat > max_dist
|
|
parent_flat[too_far] = np.arange(width * height)[too_far]
|
|
old = np.zeros_like(parent_flat)
|
|
|
|
# flatten forest (mark each pixel with root of corresponding tree)
|
|
while (old != parent_flat).any():
|
|
old = parent_flat
|
|
parent_flat = parent_flat[parent_flat]
|
|
|
|
parent_flat = np.unique(parent_flat, return_inverse=True)[1]
|
|
parent_flat = parent_flat.reshape(height, width)
|
|
|
|
if return_tree:
|
|
return parent_flat, parent, dist_parent
|
|
return parent_flat
|