Metadata-Version: 1.0
Name: rmi-pytorch
Version: 0.1.1
Summary: Region Mutual Information loss in PyTorch
Home-page: https://github.com/RElbers/region-mutual-information-pytorch
Author: Robin Elbers
Author-email: UNKNOWN
License: MIT
Description: Region Mutual Information loss
        ==============================
        
        PyTorch implementation of the `Region Mutual Information Loss for
        Semantic Segmentation <https://arxiv.org/abs/1910.12037>`__.
        
        Example usage
        -------------
        
        With logits:
        
        .. code:: python
        
            import torch
            from rmi import RMILoss
        
            loss = RMILoss(with_logits=True)
        
            batch_size, classes, height, width = 5, 4, 64, 64
            pred = torch.rand(batch_size, classes, height, width, requires_grad=True)
            target = torch.empty(batch_size, classes, height, width).random_(2)
        
            output = loss(pred, target)
            output.backward()
        
        With probabilities:
        
        .. code:: python
        
            import torch
            from torch import nn
            from rmi import RMILoss
        
            m = nn.Sigmoid()
            loss = RMILoss(with_logits=False)
        
            batch_size, classes, height, width = 5, 4, 64, 64
            pred = torch.randn(batch_size, classes, height, width, requires_grad=True)
            target = torch.empty(batch_size, classes, height, width).random_(2)
        
            output = loss(m(pred), target)
            output.backward()
        
        Graphs
        ------
        
        Plot of the value of the loss between the prediction and target without
        the BCE component. Target is a random binary 256x256 matrix. For
        ``Random`` the prediction is a 256x256 matrix of probabilities
        initialized uniformly at random. For ``All zero`` the prediction is a
        256x256 matrix with all zeros. For ``1- target`` the prediction is the
        inverse of the target. The prediction is interpolated with the target
        by: ``input_i = (1 - α) * input + α * target``.
        
        .. image:: https://raw.githubusercontent.com/RElbers/region-mutual-information-pytorch/main/imgs/loss.png
        
        Difference between this implementation and the implementation in the
        official git `repository <https://github.com/ZJULearning/RMI>`__, with
        ``EPSILON = 0.0005`` and ``pool='max'``.
        
        .. image:: https://raw.githubusercontent.com/RElbers/region-mutual-information-pytorch/main/imgs/diff.png
        
        Execution time on tensors with batch size of 8 and with 21 classes.
        
        +----------------+--------------+--------------+
        | Size           | This         | Official     |
        +================+==============+==============+
        | 8x21x32x32     | 6.5722ms     | 6.3261ms     |
        +----------------+--------------+--------------+
        | 8x21x64x64     | 11.8159ms    | 12.6169ms    |
        +----------------+--------------+--------------+
        | 8x21x128x128   | 39.9946ms    | 40.3798ms    |
        +----------------+--------------+--------------+
        | 8x21x256x256   | 160.0352ms   | 160.9543ms   |
        +----------------+--------------+--------------+
        
        
        
Platform: UNKNOWN
