Metadata-Version: 2.1
Name: pyqg-jax
Version: 0.6.0
Summary: Quasigeostrophic model in JAX (port of PyQG)
Author: Karl Otness
Requires-Python: >=3.8
Description-Content-Type: text/markdown
Classifier: Development Status :: 2 - Pre-Alpha
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Atmospheric Science
Classifier: Topic :: Scientific/Engineering :: Physics
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Requires-Dist: jax>=0.3.21
Requires-Dist: jaxtyping
Project-URL: Bug Tracker, https://github.com/karlotness/pyqg-jax/issues
Project-URL: Documentation, https://pyqg-jax.readthedocs.io
Project-URL: Homepage, https://github.com/karlotness/pyqg-jax
Project-URL: Source Code, https://github.com/karlotness/pyqg-jax

# PyQG JAX Port

[![PyQG-JAX on PyPI](https://img.shields.io/pypi/v/pyqg-jax)][pypi]
[![Documentation](https://readthedocs.org/projects/pyqg-jax/badge/?version=latest)][docs]
[![Tests](https://github.com/karlotness/pyqg-jax/actions/workflows/test.yml/badge.svg)][tests]

This is a partial port of [PyQG](https://github.com/pyqg/pyqg) to
[JAX](https://github.com/google/jax) which enables GPU acceleration,
batching, automatic differentiation, etc.

- **Documentation:** https://pyqg-jax.readthedocs.io/en/latest/
- **Source Code:** https://github.com/karlotness/pyqg-jax
- **Bug Reports:** https://github.com/karlotness/pyqg-jax/issues

⚠️ **Warning:** this is a partial, early stage port. There may be bugs
and other numerical issues. The API may evolve as work continues.

## Installation
Install from PyPI using pip:
```console
$ python -m pip install pyqg-jax
```
This should install required dependencies, but JAX itself may require
special attention. Follow the [JAX installation
instructions](https://github.com/google/jax#installation).

## Usage
[Documentation][docs] is a work in progress. The parameters `QGModel`
implemented here are the same as for the model in the original PyQG,
so consult the [pyqg
documentation](https://pyqg.readthedocs.io/en/latest/) for details.

However, there are a few overarching changes used to make the models
JAX-compatible:

1. The model state is now a separate, immutable object rather than
   being attributes of the `QGModel` class

2. Time-stepping is now separated from the models. Use
   `steppers.AB3Stepper` for the same time stepping as in the original
   `QGModel`.

3. Random initialization requires an explicit `key` variable as with
   all JAX random number generation.

The `QGModel` uses double precision (`float64`) values for part of its
computation regardless of the precision setting. Make sure JAX is set
to enable 64-bit. [See the
documentation](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision)
for details. One option is to set the following environment variables:
```bash
export JAX_ENABLE_X64=True
export JAX_DEFAULT_DTYPE_BITS=32
```
or use the [`%env`
magic](https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-env)
in a Jupyter notebook.

### Short Example
A short example initializing a `QGModel`, adding a parameterization,
and taking a single step.
```pycon
>>> import pyqg_jax
>>> import jax
>>> # Construct model, parameterization, and time-stepper
>>> stepped_model = pyqg_jax.steppers.SteppedModel(
...     model=pyqg_jax.parameterizations.smagorinsky.apply_parameterization(
...         pyqg_jax.qg_model.QGModel(),
...         constant=0.08,
...     ),
...     stepper=pyqg_jax.steppers.AB3Stepper(dt=3600.0),
... )
>>> # Initialize the model state (wrapped in stepper and parameterization state)
>>> stepper_state = stepped_model.create_initial_state(
...     jax.random.PRNGKey(0)
... )
>>> # Compute next state
>>> next_stepper_state = stepped_model.step_model(stepper_state)
>>> # Unwrap the result from the stepper and parameterization
>>> next_param_state = next_stepper_state.state
>>> next_model_state = next_param_state.model_state
>>> final_q = next_model_state.q
```
For repeated time-stepping combine `step_model` with
[`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html).

## License
This software is distributed under the MIT license. See LICENSE.txt
for the license text.

[pypi]: https://pypi.org/project/pyqg-jax
[docs]: https://pyqg-jax.readthedocs.io/en/latest/
[tests]: https://github.com/karlotness/pyqg-jax/actions

