311 lines
13 KiB
Python
311 lines
13 KiB
Python
"""
|
|
=========================================================
|
|
Full tutorial on calibrating Denoisers Using J-Invariance
|
|
=========================================================
|
|
|
|
In this example, we show how to find an optimally calibrated
|
|
version of any denoising algorithm.
|
|
|
|
The calibration method is based on the `noise2self` algorithm of [1]_.
|
|
|
|
.. [1] J. Batson & L. Royer. Noise2Self: Blind Denoising by Self-Supervision,
|
|
International Conference on Machine Learning, p. 524-533 (2019).
|
|
|
|
.. seealso::
|
|
A simple example of the method is given in
|
|
:ref:`sphx_glr_auto_examples_filters_plot_j_invariant.py`.
|
|
"""
|
|
|
|
#####################################################################
|
|
# Calibrating a wavelet denoiser
|
|
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
from matplotlib import gridspec
|
|
|
|
from skimage.data import chelsea, hubble_deep_field
|
|
from skimage.metrics import mean_squared_error as mse
|
|
from skimage.metrics import peak_signal_noise_ratio as psnr
|
|
from skimage.restoration import (calibrate_denoiser,
|
|
denoise_wavelet,
|
|
denoise_tv_chambolle, denoise_nl_means,
|
|
estimate_sigma)
|
|
from skimage.util import img_as_float, random_noise
|
|
from skimage.color import rgb2gray
|
|
from functools import partial
|
|
|
|
_denoise_wavelet = partial(denoise_wavelet, rescale_sigma=True)
|
|
|
|
image = img_as_float(chelsea())
|
|
sigma = 0.2
|
|
noisy = random_noise(image, var=sigma ** 2)
|
|
|
|
# Parameters to test when calibrating the denoising algorithm
|
|
parameter_ranges = {'sigma': np.arange(0.1, 0.3, 0.02),
|
|
'wavelet': ['db1', 'db2'],
|
|
'convert2ycbcr': [True, False],
|
|
'multichannel': [True]}
|
|
|
|
# Denoised image using default parameters of `denoise_wavelet`
|
|
default_output = denoise_wavelet(noisy, multichannel=True, rescale_sigma=True)
|
|
|
|
# Calibrate denoiser
|
|
calibrated_denoiser = calibrate_denoiser(noisy,
|
|
_denoise_wavelet,
|
|
denoise_parameters=parameter_ranges
|
|
)
|
|
|
|
# Denoised image using calibrated denoiser
|
|
calibrated_output = calibrated_denoiser(noisy)
|
|
|
|
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(15, 5))
|
|
|
|
for ax, img, title in zip(axes,
|
|
[noisy, default_output, calibrated_output],
|
|
['Noisy Image', 'Denoised (Default)',
|
|
'Denoised (Calibrated)']):
|
|
ax.imshow(img)
|
|
ax.set_title(title)
|
|
ax.set_yticks([])
|
|
ax.set_xticks([])
|
|
|
|
#####################################################################
|
|
# The Self-Supervised Loss and J-Invariance
|
|
# =========================================
|
|
# The key to this calibration method is the notion of J-invariance. A denoising
|
|
# function is J-invariant if the prediction it makes for each pixel does
|
|
# not depend on the value of that pixel in the original image. The prediction
|
|
# for each pixel may instead use all the relevant information contained in the
|
|
# rest of the image, which is typically quite significant. Any function
|
|
# can be converted into a J-invariant one using a simple masking procedure,
|
|
# as described in [1].
|
|
#
|
|
# The pixel-wise error of a J-invariant denoiser is uncorrelated
|
|
# to the noise, so long as the noise in each pixel is independent.
|
|
# Consequently, the average difference between the denoised image and the
|
|
# noisy image, the *self-supervised loss*, is the same as the
|
|
# difference between the denoised image and the original clean image, the
|
|
# *ground-truth loss* (up to a constant).
|
|
#
|
|
# This means that the best J-invariant denoiser for a given image can
|
|
# be found using the noisy data alone, by selecting the denoiser minimizing
|
|
# the self-supervised loss. Below, we demonstrate this
|
|
# for a family of wavelet denoisers with varying `sigma` parameter. The
|
|
# self-supervised loss (solid blue line) and the ground-truth loss (dashed
|
|
# blue line) have the same shape and the same minimizer.
|
|
#
|
|
|
|
from skimage.restoration.j_invariant import _invariant_denoise
|
|
|
|
sigma_range = np.arange(sigma/2, 1.5*sigma, 0.025)
|
|
|
|
parameters_tested = [{'sigma': sigma, 'convert2ycbcr': True, 'wavelet': 'db2',
|
|
'multichannel': True}
|
|
for sigma in sigma_range]
|
|
|
|
denoised_invariant = [_invariant_denoise(noisy, _denoise_wavelet,
|
|
denoiser_kwargs=params)
|
|
for params in parameters_tested]
|
|
|
|
self_supervised_loss = [mse(img, noisy) for img in denoised_invariant]
|
|
ground_truth_loss = [mse(img, image) for img in denoised_invariant]
|
|
|
|
opt_idx = np.argmin(self_supervised_loss)
|
|
plot_idx = [0, opt_idx, len(sigma_range) - 1]
|
|
|
|
get_inset = lambda x: x[25:225, 100:300]
|
|
|
|
plt.figure(figsize=(10, 12))
|
|
|
|
gs = gridspec.GridSpec(3, 3)
|
|
ax1 = plt.subplot(gs[0, :])
|
|
ax2 = plt.subplot(gs[1, :])
|
|
ax_image = [plt.subplot(gs[2, i]) for i in range(3)]
|
|
|
|
ax1.plot(sigma_range, self_supervised_loss, color='C0',
|
|
label='Self-Supervised Loss')
|
|
ax1.scatter(sigma_range[opt_idx], self_supervised_loss[opt_idx] + 0.0003,
|
|
marker='v', color='red', label='optimal sigma')
|
|
|
|
ax1.set_ylabel('MSE')
|
|
ax1.set_xticks([])
|
|
ax1.legend()
|
|
ax1.set_title('Self-Supervised Loss')
|
|
|
|
ax2.plot(sigma_range, ground_truth_loss, color='C0', linestyle='--',
|
|
label='Ground Truth Loss')
|
|
ax2.scatter(sigma_range[opt_idx], ground_truth_loss[opt_idx] + 0.0003,
|
|
marker='v', color='red', label='optimal sigma')
|
|
ax2.set_ylabel('MSE')
|
|
ax2.legend()
|
|
ax2.set_xlabel('sigma')
|
|
ax2.set_title('Ground-Truth Loss')
|
|
|
|
for i in range(3):
|
|
ax = ax_image[i]
|
|
ax.set_xticks([])
|
|
ax.set_yticks([])
|
|
ax.imshow(get_inset(denoised_invariant[plot_idx[i]]))
|
|
ax.set_xlabel('sigma = ' + str(np.round(sigma_range[plot_idx[i]], 2)))
|
|
|
|
for spine in ax_image[1].spines.values():
|
|
spine.set_edgecolor('red')
|
|
spine.set_linewidth(5)
|
|
|
|
#####################################################################
|
|
# Conversion to J-invariance
|
|
# =========================================
|
|
# The function `_invariant_denoise` acts as a J-invariant version of a
|
|
# given denoiser. It works by masking a fraction of the pixels, interpolating
|
|
# them, running the original denoiser, and extracting the values returned in
|
|
# the masked pixels. Iterating over the image results in a fully J-invariant
|
|
# output.
|
|
#
|
|
# For any given set of parameters, the J-invariant version of a denoiser
|
|
# is different from the original denoiser, but it is not necessarily better
|
|
# or worse. In the plot below, we see that, for the test image of a cat,
|
|
# the J-invariant version of a wavelet denoiser is significantly better
|
|
# than the original at small values of variance-reduction `sigma` and
|
|
# imperceptibly worse at larger values.
|
|
#
|
|
|
|
parameters_tested = [{'sigma': sigma, 'convert2ycbcr': True,
|
|
'wavelet': 'db2', 'multichannel': True}
|
|
for sigma in sigma_range]
|
|
|
|
denoised_original = [_denoise_wavelet(noisy, **params)
|
|
for params in parameters_tested]
|
|
|
|
ground_truth_loss_invariant = [mse(img, image) for img in denoised_invariant]
|
|
ground_truth_loss_original = [mse(img, image) for img in denoised_original]
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 4))
|
|
|
|
ax.plot(sigma_range, ground_truth_loss_invariant, color='C0', linestyle='--',
|
|
label='J-invariant')
|
|
ax.plot(sigma_range, ground_truth_loss_original, color='C1', linestyle='--',
|
|
label='Original')
|
|
ax.scatter(sigma_range[opt_idx], ground_truth_loss[opt_idx] + 0.001,
|
|
marker='v', color='red')
|
|
ax.legend()
|
|
ax.set_title(
|
|
'J-Invariant Denoiser Has Comparable Or '
|
|
'Better Performance At Same Parameters'
|
|
)
|
|
ax.set_ylabel('MSE')
|
|
ax.set_xlabel('sigma')
|
|
|
|
#####################################################################
|
|
# Comparing Different Classes of Denoiser
|
|
# =========================================
|
|
# The self-supervised loss can be used to compare different classes of
|
|
# denoiser in addition to choosing parameters for a single class.
|
|
# This allows the user to, in an unbiased way, choose the best parameters
|
|
# for the best class of denoiser for a given image.
|
|
#
|
|
# Below, we show this for an image of the hubble deep field with significant
|
|
# speckle noise added. In this case, the J-invariant calibrated denoiser is
|
|
# better than the default denoiser in each of three families of denoisers --
|
|
# Non-local means, wavelet, and TV norm. Additionally, the self-supervised
|
|
# loss shows that the TV norm denoiser is the best for this noisy image.
|
|
#
|
|
|
|
image = rgb2gray(img_as_float(hubble_deep_field()[100:250, 50:300]))
|
|
|
|
sigma = 0.4
|
|
noisy = random_noise(image, mode='speckle', var=sigma ** 2)
|
|
|
|
parameter_ranges_tv = {'weight': np.arange(0.01, 0.3, 0.02)}
|
|
_, (parameters_tested_tv, losses_tv) = calibrate_denoiser(
|
|
noisy,
|
|
denoise_tv_chambolle,
|
|
denoise_parameters=parameter_ranges_tv,
|
|
extra_output=True)
|
|
print(f'Minimum self-supervised loss TV: {np.min(losses_tv):.4f}')
|
|
|
|
best_parameters_tv = parameters_tested_tv[np.argmin(losses_tv)]
|
|
denoised_calibrated_tv = _invariant_denoise(noisy, denoise_tv_chambolle,
|
|
denoiser_kwargs=best_parameters_tv)
|
|
denoised_default_tv = denoise_tv_chambolle(noisy, **best_parameters_tv)
|
|
|
|
psnr_calibrated_tv = psnr(image, denoised_calibrated_tv)
|
|
psnr_default_tv = psnr(image, denoised_default_tv)
|
|
|
|
parameter_ranges_wavelet = {'sigma': np.arange(0.01, 0.3, 0.03)}
|
|
_, (parameters_tested_wavelet, losses_wavelet) = calibrate_denoiser(
|
|
noisy,
|
|
_denoise_wavelet,
|
|
parameter_ranges_wavelet,
|
|
extra_output=True)
|
|
print(f'Minimum self-supervised loss wavelet: {np.min(losses_wavelet):.4f}')
|
|
|
|
best_parameters_wavelet = parameters_tested_wavelet[np.argmin(losses_wavelet)]
|
|
denoised_calibrated_wavelet = _invariant_denoise(
|
|
noisy, _denoise_wavelet,
|
|
denoiser_kwargs=best_parameters_wavelet)
|
|
denoised_default_wavelet = _denoise_wavelet(noisy, **best_parameters_wavelet)
|
|
|
|
psnr_calibrated_wavelet = psnr(image, denoised_calibrated_wavelet)
|
|
psnr_default_wavelet = psnr(image, denoised_default_wavelet)
|
|
|
|
sigma_est = estimate_sigma(noisy)
|
|
|
|
parameter_ranges_nl = {'sigma': np.arange(0.6, 1.4, 0.2) * sigma_est,
|
|
'h': np.arange(0.6, 1.2, 0.2) * sigma_est}
|
|
|
|
parameter_ranges_nl = {'sigma': np.arange(0.01, 0.3, 0.03)}
|
|
_, (parameters_tested_nl, losses_nl) = calibrate_denoiser(noisy,
|
|
denoise_nl_means,
|
|
parameter_ranges_nl,
|
|
extra_output=True)
|
|
print(f'Minimum self-supervised loss NL means: {np.min(losses_nl):.4f}')
|
|
|
|
best_parameters_nl = parameters_tested_nl[np.argmin(losses_nl)]
|
|
denoised_calibrated_nl = _invariant_denoise(noisy, denoise_nl_means,
|
|
denoiser_kwargs=best_parameters_nl)
|
|
denoised_default_nl = denoise_nl_means(noisy, **best_parameters_nl)
|
|
|
|
psnr_calibrated_nl = psnr(image, denoised_calibrated_nl)
|
|
psnr_default_nl = psnr(image, denoised_default_nl)
|
|
|
|
print(f' PSNR')
|
|
print(f'NL means (Default) : {psnr_default_nl:.1f}')
|
|
print(f'NL means (Calibrated): {psnr_calibrated_nl:.1f}')
|
|
print(f'Wavelet (Default) : {psnr_default_wavelet:.1f}')
|
|
print(f'Wavelet (Calibrated): {psnr_calibrated_wavelet:.1f}')
|
|
print(f'TV norm (Default) : {psnr_default_tv:.1f}')
|
|
print(f'TV norm (Calibrated): {psnr_calibrated_tv:.1f}')
|
|
|
|
plt.subplots(figsize=(10, 12))
|
|
plt.imshow(noisy, cmap='Greys_r')
|
|
plt.xticks([])
|
|
plt.yticks([])
|
|
plt.title('Noisy Image')
|
|
|
|
get_inset = lambda x: x[0:100, -140:]
|
|
|
|
fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(15, 8))
|
|
|
|
for ax in axes.ravel():
|
|
ax.set_xticks([])
|
|
ax.set_yticks([])
|
|
|
|
axes[0, 0].imshow(get_inset(denoised_default_nl), cmap='Greys_r')
|
|
axes[0, 0].set_title('NL Means Default')
|
|
axes[1, 0].imshow(get_inset(denoised_calibrated_nl), cmap='Greys_r')
|
|
axes[1, 0].set_title('NL Means Calibrated')
|
|
axes[0, 1].imshow(get_inset(denoised_default_wavelet), cmap='Greys_r')
|
|
axes[0, 1].set_title('Wavelet Default')
|
|
axes[1, 1].imshow(get_inset(denoised_calibrated_wavelet), cmap='Greys_r')
|
|
axes[1, 1].set_title('Wavelet Calibrated')
|
|
axes[0, 2].imshow(get_inset(denoised_default_tv), cmap='Greys_r')
|
|
axes[0, 2].set_title('TV Norm Default')
|
|
axes[1, 2].imshow(get_inset(denoised_calibrated_tv), cmap='Greys_r')
|
|
axes[1, 2].set_title('TV Norm Calibrated')
|
|
|
|
for spine in axes[1, 2].spines.values():
|
|
spine.set_edgecolor('red')
|
|
spine.set_linewidth(5)
|
|
|
|
plt.show()
|