Metadata-Version: 2.1
Name: clip-jax
Version: 0.0.1
Summary: Training of CLIP in JAX
License-Expression: Apache-2.0
License-File: LICENSE.md
Requires-Dist: einops
Requires-Dist: flax
Requires-Dist: jax>=0.2.6
Requires-Dist: jaxlib
Requires-Dist: numpy
Requires-Dist: tensorflow-io[tensorflow-cpu]
Requires-Dist: transformers
Provides-Extra: dev
Requires-Dist: black[jupyter]; extra == 'dev'
Requires-Dist: isort; extra == 'dev'
Requires-Dist: optax; extra == 'dev'
Requires-Dist: tqdm; extra == 'dev'
Description-Content-Type: text/markdown

# CLIP-JAX

This repository is used to CLIP models from 🤗 transformers using JAX.

## Installation

```bash
pip install -e .
```

## Usage

1. Use [dataset/prepare_dataset.ipynb](dataset/prepare_dataset.ipynb) to prepare your dataset.
1. Train the model with [training/train_clip.py](training/train_clip.py).

## Supported downstream tasks

- [x] Image classification with `FlaxCLIPVisionModelForImageClassification`

## TODO

- [ ] Add guides
- [ ] Add pre-trained models
- [ ] Add more downstream tasks
