# -*- coding: utf-8 -*-
from setuptools import setup

packages = \
['jax_metrics',
 'jax_metrics.losses',
 'jax_metrics.metrics',
 'jax_metrics.metrics.tm_port',
 'jax_metrics.metrics.tm_port.classification',
 'jax_metrics.metrics.tm_port.functional',
 'jax_metrics.metrics.tm_port.functional.classification',
 'jax_metrics.metrics.tm_port.utilities',
 'jax_metrics.regularizers']

package_data = \
{'': ['*']}

install_requires = \
['certifi>=2021.10.8,<2022.0.0',
 'einops>=0.4.0,<0.5.0',
 'optax>=0.1.1,<0.2.0',
 'treeo>=0.2.0.dev0,<0.3.0']

setup_kwargs = {
    'name': 'jax-metrics',
    'version': '0.1.0a2',
    'description': '',
    'long_description': '# JAX Metrics\n\n_A Metrics library for the JAX ecosystem_\n\n#### Main Features\n* Standard metrics that can be used in any JAX project.\n* Pytree abstractions that can natively integrate with all JAX APIs and pytree-supporting frameworks (flax.struct, equinox, treex, etc).\n* Distributed-friendly APIs that make it super easy to synchronize metrics across devices.\n* Automatic accumulation over epochs.\n\n\nJAX Metrics is implemented on top of [Treeo](https://github.com/cgarciae/treeo).\n\n## What is included?\n* The Keras-like `Loss` and `Metric` abstractions.\n* A `metrics` module containing popular metrics.\n* The `losses` and `regularizers` modules containing popular losses.\n* The `Metrics`, `Losses`, and `LossesAndMetrics` combinators.\n\n<!-- ## Why JAX Metrics? -->\n\n## Installation\nInstall using pip:\n```bash\npip install jax_metrics\n```\n\n## Status\nMetrics on this library are usually tested against their Keras or Torchmetrics counterparts for numerical equivalence. This code base comes from Treex and Elegy so it\'s already in use.\n\n## Getting Started\n\n### Metric\n\nThe `Metric` API consists of 3 basic methods:\n\n* `reset`: Used to both initialize and reset a metric.\n* `update`: Takes in new data and updates the metric state.\n* `compute`: Returns the current value of the metric.\n\nSimple usage looks like this:\n\n\n```python\nimport jax_metrics as jm\n\nmetric = jm.metrics.Accuracy()\n\n# Initialize the metric\nmetric = metric.reset()\n\n# Update the metric with a batch of predictions and labels\nmetric = metric.update(target=y, preds=logits)\n\n# Get the current value of the metric\nacc = metric.compute() # 0.95\n\n# alternatively, produce a logs dict\nlogs = metric.compute_logs() # {\'accuracy\': 0.95}\n```\n\nNote that `update` enforces the use of keyword arguments. Also the `Metric.name` property is used as the key in the returned dict, by default this is the name of the class in lowercase but can be overridden in the constructor via the `name` argument.\n\n#### Tipical Training Setup\n\nBecause Metrics are pytrees they can be used with `jit`, `pmap`, etc. On a more realistic scenario you will proably want to use them inside some of your JAX functions in a setup similar to this:\n\n```python\nimport jax_metrics as jm\n\nmetric = jm.metrics.Accuracy()\n\n@jax.jit\ndef init_step(metric: jm.Metric) -> jm.Metric:\n    return metric.reset()\n\n\ndef loss_fn(params, metric, x, y):\n    ...\n    metric = metric.update(target=y, preds=logits)\n    ...\n\n    return loss, metric\n\n@jax.jit\ndef train_step(params, metric, x, y):\n    grads, metric = jax.grad(loss_fn, has_aux=True)(\n        params, metric, x, y\n    )\n    ...\n    return params, metric\n```\n\nSince the loss function usually has access to the predictions and labels, its usually where you would call `metric.update`, and the new metric state can be returned as an auxiliary output.\n\n#### Distributed Training\n\nJAX Metrics has a distributed friendly API via the `batch_updates` and `aggregate` methods. A simple example of a loss function inside a data parallel setup could look like this:\n\n```python\ndef loss_fn(params, metric, x, y):\n    ...\n    # compuate batch update\n    batch_updates = metric.batch_updates(target=y, preds=logits)\n    # gather over all devices and aggregate\n    batch_updates = jax.lax.all_gather(batch_updates, "device").aggregate()\n    # update metric\n    metric = metric.merge(batch_updates)\n    ...\n```\n\nThe `batch_updates` method behaves similar to `update` but returns a new metric state with only information about that batch, `jax.lax.all_gather` "gathers" the metric state over all devices plus adds a new axis to the metric state, and `aggregate` reduces the metric state over all devices (first axis). Finally, `merge` combines the accumulated metric state over the previous batches with the batch updates.\n\n### Loss\n\nThe `Loss` API just consists of a `__call__` method. Simple usage looks like this:\n\n```python\nimport jax_metrics as jm\n\ncrossentropy = jm.losses.Crossentropy()\n\n# get reduced loss value\nloss = crossentropy(target=y, preds=logits) # 0.23\n```\nNote that losses are not pytrees so they should be marked as static. Similar to Keras, all losses have a `reduction` strategy that can be specified in the constructor and (usually) makes sure that the output is a scalar.\n\n<details>\n<summary><b>Why have losses in a metrics library?</b></summary>\n<!-- #### Why have losses in a metrics library? -->\n\nThere are a few reasons for having losses in a metrics library:\n\n1. Most code from this library was originally written for and will still be consumed by Elegy. Since Elegy needs support for calculating cumulative losses, as you will see later, a Metric abstraction called `Losses` was created for this.\n2. A couple of API design decisions are shared between the `Loss` and `Metric` APIs. This includes: \n    * `__call__` and `update` both accept any number keyword only arguments. This is used to facilitate composition (see [Combinators](#combinators) section).\n    * Both classes have the `index_into` and `map_arg` methods that allow them to modify how arguments are consumed.\n    * Argument names are standardized to be consistent when ever possible, e.g. both `metrics.Accuracy` and `losses.Crossentropy` use the `target` and `preds` arguments. This is super convenient for the `LossesAndMetrics` combinator.\n\n</details>\n\n### Combinators\nCombinators as instances of `Metric` that enable you to group together multiple instances while maintaining the same API.\n#### Metrics\nThe `Metrics` combinator lets you combine multiple metrics into a single metric.\n\n```python\nmetrics = jm.Metrics([\n    jm.metrics.Accuracy(),\n    jm.metrics.F1(), # not yet implemented 😅, coming soon?\n])\n\n# same API\nmetrics = metrics.reset()\n# same API\nmetrics = metrics.update(target=y, preds=logits)\n# compute now returns a dict\nmetrics.compute() # {\'accuracy\': 0.95, \'f1\': 0.87}\n# same as compute_logs in the case\nmetrics.compute_logs() # {\'accuracy\': 0.95, \'f1\': 0.87}\n```\n\nAs you can see the `Metrics.update` method accepts and forwards all the arguments required by the individual metrics. In this example they use the same arguments, but in practice they may consume different subsets of the arguments. Also, if names are repeated then unique names are generated for each metric by appending a number to the metric name.\n\nIf a dictionary is used instead of a list, the keys are used instead of the `name` property of the metrics to determine the key in the returned dict.\n\n```python\nmetrics = jm.Metrics({\n    "acc": jm.metrics.Accuracy(),\n    "f_one": jm.metrics.F1(), # not yet implemented 😅, coming soon?\n})\n\n# same API\nmetrics = metrics.reset()\n# same API\nmetrics = metrics.update(target=y, preds=logits)\n# compute new returns a dict\nmetrics.compute() # {\'acc\': 0.95, \'f_one\': 0.87}\n# same as compute_logs in the case\nmetrics.compute_logs() # {\'acc\': 0.95, \'f_one\': 0.87}\n```\n\nYou can use nested structures of dicts and lists to group metrics, the keys of the dicts are used to determine group names. Group names and metrics names are concatenated using `"/"` e.g. `"{group_name}/{metric_name}"`.\n\n#### Losses\n\n`Losses` is a `Metric` combinator that behaves very similarly to `Metrics` but contains `Loss` instances. `Losses` calculates the cumulative **mean** value of each loss over the batches.\n\n```python\nlosses = jm.Losses([\n    jm.losses.Crossentropy(),\n    jm.regularizers.L2(1e-4),\n])\n\n# same API\nlosses = losses.reset()\n# same API\nlosses = losses.update(target=y, preds=logits, parameters=params)\n# compute new returns a dict\nlosses.compute() # {\'crossentropy\': 0.23, \'l2\': 0.005}\n# same as compute_logs in the case\nlosses.compute_logs() # {\'crossentropy\': 0.23, \'l2\': 0.005}\n# you can also compute the total loss\nloss = losses.total_loss() # 0.235\n```\n\nAs with `Metrics`, the `update` method accepts and forwards all the arguments required by the individual losses. In this example `target` and `preds` are used by the `Crossentropy`, while `parameters` is used by the `L2`. The `total_loss` method returns the sum of all values returned by `compute`.\n\nIf a dictionary is used instead of a list, the keys are used instead of the `name` property of the losses to determine the key in the returned dict.\n\n```python\nlosses = jm.Losses({\n    "xent": jm.losses.Crossentropy(),\n    "l_two": jm.regularizers.L2(1e-4),\n})\n\n# same API\nlosses = losses.reset()\n# same API\nlosses = losses.update(target=y, preds=logits, parameters=params)\n# compute new returns a dict\nlosses.compute() # {\'xent\': 0.23, \'l_two\': 0.005}\n# same as compute_logs in the case\nlosses.compute_logs() # {\'xent\': 0.23, \'l_two\': 0.005}\n# you can also compute the total loss\nloss = losses.total_loss() # 0.235\n```\n\nIf you want to use `Losses` to calculate the loss of a model, you should use `batch_updates` followed by `total_loss` to get the correct batch loss. For example, a loss function could be written as:\n\n```python\ndef loss_fn(..., losses):\n    ...\n    batch_updates = losses.batch_updates(target=y, preds=logits, parameters=params)\n    loss = batch_updates.total_loss()\n    losses = losses.merge(batch_updates)\n    ...\n    return loss, losses\n```\nFor convenience, the previous pattern can be simplified to a single line using the `loss_and_update` method:\n```python\ndef loss_fn(...):\n    ...\n    loss, lossses = losses.loss_and_update(target=y, preds=logits, parameters=params)\n    ...\n    return loss, losses\n```\n#### LossesAndMetrics\n\nThe `LossesAndMetrics` combinator is a `Metric` that combines the `Lossses` and `Metrics` combinators. Its main utility instead of using these independently is that it can computes a single logs dictionary while making sure that names/keys remain unique in case of collisions.\n\n```python\nlosses_and_metrics = jm.LossesAndMetrics(\n    metrics=[\n        jm.metrics.Accuracy(),\n        jm.metrics.F1(), # not yet implemented 😅, coming soon?\n    ],\n    losses=[\n        jm.losses.Crossentropy(),\n        jm.regularizers.L2(1e-4),\n    ],\n)\n\n# same API\nlosses_and_metrics = losses_and_metrics.reset()\n# same API\nlosses_and_metrics = losses_and_metrics.update(\n    target=y, preds=logits, parameters=params\n)\n# compute new returns a dict\nlosses_and_metrics.compute() # {\'loss\': 0.235, \'accuracy\': 0.95, \'f1\': 0.87, \'crossentropy\': 0.23, \'l2\': 0.005}\n# same as compute_logs in the case\nlosses_and_metrics.compute_logs() # {\'loss\': 0.235, \'accuracy\': 0.95, \'f1\': 0.87, \'crossentropy\': 0.23, \'l2\': 0.005}\n# you can also compute the total loss\nloss = losses_and_metrics.total_loss() # 0.235\n```\n\nThanks to consistent naming, `Accuracy`, `F1` and `Crossentropy` all consume the same `target` and `preds` arguments, while `L2` consumes `parameters`. For convenience a `"loss"` key is added to the returned logs dictionary.\n\nIf you want to use `LossesAndMetrics` to calculate the loss of a model, you should use `batch_updates` followed by `total_loss` to get the correct batch loss. For example, a loss function could be written as:\n\n```python\ndef loss_fn(...):\n    ...\n    batch_updates = losses_and_metrics.batch_updates(\n        target=y, preds=logits, parameters=params\n    )\n    loss = batch_updates.total_loss()\n    losses_and_metrics = losses_and_metrics.merge(batch_updates)\n    ...\n    return loss, losses_and_metrics\n```\n\nFor convenience, the previous pattern can be simplified to a single line using the `loss_and_update` method:\n\n```python\ndef loss_fn(...):\n    ...\n    loss, losses_and_metrics = losses_and_metrics.loss_and_update(\n        target=y, preds=logits, parameters=params\n    )\n    ...\n    return loss, losses_and_metrics\n```\n\nIf the loss function is running in a distributed context (e.g. `pmap`) you can calculate the device-local loss and synchronize the metric state across devices like this:\n\n\n```python\ndef loss_fn(...):\n    ...\n    batch_updates = losses_and_metrics.batch_updates(\n        target=y, preds=logits, parameters=params\n    )\n    loss = batch_updates.total_loss()\n    batch_updates = jax.lax.all_gather(batch_updates, "device").aggregate()\n    losses_and_metrics = losses_and_metrics.merge(batch_updates)\n    ...\n    return loss, losses_and_metrics\n```',
    'author': 'Cristian Garcia',
    'author_email': 'cgarcia.e88@gmail.com',
    'maintainer': None,
    'maintainer_email': None,
    'url': 'https://cgarciae.github.io/jax_metrics',
    'packages': packages,
    'package_data': package_data,
    'install_requires': install_requires,
    'python_requires': '>=3.7,<3.11',
}


setup(**setup_kwargs)
