CofeehousePy/deps/scikit-image/doc/examples/filters/plot_j_invariant_tutorial.py

311 lines
13 KiB
Python

"""
=========================================================
Full tutorial on calibrating Denoisers Using J-Invariance
=========================================================
In this example, we show how to find an optimally calibrated
version of any denoising algorithm.
The calibration method is based on the `noise2self` algorithm of [1]_.
.. [1] J. Batson & L. Royer. Noise2Self: Blind Denoising by Self-Supervision,
International Conference on Machine Learning, p. 524-533 (2019).
.. seealso::
A simple example of the method is given in
:ref:`sphx_glr_auto_examples_filters_plot_j_invariant.py`.
"""
#####################################################################
# Calibrating a wavelet denoiser
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import gridspec
from skimage.data import chelsea, hubble_deep_field
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.restoration import (calibrate_denoiser,
denoise_wavelet,
denoise_tv_chambolle, denoise_nl_means,
estimate_sigma)
from skimage.util import img_as_float, random_noise
from skimage.color import rgb2gray
from functools import partial
_denoise_wavelet = partial(denoise_wavelet, rescale_sigma=True)
image = img_as_float(chelsea())
sigma = 0.2
noisy = random_noise(image, var=sigma ** 2)
# Parameters to test when calibrating the denoising algorithm
parameter_ranges = {'sigma': np.arange(0.1, 0.3, 0.02),
'wavelet': ['db1', 'db2'],
'convert2ycbcr': [True, False],
'multichannel': [True]}
# Denoised image using default parameters of `denoise_wavelet`
default_output = denoise_wavelet(noisy, multichannel=True, rescale_sigma=True)
# Calibrate denoiser
calibrated_denoiser = calibrate_denoiser(noisy,
_denoise_wavelet,
denoise_parameters=parameter_ranges
)
# Denoised image using calibrated denoiser
calibrated_output = calibrated_denoiser(noisy)
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(15, 5))
for ax, img, title in zip(axes,
[noisy, default_output, calibrated_output],
['Noisy Image', 'Denoised (Default)',
'Denoised (Calibrated)']):
ax.imshow(img)
ax.set_title(title)
ax.set_yticks([])
ax.set_xticks([])
#####################################################################
# The Self-Supervised Loss and J-Invariance
# =========================================
# The key to this calibration method is the notion of J-invariance. A denoising
# function is J-invariant if the prediction it makes for each pixel does
# not depend on the value of that pixel in the original image. The prediction
# for each pixel may instead use all the relevant information contained in the
# rest of the image, which is typically quite significant. Any function
# can be converted into a J-invariant one using a simple masking procedure,
# as described in [1].
#
# The pixel-wise error of a J-invariant denoiser is uncorrelated
# to the noise, so long as the noise in each pixel is independent.
# Consequently, the average difference between the denoised image and the
# noisy image, the *self-supervised loss*, is the same as the
# difference between the denoised image and the original clean image, the
# *ground-truth loss* (up to a constant).
#
# This means that the best J-invariant denoiser for a given image can
# be found using the noisy data alone, by selecting the denoiser minimizing
# the self-supervised loss. Below, we demonstrate this
# for a family of wavelet denoisers with varying `sigma` parameter. The
# self-supervised loss (solid blue line) and the ground-truth loss (dashed
# blue line) have the same shape and the same minimizer.
#
from skimage.restoration.j_invariant import _invariant_denoise
sigma_range = np.arange(sigma/2, 1.5*sigma, 0.025)
parameters_tested = [{'sigma': sigma, 'convert2ycbcr': True, 'wavelet': 'db2',
'multichannel': True}
for sigma in sigma_range]
denoised_invariant = [_invariant_denoise(noisy, _denoise_wavelet,
denoiser_kwargs=params)
for params in parameters_tested]
self_supervised_loss = [mse(img, noisy) for img in denoised_invariant]
ground_truth_loss = [mse(img, image) for img in denoised_invariant]
opt_idx = np.argmin(self_supervised_loss)
plot_idx = [0, opt_idx, len(sigma_range) - 1]
get_inset = lambda x: x[25:225, 100:300]
plt.figure(figsize=(10, 12))
gs = gridspec.GridSpec(3, 3)
ax1 = plt.subplot(gs[0, :])
ax2 = plt.subplot(gs[1, :])
ax_image = [plt.subplot(gs[2, i]) for i in range(3)]
ax1.plot(sigma_range, self_supervised_loss, color='C0',
label='Self-Supervised Loss')
ax1.scatter(sigma_range[opt_idx], self_supervised_loss[opt_idx] + 0.0003,
marker='v', color='red', label='optimal sigma')
ax1.set_ylabel('MSE')
ax1.set_xticks([])
ax1.legend()
ax1.set_title('Self-Supervised Loss')
ax2.plot(sigma_range, ground_truth_loss, color='C0', linestyle='--',
label='Ground Truth Loss')
ax2.scatter(sigma_range[opt_idx], ground_truth_loss[opt_idx] + 0.0003,
marker='v', color='red', label='optimal sigma')
ax2.set_ylabel('MSE')
ax2.legend()
ax2.set_xlabel('sigma')
ax2.set_title('Ground-Truth Loss')
for i in range(3):
ax = ax_image[i]
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(get_inset(denoised_invariant[plot_idx[i]]))
ax.set_xlabel('sigma = ' + str(np.round(sigma_range[plot_idx[i]], 2)))
for spine in ax_image[1].spines.values():
spine.set_edgecolor('red')
spine.set_linewidth(5)
#####################################################################
# Conversion to J-invariance
# =========================================
# The function `_invariant_denoise` acts as a J-invariant version of a
# given denoiser. It works by masking a fraction of the pixels, interpolating
# them, running the original denoiser, and extracting the values returned in
# the masked pixels. Iterating over the image results in a fully J-invariant
# output.
#
# For any given set of parameters, the J-invariant version of a denoiser
# is different from the original denoiser, but it is not necessarily better
# or worse. In the plot below, we see that, for the test image of a cat,
# the J-invariant version of a wavelet denoiser is significantly better
# than the original at small values of variance-reduction `sigma` and
# imperceptibly worse at larger values.
#
parameters_tested = [{'sigma': sigma, 'convert2ycbcr': True,
'wavelet': 'db2', 'multichannel': True}
for sigma in sigma_range]
denoised_original = [_denoise_wavelet(noisy, **params)
for params in parameters_tested]
ground_truth_loss_invariant = [mse(img, image) for img in denoised_invariant]
ground_truth_loss_original = [mse(img, image) for img in denoised_original]
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(sigma_range, ground_truth_loss_invariant, color='C0', linestyle='--',
label='J-invariant')
ax.plot(sigma_range, ground_truth_loss_original, color='C1', linestyle='--',
label='Original')
ax.scatter(sigma_range[opt_idx], ground_truth_loss[opt_idx] + 0.001,
marker='v', color='red')
ax.legend()
ax.set_title(
'J-Invariant Denoiser Has Comparable Or '
'Better Performance At Same Parameters'
)
ax.set_ylabel('MSE')
ax.set_xlabel('sigma')
#####################################################################
# Comparing Different Classes of Denoiser
# =========================================
# The self-supervised loss can be used to compare different classes of
# denoiser in addition to choosing parameters for a single class.
# This allows the user to, in an unbiased way, choose the best parameters
# for the best class of denoiser for a given image.
#
# Below, we show this for an image of the hubble deep field with significant
# speckle noise added. In this case, the J-invariant calibrated denoiser is
# better than the default denoiser in each of three families of denoisers --
# Non-local means, wavelet, and TV norm. Additionally, the self-supervised
# loss shows that the TV norm denoiser is the best for this noisy image.
#
image = rgb2gray(img_as_float(hubble_deep_field()[100:250, 50:300]))
sigma = 0.4
noisy = random_noise(image, mode='speckle', var=sigma ** 2)
parameter_ranges_tv = {'weight': np.arange(0.01, 0.3, 0.02)}
_, (parameters_tested_tv, losses_tv) = calibrate_denoiser(
noisy,
denoise_tv_chambolle,
denoise_parameters=parameter_ranges_tv,
extra_output=True)
print(f'Minimum self-supervised loss TV: {np.min(losses_tv):.4f}')
best_parameters_tv = parameters_tested_tv[np.argmin(losses_tv)]
denoised_calibrated_tv = _invariant_denoise(noisy, denoise_tv_chambolle,
denoiser_kwargs=best_parameters_tv)
denoised_default_tv = denoise_tv_chambolle(noisy, **best_parameters_tv)
psnr_calibrated_tv = psnr(image, denoised_calibrated_tv)
psnr_default_tv = psnr(image, denoised_default_tv)
parameter_ranges_wavelet = {'sigma': np.arange(0.01, 0.3, 0.03)}
_, (parameters_tested_wavelet, losses_wavelet) = calibrate_denoiser(
noisy,
_denoise_wavelet,
parameter_ranges_wavelet,
extra_output=True)
print(f'Minimum self-supervised loss wavelet: {np.min(losses_wavelet):.4f}')
best_parameters_wavelet = parameters_tested_wavelet[np.argmin(losses_wavelet)]
denoised_calibrated_wavelet = _invariant_denoise(
noisy, _denoise_wavelet,
denoiser_kwargs=best_parameters_wavelet)
denoised_default_wavelet = _denoise_wavelet(noisy, **best_parameters_wavelet)
psnr_calibrated_wavelet = psnr(image, denoised_calibrated_wavelet)
psnr_default_wavelet = psnr(image, denoised_default_wavelet)
sigma_est = estimate_sigma(noisy)
parameter_ranges_nl = {'sigma': np.arange(0.6, 1.4, 0.2) * sigma_est,
'h': np.arange(0.6, 1.2, 0.2) * sigma_est}
parameter_ranges_nl = {'sigma': np.arange(0.01, 0.3, 0.03)}
_, (parameters_tested_nl, losses_nl) = calibrate_denoiser(noisy,
denoise_nl_means,
parameter_ranges_nl,
extra_output=True)
print(f'Minimum self-supervised loss NL means: {np.min(losses_nl):.4f}')
best_parameters_nl = parameters_tested_nl[np.argmin(losses_nl)]
denoised_calibrated_nl = _invariant_denoise(noisy, denoise_nl_means,
denoiser_kwargs=best_parameters_nl)
denoised_default_nl = denoise_nl_means(noisy, **best_parameters_nl)
psnr_calibrated_nl = psnr(image, denoised_calibrated_nl)
psnr_default_nl = psnr(image, denoised_default_nl)
print(f' PSNR')
print(f'NL means (Default) : {psnr_default_nl:.1f}')
print(f'NL means (Calibrated): {psnr_calibrated_nl:.1f}')
print(f'Wavelet (Default) : {psnr_default_wavelet:.1f}')
print(f'Wavelet (Calibrated): {psnr_calibrated_wavelet:.1f}')
print(f'TV norm (Default) : {psnr_default_tv:.1f}')
print(f'TV norm (Calibrated): {psnr_calibrated_tv:.1f}')
plt.subplots(figsize=(10, 12))
plt.imshow(noisy, cmap='Greys_r')
plt.xticks([])
plt.yticks([])
plt.title('Noisy Image')
get_inset = lambda x: x[0:100, -140:]
fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(15, 8))
for ax in axes.ravel():
ax.set_xticks([])
ax.set_yticks([])
axes[0, 0].imshow(get_inset(denoised_default_nl), cmap='Greys_r')
axes[0, 0].set_title('NL Means Default')
axes[1, 0].imshow(get_inset(denoised_calibrated_nl), cmap='Greys_r')
axes[1, 0].set_title('NL Means Calibrated')
axes[0, 1].imshow(get_inset(denoised_default_wavelet), cmap='Greys_r')
axes[0, 1].set_title('Wavelet Default')
axes[1, 1].imshow(get_inset(denoised_calibrated_wavelet), cmap='Greys_r')
axes[1, 1].set_title('Wavelet Calibrated')
axes[0, 2].imshow(get_inset(denoised_default_tv), cmap='Greys_r')
axes[0, 2].set_title('TV Norm Default')
axes[1, 2].imshow(get_inset(denoised_calibrated_tv), cmap='Greys_r')
axes[1, 2].set_title('TV Norm Calibrated')
for spine in axes[1, 2].spines.values():
spine.set_edgecolor('red')
spine.set_linewidth(5)
plt.show()