Metadata-Version: 2.1
Name: batchdist
Version: 0.1.2
Summary: A small pytorch package for efficiently running pair-wise operations such as distances on the batch-level.
Home-page: https://github.com/mi92/batchdist
License: BSD-3-Clause
Author: Michael Moor
Author-email: michael.moor@bsse.ethz.ch
Requires-Python: >=3.7,<4.0
Classifier: License :: OSI Approved :: BSD License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Requires-Dist: torch (>=1.6.0,<2.0.0)
Requires-Dist: unittest (>=0.0,<0.1)
Project-URL: Repository, https://github.com/mi92/batchdist
Description-Content-Type: text/markdown

# batchdist  

This is a small PyTorch-based package which allows for efficient batched operations, e.g. for computing distances without having to slowly loop over all instance pairs of a batch of data.

After having encountered mulitple instances of torch modules/methods promising to handling batches while only returning a vector of pairwise results (see example below) instead of the full matrix, this package serves as a tool to wrap such methods in order to return full matrices (e.g. distance matrices) using fast, batched operations (without loops). 

## Example  

First, let's define a custom distance function that only computes pair-wise distances for batches, so two batches of each 10 samples are 
converted to a distance vector of shape (10,).
```python  
>>> def dummy_distance(x,y):
        """
        This is a dummy distance d which allows for a batch dimension 
        (say with n instances in a batch), but does not return the full 
        n x n distance matrix but only a n-dimensional vector of the 
        pair-wise distances d(x_i,y_i) for all i in (1,...,n). 
        """
        x_ = x.sum(axis=[1,2])
        y_ = y.sum(axis=[1,2])
        return x_ + y_

# batchdist wraps a torch module around this callable to compute 
# the full n x n matrix with batched operations (no loops). 

>>> import batchdist as bd
>>> batched = bd.BatchDistance(dummy_distance)

# generate data (two batches of 256 samples of dimension [4,3])

>>> x1 = torch.rand(256,4,3)
>>> x2 = torch.rand(256,4,3)

>>> out1 = batched(x1, x2) # distance matrix of shape [256,256]
```
 
For more details, consult the included examples.


