Metadata-Version: 2.1
Name: fastmtl
Version: 1.1.0
Summary: Multi-task learning utilities for fastai
Home-page: https://github.com/bdsaglam/fastmtl/tree/master/
Author: Barış Deniz Sağlam
Author-email: bdsaglam@gmail.com
License: Apache Software License 2.0
Keywords: deep learning,multi-task learning,fastai,pytorch
Platform: UNKNOWN
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Natural Language :: English
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: License :: OSI Approved :: Apache Software License
Requires-Python: >=3.7
Description-Content-Type: text/markdown
Provides-Extra: dev
License-File: LICENSE

# fastmtl
> Multi-task learning utilities for fastai


## Install

`pip install fastmtl`

## Usage

### Loss

Apply a loss function per model output and get weighted sum of them. For instance, if the first model output is for classification and the second model output is for regression,
```py
from fastmtl.loss import CombinedLoss
loss_func = CombinedLoss(CrossEntropyLossFlat(), MSELossFlat(), weight=[1.0, 3.0])
```

### Metric

Apply metrics for each model output. For instance, if we have a model making classification and regression, we can evaluate each model output with relevant metrics. Assuming that model outputs a tuple of tensors for classification and regression, respectively:

```py
from fastai.metrics import F1Score, R2Score
from fastmtl.metric import mtl_metrics

clf_f1_macro =  F1Score(average='macro')
clf_f1_macro.name = 'clf_f1(macro)'
clf_f1_micro =  F1Score(average='micro')
clf_f1_micro.name = 'clf_f1(micro)'

reg_r2 = R2Score()
reg_r2.name = 'reg_r2'

# metrics for classification in the first list 
# metrics for regression in the second list 
metrics = mtl_metrics([clf_f1_macro, clf_f1_micro], [reg_r2])

learn = Learner(
    ...
    metrics=metrics,
)
```

## Tutorials

[Video distortion detection](https://bdsaglam.github.io/fastmtl/tutorial.vqa)

## TODO
- [ ] Support tabular learner
- [ ] Support fastai>=2.7


