Metadata-Version: 2.4
Name: master_agent
Version: 0.0.67
Summary: A library providing the tools to solve complex environments in Minigrid using LgTS
Author: Stevo Huncho
Author-email: stevo@stevohuncho.com
Keywords: reinforcement learning,actor-critic,a2c,ppo,multi-processes,gpu,teacher student,ts
Description-Content-Type: text/markdown
License-File: LICENCE
Requires-Dist: torch
Requires-Dist: minigrid
Requires-Dist: numpy
Requires-Dist: gymnasium
Requires-Dist: stable_baselines3
Requires-Dist: opencv-python
Requires-Dist: imageio
Requires-Dist: matplotlib
Requires-Dist: openai
Dynamic: author
Dynamic: author-email
Dynamic: description
Dynamic: description-content-type
Dynamic: keywords
Dynamic: license-file
Dynamic: requires-dist
Dynamic: summary

# master-minigrid-agent
A python module for training an RL agent on any Minigrid environment using LgTS.

## Installation
```bash
pip install master-agent
```

## Description
A python library providing tools for an **all-in-one** solution to the [GTRI Research Paper](https://arxiv.org/pdf/2310.09454) `LgTS: Dynamic Task Sampling using LLM-generated sub-goals for
Reinforcement Learning Agents`.

### Includes
- Prebuilt Minigrid Environments
- LLM-based providers for Subtask generation + evaluation
- Teacher Student Algorithm implementation using PPO policies
- Automatic Minigrid Tileset Identification

### Methodology
Methodology is based off the [GTRI Research Paper](https://arxiv.org/pdf/2310.09454). 
#### Brief Overview
`llm.gen_2d_array() -> create DAG -> use DAG to train set of policies using Teacher Student algorithm`.

## Prebuilt Minigrid Environments
`master_agent.llm` provides **7** customized environments based on the research paper and designed for evaluation of RL success on specific obstacles.
- **Complex Env** (Copy of the example environment via [GTRI Research Paper](https://arxiv.org/pdf/2310.09454))
- **KeyOne Env**
- **KeyTwo Env**
- **LavaIsWall Env**
- **+ No Lava Variants**

## Subtask Generation
Generate **2D Array of Paths** using the `SubtasksGenerator` class
### Example of 2D Array
```py
[
    ['At(OutsideRoom)', 'Holding(Key1)', 'Unlocked(Door1)', 'At(Green_Goal)'], 
    ['At(OutsideRoom)', 'Holding(Key2)', 'Unlocked(Door2)', 'At(Green_Goal)'], 
    ['At(OutsideRoom)', 'Holding(Key3)', 'At(Green_Goal)'], 
    ['At(OutsideRoom)', 'At(Wall)', 'At(Green_Goal)'],
]
```

### Generation + Validation
```py
from llm.client import LlmClient
from llm.subtasks import SubtasksGenerator, validate_subtask_paths

# Create llm_client
llm_client = LlmClient(llm_api_key, llm_model, llm_base_url)
# Create subtasks generator
subtasks_gen = SubtasksGenerator(llm_client)
objects = ["Key1", "Key2", "Key3", "Door1", "Door2"]
# Genereate paths (2D Array Output)
subtask_paths = subtasks_gen.gen_subtask_paths(objects)
# Validate paths
try:
    validate_subtask_paths(subtask_paths, objects)
except Exception as e:
    print(f"Validation failed: {e}")
```

## Teacher Student Training
Use the generated **2D Array of Paths** to train an **RL Agent** to master the environment with the `TeacherStudent` class.

### Create Teacher Student Algorithm
```py
from master_agent.rl.teacher_student import TeacherStudent

ts = TeacherStudent(subtask_paths)
print("Training the model...")
ts.train()
print("Training complete.")

print("Demonstrating learned path...")
ts.demo_learned_path()
```

## VLM Identification
This project also automates the process of **Object Detection** within the Minigrid environment. Currently the `master-agent` package has the `TilesetIdentifier` class to aid in this process. We recommend using a gpt based model such as `openai/gpt-4o-mini`.

### Unidentified Tileset Identification
```py
import os
from dotenv import load_dotenv
from .identify import TileIdentifier
from .client import LlmClient
from envs.complexEnv import ComplexEnv

llm_client = LlmClient(llm_api_key, llm_model, llm_base_url)
# Create tileset identifier
identifier = TileIdentifier(llm_client)
env = ComplexEnv(render_mode='rgb_array', highlight=False) # Removing highlight for accurate tileset representation
env.reset()
# Generate unidentified tileset
unidentified_tileset = identifier.parse_tileset(env.render())
# Validate tileset
identifier.validate_unidentified_tileset(unidentified_tileset, env)
```

### Display Tileset
```py
import matplotlib.pyplot as plt

unique_tiles = np.unique(unidentified_tileset.reshape(-1, 32, 32, 3), axis=0)
print(f"Number of unique tiles: {len(unique_tiles)}")

# Create a mapping of tile IDs to their positions in the grid
tile_positions = {}
for tile_id, tile in enumerate(unique_tiles):
    tile_positions[tile_id] = []
    for row_idx, row in enumerate(unidentified_tileset):
        for col_idx, grid_tile in enumerate(row):
            if np.array_equal(grid_tile, tile):
                tile_positions[tile_id].append((row_idx, col_idx))

# Create a figure with subplots for each unique tile
num_tiles = len(unique_tiles)
num_cols = 5
num_rows = (num_tiles + num_cols - 1) // num_cols

fig, axs = plt.subplots(num_rows, num_cols, figsize=(15, 3 * num_rows))

# Flatten the axs array for easier indexing
axs = axs.flatten()

# Plot each unique tile in a separate subplot with its ID
for i, tile in enumerate(unique_tiles):
    axs[i].imshow(tile)
    axs[i].set_title(f"Tile ID: {i}")
    axs[i].set_xticks([])
    axs[i].set_yticks([])

# Adjust spacing between subplots
plt.subplots_adjust(wspace=0.1, hspace=0.1)

# Show the figure
plt.show()

# Print the mapping of tile IDs to their positions in the grid
for tile_id, positions in tile_positions.items():
    print(f"Tile ID: {tile_id}")
    print(f"  Tile ID: {tile_id}, Coordinate Positions: {[f'({col+1}, {(unidentified_tileset.shape[0]-row)})' for row, col in positions]}")
```

### Identify Tileset
```py
import os
from dotenv import load_dotenv
from .identify import TileIdentifier
from .client import LlmClient
from envs.complexEnv import ComplexEnv

llm_client = LlmClient(llm_api_key, llm_model, llm_base_url)
# Create tileset identifier
identifier = TileIdentifier(llm_client)
env = ComplexEnv(render_mode='rgb_array', highlight=False) # Removing highlight for accurate tileset representation
env.reset()
# Generate unidentified tileset
unidentified_tileset = identifier.parse_tileset(env.render())
# Validate tileset
identifier.validate_unidentified_tileset(unidentified_tileset, env)
# Identify tileset
tileset = identifier.identify_tiles(unidentified_tileset)

for tile in tileset.tiles:
    print(tile.name, tile.world_obj, tile.positions)
```
