Source code for mridc.collections.common.losses.ssim

# encoding: utf-8
__author__ = "Dimitrios Karkalousos"

# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class SSIMLoss(nn.Module): """SSIM loss module.""" def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03): """ Args: win_size: Window size for SSIM calculation. k1: k1 parameter for SSIM calculation. k2: k2 parameter for SSIM calculation. """ super().__init__() self.win_size = win_size self.k1, self.k2 = k1, k2 self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size**2) NP = win_size**2 self.cov_norm = NP / (NP - 1)
[docs] def forward(self, X: torch.Tensor, Y: torch.Tensor, data_range: torch.Tensor): """ Parameters ---------- X: First input tensor. Y: Second input tensor. data_range: Data range of the input tensors. Returns ------- SSIM loss. """ if not isinstance(self.w, torch.Tensor): # type: ignore raise AssertionError self.w = self.w.to(X) # type: ignore data_range = data_range[:, None, None, None] C1 = (self.k1 * data_range) ** 2 C2 = (self.k2 * data_range) ** 2 ux = F.conv2d(X, self.w) # typing: ignore uy = F.conv2d(Y, self.w) # uxx = F.conv2d(X * X, self.w) uyy = F.conv2d(Y * Y, self.w) uxy = F.conv2d(X * Y, self.w) vx = self.cov_norm * (uxx - ux * ux) vy = self.cov_norm * (uyy - uy * uy) vxy = self.cov_norm * (uxy - ux * uy) A1, A2, B1, B2 = (2 * ux * uy + C1, 2 * vxy + C2, ux**2 + uy**2 + C1, vx + vy + C2) D = B1 * B2 S = (A1 * A2) / D return 1 - S.mean()