import random
import tqdm
from datasets import load_dataset
import dspy


from pydantic import BaseModel, Field
from typing import List, Optional
import json


class SWEBenchData(BaseModel):
    instance_id: str = Field(description="A formatted instance identifier, usually as repo_owner__repo_name-PR-number.")
    patch: str = Field(description="The gold patch generated by the PR minus test-related code.")
    repo: str = Field(description="The repository owner/name identifier from GitHub.")
    base_commit: str = Field(description="The commit hash of the repository representing the HEAD of the repository before the solution PR is applied.")
    hints_text: str = Field(description="Comments made on the issue prior to the creation of the solution PR’s first commit creation date.")
    text: str = Field(description="The generated text according to the retrieval criterion and the style-2 prompt found in github:SWE-bench.")
    created_at: str = Field(description="The creation date of the pull request.")
    test_patch: str = Field(description="A test-file patch that was contributed by the solution PR.")
    problem_statement: str = Field(description="The issue title and body.")
    version: str = Field(description="Installation version to use for running evaluation.")
    environment_setup_commit: str = Field(description="Commit hash to use for environment setup and installation.")
    # FAIL_TO_PASS: Optional[List[str]] = Field(default=None, description="A list of strings that represent the set of tests resolved by the PR and tied to the issue resolution.")
    # PASS_TO_PASS: Optional[List[str]] = Field(default=None, description="A list of strings that represent tests that should pass before and after the PR application.")


def get_instance_by_id(dataset, instance_id):
    """
    Retrieve a model instance from the dataset by instance_id.

    Args:
    dataset (list): The dataset containing the instances.
    instance_id (str): The instance identifier to search for.

    Returns:
    SWEBenchData: The found model instance or None if not found.
    """
    for item in dataset:
        if item['instance_id'] == instance_id:
            # Convert FAIL_TO_PASS and PASS_TO_PASS from JSON string to List
            item['FAIL_TO_PASS'] = json.loads(item['FAIL_TO_PASS']) if item['FAIL_TO_PASS'] else None
            item['PASS_TO_PASS'] = json.loads(item['PASS_TO_PASS']) if item['PASS_TO_PASS'] else None
            return SWEBenchData(**item)
    return None


class SWEBench:
    def __init__(self, do_shuffle=True, shuffle_seed=0) -> None:
        super().__init__()

        # Load the dataset from the Hugging Face Hub
        dataset = load_dataset("princeton-nlp/SWE-bench_oracle", 'default')

        hf_official_train = dataset['train']
        hf_official_test = dataset['test']
        official_train = []
        official_test = []

        for example in tqdm.tqdm(hf_official_train):
            issue = f"{example['problem_statement']}\n\n{example['hints_text']}"
            patch = example['patch']
            test_patch = example['test_patch']

            official_train.append(dict(issue=issue, patch=patch, test_patch=test_patch))

        for example in tqdm.tqdm(hf_official_test):
            issue = f"{example['problem_statement']}\n\n{example['hints_text']}"
            patch = example['patch']
            test_patch = example['test_patch']

            if len(test_patch) == 0:
                continue

            official_test.append(dict(issue=issue, patch=patch, test_patch=test_patch))

        # Optionally shuffle datasets
        if do_shuffle:
            rng = random.Random(shuffle_seed)
            rng.shuffle(official_train)
            rng.shuffle(official_test)

        # Split the data
        trainset = official_train[:int(0.8 * len(official_train))]
        devset = official_train[int(0.8 * len(official_train)):]
        testset = official_test

        # Wrap data into dspy.Example format
        self.train = [dspy.Example(**x).with_inputs('issue') for x in trainset]
        self.dev = [dspy.Example(**x).with_inputs('issue') for x in devset]
        self.test = [dspy.Example(**x).with_inputs('issue') for x in testset]



def main():
    """Main function"""
    # Example instantiation and use:
    swe_bench = SWEBench(do_shuffle=True, shuffle_seed=42)
    print(f"Trainset size: {len(swe_bench.train)}")
    print(f"Devset size: {len(swe_bench.dev)}")
    print(f"Testset size: {len(swe_bench.test)}")


if __name__ == '__main__':
    main()
