109 lines
3.1 KiB
Python
109 lines
3.1 KiB
Python
import pytest
|
|
from functools import partial
|
|
import numpy as np
|
|
from skimage.future import fit_segmenter, predict_segmenter, TrainableSegmenter
|
|
from skimage.feature import multiscale_basic_features
|
|
from scipy import spatial
|
|
|
|
|
|
class DummyNNClassifier(object):
|
|
def fit(self, X, labels):
|
|
self.X = X
|
|
self.labels = labels
|
|
self.tree = spatial.cKDTree(self.X)
|
|
|
|
def predict(self, X):
|
|
nearest_neighbors = self.tree.query(X)[1]
|
|
return self.labels[nearest_neighbors]
|
|
|
|
|
|
def test_trainable_segmentation_singlechannel():
|
|
img = np.zeros((20, 20))
|
|
img[:10] = 1
|
|
img += 0.05 * np.random.randn(*img.shape)
|
|
labels = np.zeros_like(img, dtype=np.uint8)
|
|
labels[:2] = 1
|
|
labels[-2:] = 2
|
|
clf = DummyNNClassifier()
|
|
features_func = partial(
|
|
multiscale_basic_features,
|
|
edges=False,
|
|
texture=False,
|
|
sigma_min=0.5,
|
|
sigma_max=2,
|
|
)
|
|
features = features_func(img)
|
|
clf = fit_segmenter(labels, features, clf)
|
|
out = predict_segmenter(features, clf)
|
|
assert np.all(out[:10] == 1)
|
|
assert np.all(out[10:] == 2)
|
|
|
|
|
|
def test_trainable_segmentation_multichannel():
|
|
img = np.zeros((20, 20, 3))
|
|
img[:10] = 1
|
|
img += 0.05 * np.random.randn(*img.shape)
|
|
labels = np.zeros_like(img[..., 0], dtype=np.uint8)
|
|
labels[:2] = 1
|
|
labels[-2:] = 2
|
|
clf = DummyNNClassifier()
|
|
features_func = partial(
|
|
multiscale_basic_features,
|
|
edges=False,
|
|
texture=False,
|
|
sigma_min=0.5,
|
|
sigma_max=2,
|
|
multichannel=True
|
|
)
|
|
features = features_func(img)
|
|
clf = fit_segmenter(labels, features, clf)
|
|
out = predict_segmenter(features, clf)
|
|
assert np.all(out[:10] == 1)
|
|
assert np.all(out[10:] == 2)
|
|
|
|
|
|
def test_trainable_segmentation_predict():
|
|
img = np.zeros((20, 20))
|
|
img[:10] = 1
|
|
img += 0.05 * np.random.randn(*img.shape)
|
|
labels = np.zeros_like(img, dtype=np.uint8)
|
|
labels[:2] = 1
|
|
labels[-2:] = 2
|
|
clf = DummyNNClassifier()
|
|
features_func = partial(
|
|
multiscale_basic_features,
|
|
edges=False,
|
|
texture=False,
|
|
sigma_min=0.5,
|
|
sigma_max=2,
|
|
)
|
|
features = features_func(img)
|
|
clf = fit_segmenter(labels, features, clf)
|
|
|
|
test_features = np.random.random((5, 20, 20))
|
|
with pytest.raises(ValueError) as err:
|
|
_ = predict_segmenter(test_features, clf)
|
|
assert 'type of features' in str(err.value)
|
|
|
|
|
|
def test_trainable_segmentation_oo():
|
|
img = np.zeros((20, 20))
|
|
img[:10] = 1
|
|
img += 0.05 * np.random.randn(*img.shape)
|
|
labels = np.zeros_like(img, dtype=np.uint8)
|
|
labels[:2] = 1
|
|
labels[-2:] = 2
|
|
clf = DummyNNClassifier()
|
|
features_func = partial(
|
|
multiscale_basic_features,
|
|
edges=False,
|
|
texture=False,
|
|
sigma_min=0.5,
|
|
sigma_max=2,
|
|
)
|
|
segmenter = TrainableSegmenter(clf=clf, features_func=features_func)
|
|
segmenter.fit(img, labels)
|
|
out = segmenter.predict(img)
|
|
assert np.all(out[:10] == 1)
|
|
assert np.all(out[10:] == 2)
|