# EmbeddingRWKV

A high-efficiency text embedding and reranking model based on RWKV architecture.

## 🚀 Quick Start (End-to-End)

Get text embeddings in just a few lines. The tokenizer and model are designed to work seamlessly together.

> **Note**: Always set `add_eos=True` during tokenization. The model relies on the EOS token (`65535`) to mark the end of a sentence for correct embedding generation.

```python
import os
# Set environment for JIT compilation (Optional, set to '1' for CUDA acceleration)
os.environ["RWKV_CUDA_ON"] = '1'

from rwkv_emb.tokenizer import RWKVTokenizer
from rwkv_emb.model import EmbeddingRWKV

# 1. Initialize Tokenizer & Model
model = EmbeddingRWKV(model_path='/path/to/model.pth')
tokenizer = RWKVTokenizer()

# 2. Tokenize (Text -> Tokens)
text = "Hello world! This is RWKV embedding."
# Important: Enable add_eos=True to append the required EOS token (65535)
tokens = tokenizer.encode(text, add_eos=True)
print(f"Tokens: {tokens}") 

# 3. Inference (Tokens -> Embedding)
embedding, state = model.forward(tokens, None)

print(f"Embedding shape: {embedding.shape}")
print(embedding.shape)
```

## ⚡ Batch Inference & Performance Guide

For production use cases, running inference in batches is significantly faster.

### ⚠️ Critical Performance Tip: Pad to Same Length

While the model supports batches with variable sequence lengths, **we strongly recommend padding all sequences to the same length** for maximum GPU throughput.

  - **Pad Token**: `0`
  - **Performance**: Fixed-length batches allow the CUDA kernel to parallelize computation efficiently. Variable-length batches will trigger a slower execution path.

### Batch Example (Recommended)

```python
# Example: Batching two sentences of different lengths
sentences = [
    "Short sentence.",
    "This is a slightly longer sentence for demonstration."
]

# 1. Tokenize all
batch_tokens = [tokenizer.encode(s, add_eos=True) for s in sentences]

# 2. Left pad to the longest sequence in the batch using 0
max_len = max(len(t) for t in batch_tokens)

for i in range(len(batch_tokens)):
    pad_len = max_len - len(batch_tokens[i])
    # insert 0s to the beginning
    batch_tokens[i] = [0] * pad_len + batch_tokens[i]

# batch_tokens is now a rectangular matrix (List of Lists with same length)
print(f"Padded Batch: {batch_tokens}")

# 3. Fast Inference
embeddings, states = model.forward(batch_tokens, None)

# embeddings shape: [Batch_Size, Embedding_Dim]
print("Batch Embeddings:", embeddings.shape)
print("Batch States:", states[0].shape, states[1].shape)
```

## 🎯 RWKVReRanker (State-based Reranker)

The `RWKVReRanker` utilizes the final hidden state produced by the main `EmbeddingRWKV` model to score the relevance between a query and a document.

### How it works (Online Mode)

1.  **Format** Query and Document based on Online template.
2.  Run the **Embedding Model** to generate the final State.
3.  Feed the **Attention State** (`state[1]`) into the **ReRanker** to get a relevance score.

### 📝 Oneline Mode Usage Example

```python
import torch
from rwkv_emb.tokenizer import RWKVTokenizer
from rwkv_emb.model import EmbeddingRWKV, RWKVReRanker

# 1. Load Models
# The ReRanker weights are stored in the differernt checkpoint
emb_model = EmbeddingRWKV(model_path='/path/to/EmbeddingRWKV.pth')
reranker = RWKVReRanker(model_path='/path/to/RWKVReRanker.pth')

tokenizer = RWKVTokenizer()

# 2. Prepare Data (Query + Candidate Documents)
query = "What represents the end of a sequence?"
documents = [
    "The EOS token is used to mark the end of a sentence.",
    "Apples are red and delicious fruits.",
    "Machine learning requires large datasets."
]

# 3. Construct Input Pairs
# We treat the Query and Document as a single sequence.
pairs = []
online_template = "Instruct: Given a query, retrieve documents that answer the query\nDocument: {document}\nQuery: {query}"
for doc in documents:
    # Format: Instruct + Document + Query
    text = online_template.format(document=doc, query=query)
    pairs.append(text)

# 4. Tokenize & Pad (Critical for Batch Performance)
batch_tokens = [tokenizer.encode(p, add_eos=True) for p in pairs]

# Pad to same length for efficiency
max_len = max(len(t) for t in batch_tokens)
for i in range(len(batch_tokens)):
    batch_tokens[i] = batch_tokens[i] + [0] * (max_len - len(batch_tokens[i]))

# 5. Get States from Embedding Model
# We don't need the embedding output here, we only need the final 'state'
with torch.no_grad():
    _, state = emb_model.forward(batch_tokens, None)

# 6. Score with ReRanker
# The ReRanker expects the Attention State: state[1]
# state[1] shape: [Layers, Batch, Heads, HeadSize, HeadSize]
logits = reranker.forward(state[1])
scores = torch.sigmoid(logits) # Convert logits to probabilities (0-1)

# 7. Print Results
print("\nReRanker Scores:")
for doc, score in zip(documents, scores):
    print(f"[{score:.4f}] {doc}")
```