Metadata-Version: 2.2
Name: tn4ml
Version: 1.0.5
Summary: Tensor Networks for Machine Learning
Home-page: https://github.com/bsc-quantic/tn4ml/tree/master
Author: Ema Puljak, Sergio Sanchez Ramirez, Sergi Masot Llima, Jofre Vallès-Muns
License: MIT
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: autoray>=0.3.0
Requires-Dist: dask
Requires-Dist: funcy
Requires-Dist: numpy
Requires-Dist: opt_einsum
Requires-Dist: quimb>=1.4.1
Requires-Dist: jaxlib
Requires-Dist: jax
Requires-Dist: optax
Requires-Dist: flax
Requires-Dist: pandas
Requires-Dist: nevergrad
Requires-Dist: chocolate
Requires-Dist: baytune
Requires-Dist: optuna
Requires-Dist: scikit-optimize
Requires-Dist: kahypar
Provides-Extra: docs
Requires-Dist: sphinx<8.0.0; extra == "docs"
Requires-Dist: sphinx-book-theme; extra == "docs"
Requires-Dist: ipykernel; extra == "docs"
Requires-Dist: nbsphinx; extra == "docs"
Requires-Dist: myst-parser; extra == "docs"
Requires-Dist: sphinxcontrib-bibtex; extra == "docs"
Requires-Dist: sphinxcontrib-devhelp; extra == "docs"
Requires-Dist: sphinxcontrib-htmlhelp; extra == "docs"
Requires-Dist: sphinxcontrib-jsmath; extra == "docs"
Requires-Dist: sphinxcontrib-qthelp; extra == "docs"
Requires-Dist: sphinxcontrib-serializinghtml; extra == "docs"
Requires-Dist: sphinx-rtd-theme; extra == "docs"
Requires-Dist: sphinx-copybutton; extra == "docs"
Requires-Dist: sphinx-gallery; extra == "docs"
Requires-Dist: tensorflow; extra == "docs"
Requires-Dist: matplotlib; extra == "docs"
Provides-Extra: test
Requires-Dist: pytest; extra == "test"
Provides-Extra: examples
Requires-Dist: matplotlib; extra == "examples"
Requires-Dist: scikit-learn; extra == "examples"
Requires-Dist: argparse; extra == "examples"
Requires-Dist: tensorflow; extra == "examples"
Requires-Dist: seaborn; extra == "examples"
Dynamic: author
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: home-page
Dynamic: license
Dynamic: provides-extra
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary

<img src="docs/_static/logo.png" position="center" alt="logo" width="500" height="200">

# Tensor Networks for Machine Learning
![Static Badge](https://img.shields.io/badge/tests-passing-blue)
![Static Badge](https://img.shields.io/badge/docs-passing-green)<br>
**tn4ml** is a Python library that handles tensor networks for machine learning applications.<br>
It is built on top of **Quimb**, for Tensor Network objects, and **JAX**, for optimization pipeline.<br>
For now, the library supports 1D Tensor Network structures: 
- **Matrix Product State**
- **Matrix Product Operator**
- **Spaced Matrix Product Operator**

It supports different **embedding** functions, **initialization** techniques, **objective functions** and **optimization strategies**.<br>

## Installation

First create a virtualenv using `pyenv` or `conda`. Then install the package and its dependencies.
<br>

**With** `pip` (tag v1.0.5):
```bash
pip install tn4ml
```
<br>

or **directly from github**:
```bash
pip install -U git+https://github.com/bsc-quantic/tn4ml.git
```
<br>

If you want to test and edit the code, you can clone the local version of the package and install it.
```bash
git clone https://github.com/bsc-quantic/tn4ml.git
pip install -e tn4ml/
```
If you want to install dependices for *docs*, *test* and *examples*:

```zsh
pip install "tn4ml[docs]"
```
```zsh
pip install "tn4ml[test]"
```
```zsh
pip install "tn4ml[examples]"
```


**Accelerated runtime** <br>

(Optional) To improve runtime precision set these flags:
```python
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_default_matmul_precision', 'highest')
```

**Running on GPU**<br>
Before everything install `JAX` version that supports CUDA and its suitable for runs on GPU.<br>
Checkout how to install here: [jax[cuda]](https://docs.jax.dev/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-via-pip-easier) <br>

Next, at the beginning of your script set:
```python
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use GPU 0 - or set any GPU ID
import jax
jax.config.update("jax_platform_name", 'gpu')
```
Then when training `Model` set:
```python
device = 'gpu'
model.configure(device=device)
```

## Documentation
Visit [tn4ml.readthedocs.io](https://tn4ml.readthedocs.io/en/latest/)

## Example notebooks

[TN for Classification](docs/source/examples/mnist_classification.ipynb)<br>
[TN for Anomaly Detection](docs/source/examples/mnist_ad.ipynb)<br>
[TN for Anomaly Detection with DMRG-like method](docs/source/examples/mnist_ad_sweeps.ipynb)

## Examples from the paper
[Breast Cancer Classification](docs/source/examples/supervised)<br>
[Unsupervised learning with MNIST](docs/source/examples/unsupervised)


## Citation

If you use **tn4ml** in your work, please cite the following paper: [arXiv:2502.13090](https://arxiv.org/abs/2502.13090)

```bibtex
@article{puljak2025tn4mltensornetworktraining,
      title={tn4ml: Tensor Network Training and Customization for Machine Learning}, 
      author={Ema Puljak and Sergio Sanchez-Ramirez and Sergi Masot-Llima and Jofre Vallès-Muns and Artur Garcia-Saez and Maurizio Pierini},
      year={2025},
      eprint={2502.13090},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2502.13090}, 
      }
```


## License
MIT license - check it out [here](LICENSE)
