Metadata-Version: 2.3
Name: decent-optim-wrapper
Version: 0.1.0
Summary: A wrapper for optimizers in PyTorch to enable decentralized training
Author: Zesen Wang
Author-email: Zesen Wang <zesen@kth.se>
Requires-Dist: loguru>=0.7.3
Requires-Dist: torch>=2.8.0
Requires-Dist: torchvision>=0.23.0
Requires-Python: >=3.12
Description-Content-Type: text/markdown

# decent-optim-wrapper

A PyTorch optimizer wrapper for decentralized distributed training. This package enables efficient decentralized optimization by wrapping any PyTorch optimizer and managing communication between processes using various network topologies.

## Features

- 🔄 **Decentralized Training**: Enable decentralized optimization without a central parameter server
- 🌐 **Multiple Topologies**: Support for Ring, Complete, and custom topologies
- 📦 **Efficient Communication**: Bucket-based gradient communication for reduced overhead
- ⚡ **Asynchronous Operations**: Non-blocking communication for improved performance
- 🎯 **PyTorch Native**: Seamless integration with existing PyTorch training code
- 🔧 **Flexible**: Works with any PyTorch optimizer (SGD, Adam, AdamW, etc.)

## Installation

```bash
pip install decent-optim-wrapper
```

Or install with uv:
```bash
uv add decent-optim-wrapper
```

Or install from source:

```bash
git clone https://github.com/yourusername/decent-optim-wrapper.git
cd decent-optim-wrapper
pip install -e .
```

## Requirements

- Python >= 3.12
- PyTorch >= 2.8.0
- torchvision >= 0.23.0
- loguru >= 0.7.3

## Quick Start

Here's a basic example of using `DecentOptimWrapper` in a distributed training setup:

```python
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.optim import SGD
from decent_optim_wrapper.wrapper import DecentOptimWrapper

# Initialize distributed training
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()
local_world_size = torch.cuda.device_count()

# Create your model and base optimizer
model = nn.Linear(10, 10).cuda()
base_optimizer = SGD(model.parameters(), lr=0.01)

# Wrap with DecentOptimWrapper
optimizer = DecentOptimWrapper(
    optimizer=base_optimizer,
    rank=rank,
    world_size=world_size,
    local_world_size=local_world_size,
    topology='ring',  # or 'complete'
    bucket_cap_mb=25
)

# Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        
        inputs, targets = batch
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        loss.backward()
        optimizer.step()  # Handles decentralized averaging
```

## API Reference

### DecentOptimWrapper

The main wrapper class for decentralized optimization.

#### Parameters

- **optimizer** (`torch.optim.Optimizer`): The base PyTorch optimizer to wrap (e.g., SGD, Adam, AdamW)
- **rank** (`int`): The rank of the current process in the distributed setup
- **world_size** (`int`): Total number of processes participating in training
- **local_world_size** (`int`): Number of processes in the local node/group
- **topology** (`str`): Network topology for communication. Options: `'ring'`, `'complete'`
- **bucket_cap_mb** (`int`, optional): Maximum bucket size in megabytes for gradient bucketing. Default: `25`

#### Methods

##### `step(closure=None)`

Performs a single optimization step with decentralized averaging.

```python
optimizer.step()
```

**Note**: The `closure` parameter is not supported in this implementation.

##### `zero_grad(set_to_none=True)`

Clears the gradients of all optimized parameters.

```python
optimizer.zero_grad()
```

**Parameters**:
- **set_to_none** (`bool`): If `True`, sets gradients to `None` instead of zero. Default: `True`

##### `global_avg(may_revert=True)`

Performs a global average of parameters across all processes (centralized operation).

```python
optimizer.global_avg()
```

**Parameters**:
- **may_revert** (`bool`): If `True`, allows reverting the global average. Default: `True`

##### `revert_global_avg()`

Reverts the last global average operation, restoring the previous parameter values.

```python
optimizer.revert_global_avg()
```

## Topologies

The wrapper supports different communication topologies for decentralized training:

### Ring Topology

In a ring topology, each process communicates with its two neighbors in a circular arrangement. This provides a good balance between communication efficiency and convergence speed.

```python
optimizer = DecentOptimWrapper(
    optimizer=base_optimizer,
    rank=rank,
    world_size=world_size,
    local_world_size=local_world_size,
    topology='ring'
)
```

**Requirements**: 
- World size must be even
- Each process alternates between communicating with left and right neighbors

### Complete Topology

In a complete topology, all processes communicate with each other in every step, achieving faster convergence at the cost of higher communication overhead.

```python
optimizer = DecentOptimWrapper(
    optimizer=base_optimizer,
    rank=rank,
    world_size=world_size,
    local_world_size=local_world_size,
    topology='complete'
)
```

**Requirements**:
- World size must be greater than 1

### Custom Topologies

You can implement custom topologies by extending the `Topology` class:

```python
from decent_optim_wrapper.topo import Topology, TopologyFactory

class MyCustomTopology(Topology):
    def assign_groups(self):
        # Implement your custom group assignment logic
        # Return a list of lists of lists representing groups for each rank
        groups = [[] for _ in range(self._world_size)]
        # ... your logic here ...
        return groups

# Register the custom topology
TopologyFactory.add_topology('my_custom', MyCustomTopology)

# Use it
optimizer = DecentOptimWrapper(
    optimizer=base_optimizer,
    rank=rank,
    world_size=world_size,
    local_world_size=local_world_size,
    topology='my_custom'
)
```

## Advanced Usage

### Bucket Configuration

The `bucket_cap_mb` parameter controls how parameters are grouped for communication. Larger buckets reduce communication overhead but increase memory usage:

```python
# Smaller buckets (more communication ops, less memory)
optimizer = DecentOptimWrapper(
    optimizer=base_optimizer,
    rank=rank,
    world_size=world_size,
    local_world_size=local_world_size,
    topology='ring',
    bucket_cap_mb=10  # 10 MB buckets
)

# Larger buckets (fewer communication ops, more memory)
optimizer = DecentOptimWrapper(
    optimizer=base_optimizer,
    rank=rank,
    world_size=world_size,
    local_world_size=local_world_size,
    topology='ring',
    bucket_cap_mb=100  # 100 MB buckets
)
```

### Global Averaging for Evaluation

During evaluation, you might want to perform a global average to get the true ensemble model:

```python
# Before evaluation
optimizer.global_avg()

# Evaluate model
model.eval()
with torch.no_grad():
    for batch in val_loader:
        outputs = model(batch)
        # ... evaluation logic ...

# Optionally revert to continue decentralized training
optimizer.revert_global_avg()
model.train()
```

### Integration with Learning Rate Schedulers

The wrapper works seamlessly with PyTorch learning rate schedulers:

```python
from torch.optim.lr_scheduler import StepLR

base_optimizer = SGD(model.parameters(), lr=0.1)
optimizer = DecentOptimWrapper(
    optimizer=base_optimizer,
    rank=rank,
    world_size=world_size,
    local_world_size=local_world_size,
    topology='ring'
)

scheduler = StepLR(base_optimizer, step_size=10, gamma=0.1)

for epoch in range(num_epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        loss = train_step(model, batch)
        loss.backward()
        optimizer.step()
    
    scheduler.step()  # Update learning rate
```

## How It Works

The `DecentOptimWrapper` implements decentralized optimization through the following mechanism:

1. **Bucketing**: Parameters are grouped into buckets based on `bucket_cap_mb` for efficient communication
2. **Local Update**: Each process performs a local gradient descent step using the wrapped optimizer
3. **Asynchronous Communication**: Parameters are averaged with neighboring processes according to the topology
4. **Non-blocking**: Communication happens asynchronously to overlap with computation

This approach enables each process to maintain its own model while gradually converging through periodic averaging with neighbors, eliminating the need for a central parameter server.

## Contributing

Contributions are welcome! Please feel free to submit issues or pull requests.

## Author

**Zesen Wang**  
Email: zesen@kth.se

## License

This project is licensed under the [MIT License](./LICENSE).

## Citation

If you use this package in your research, please cite:

```bibtex
@software{decent_optim_wrapper,
  author = {Wang, Zesen},
  title = {decent-optim-wrapper: A PyTorch Wrapper for Decentralized Optimization},
  year = {2025},
  url = {https://github.com/yourusername/decent-optim-wrapper}
}
