Metadata-Version: 2.1
Name: egnn_jax
Version: 0.2
Summary: E(3) GNN in jax
Author: Gianluca Galletti
Author-email: g.galletti@tum.de
Classifier: Programming Language :: Python :: 3.8
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE

# E(n) Equivariant GNN in jax
Reimplementation of [EGNN](https://arxiv.org/abs/2102.09844) in jax. Original work by Victor Garcia Satorras, Emiel Hogeboom and Max Welling.

## Installation
```
python -m pip install egnn-jax
```

Or clone this repository and build locally
```
python -m pip install -e .
```

### GPU support
Upgrade `jax` to the gpu version
```
pip install --upgrade "jax[cuda]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

## Validation
N-body (charged) is included for validation from the original paper. Times are  __model only__ on batches of 100 graphs, in (global) single precision.
|                  |  MSE  | Inference [ms]* |
|------------------|-------|-----------------|
| torch (original) | .0071 |      8.27       |
| jax (ours)       | .011  |      0.94       |

\* remeasured (Quadro RTX 4000)

### Validation install

The N-Body experiments are only included in the github repo, so it needs to be cloned first.
```
git clone https://github.com/gerkone/egnn-jax
```

They are adapted from the original implementation, so additionally `torch` and `torch_geometric` are needed (cpu versions are enough).
```
pip3 install torch==1.12.1 --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install -r nbody/requirements.txt
```

### Valdation usage
The charged N-body dataset has to be locally generated in the directory [/nbody/data](/nbody/data).
```
python3 -u generate_dataset.py --num-train=3000
```
Then, the model can be trained and evaluated (from the repo root) with
```
python main.py --epochs=500 --batch-size=100 --lr=1e-4 --weight-decay=1e-8
```

## Acknowledgements
This implementation heavily borrows from the [original pytorch code](https://github.com/vgsatorras/egnn).
