113 lines
4.5 KiB
Python
113 lines
4.5 KiB
Python
"""
|
|
Trainable segmentation using local features and random forests
|
|
==============================================================
|
|
|
|
A pixel-based segmentation is computed here using local features based on
|
|
local intensity, edges and textures at different scales. A user-provided
|
|
mask is used to identify different regions. The pixels of the mask are used
|
|
to train a random-forest classifier [1]_ from scikit-learn. Unlabeled pixels
|
|
are then labeled from the prediction of the classifier.
|
|
|
|
This segmentation algorithm is called trainable segmentation in other software
|
|
such as ilastik [2]_ or ImageJ [3]_ (where it is also called "weka
|
|
segmentation").
|
|
|
|
.. [1] https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html
|
|
.. [2] https://www.ilastik.org/documentation/pixelclassification/pixelclassification
|
|
.. [3] https://imagej.net/Trainable_Weka_Segmentation#Training_features_.282D.29
|
|
"""
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from skimage import data, segmentation, feature, future
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
from functools import partial
|
|
|
|
full_img = data.skin()
|
|
|
|
img = full_img[:900, :900]
|
|
|
|
# Build an array of labels for training the segmentation.
|
|
# Here we use rectangles but visualization libraries such as plotly
|
|
# (and napari?) can be used to draw a mask on the image.
|
|
training_labels = np.zeros(img.shape[:2], dtype=np.uint8)
|
|
training_labels[:130] = 1
|
|
training_labels[:170, :400] = 1
|
|
training_labels[600:900, 200:650] = 2
|
|
training_labels[330:430, 210:320] = 3
|
|
training_labels[260:340, 60:170] = 4
|
|
training_labels[150:200, 720:860] = 4
|
|
|
|
sigma_min = 1
|
|
sigma_max = 16
|
|
features_func = partial(feature.multiscale_basic_features,
|
|
intensity=True, edges=False, texture=True,
|
|
sigma_min=sigma_min, sigma_max=sigma_max,
|
|
multichannel=True)
|
|
features = features_func(img)
|
|
clf = RandomForestClassifier(n_estimators=50, n_jobs=-1,
|
|
max_depth=10, max_samples=0.05)
|
|
clf = future.fit_segmenter(training_labels, features, clf)
|
|
result = future.predict_segmenter(features, clf)
|
|
|
|
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(9, 4))
|
|
ax[0].imshow(segmentation.mark_boundaries(img, result, mode='thick'))
|
|
ax[0].contour(training_labels)
|
|
ax[0].set_title('Image, mask and segmentation boundaries')
|
|
ax[1].imshow(result)
|
|
ax[1].set_title('Segmentation')
|
|
fig.tight_layout()
|
|
|
|
##############################################################################
|
|
# Feature importance
|
|
# ------------------
|
|
#
|
|
# We inspect below the importance of the different features, as computed by
|
|
# scikit-learn. Intensity features have a much higher importance than texture
|
|
# features. It can be tempting to use this information to reduce the number of
|
|
# features given to the classifier, in order to reduce the computing time.
|
|
# However, this can lead to overfitting and a degraded result at the boundary
|
|
# between regions.
|
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(9, 4))
|
|
l = len(clf.feature_importances_)
|
|
feature_importance = (
|
|
clf.feature_importances_[:l//3],
|
|
clf.feature_importances_[l//3:2*l//3],
|
|
clf.feature_importances_[2*l//3:])
|
|
sigmas = np.logspace(
|
|
np.log2(sigma_min), np.log2(sigma_max),
|
|
num=int(np.log2(sigma_max) - np.log2(sigma_min) + 1),
|
|
base=2, endpoint=True)
|
|
for ch, color in zip(range(3), ['r', 'g', 'b']):
|
|
ax[0].plot(sigmas, feature_importance[ch][::3], 'o', color=color)
|
|
ax[0].set_title("Intensity features")
|
|
ax[0].set_xlabel("$\\sigma$")
|
|
for ch, color in zip(range(3), ['r', 'g', 'b']):
|
|
ax[1].plot(sigmas, feature_importance[ch][1::3], 'o', color=color)
|
|
ax[1].plot(sigmas, feature_importance[ch][2::3], 's', color=color)
|
|
ax[1].set_title("Texture features")
|
|
ax[1].set_xlabel("$\\sigma$")
|
|
|
|
fig.tight_layout()
|
|
|
|
##############################################################################
|
|
# Fitting new images
|
|
# ------------------
|
|
#
|
|
# If you have several images of similar objects acquired in similar conditions,
|
|
# you can use the classifier trained with `fit_segmenter` to segment other
|
|
# images. In the example below we just use a different part of the image.
|
|
|
|
img_new = full_img[:700, 900:]
|
|
|
|
features_new = features_func(img_new)
|
|
result_new = future.predict_segmenter(features_new, clf)
|
|
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(6, 4))
|
|
ax[0].imshow(segmentation.mark_boundaries(img_new, result_new, mode='thick'))
|
|
ax[0].set_title('Image')
|
|
ax[1].imshow(result_new)
|
|
ax[1].set_title('Segmentation')
|
|
fig.tight_layout()
|
|
|
|
plt.show()
|