import torch
from torch import nn
import numpy as np
from ._script_info import _script_info

__all__ = [
    "rnn_forecast"
]

def rnn_forecast(model: nn.Module, start_sequence: torch.Tensor, steps: int, device: str = 'cpu'):
    """
    Runs a sequential forecast for a trained RNN-based model.

    This function iteratively predicts future time steps, where each new prediction
    is generated by feeding the previous prediction back into the model.

    Args:
        model (nn.Module): The trained PyTorch RNN model (e.g., LSTM, GRU).
        start_sequence (torch.Tensor): The initial sequence to start the forecast from.
                                       Shape should be (sequence_length, num_features).
        steps (int): The number of future time steps to predict.
        device (str, optional): The device to run the forecast on ('cpu', 'cuda', 'mps'). 
                                Defaults to 'cpu'.

    Returns:
        np.ndarray: A numpy array containing the forecasted values.
    """
    model.eval()
    model.to(device)
    
    predictions = []
    current_sequence = start_sequence.to(device)

    with torch.no_grad():
        for _ in range(steps):
            # Get the model's prediction for the current sequence
            output = model(current_sequence.unsqueeze(0)) # Add batch dimension
            
            # The prediction is the last element of the output sequence
            next_pred = output[0, -1, :].view(1, -1)
            
            # Store the prediction
            predictions.append(next_pred.cpu().numpy())
            
            # Update the sequence for the next iteration:
            # Drop the first element and append the new prediction
            current_sequence = torch.cat([current_sequence[1:], next_pred], dim=0)
            
    # Concatenate all predictions and flatten the array for easy use
    return np.concatenate(predictions).flatten()


def info():
    _script_info
