import pytest
from bayesian_benchmarks.data import Boston

import torch
import gpytorch
from gpytorch.kernels import ScaleKernel, RBFKernel, MaternKernel
from skgpytorch.models import ExactGPRegressor
from .gpytorch_models import exact_gp_regressor_from_gpytorch


def test_exact_gp_regressor():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    n_iters = 5

    data = Boston()
    data.X_train = torch.tensor(data.X_train, dtype=torch.float).to(device)
    data.Y_train = torch.tensor(data.Y_train, dtype=torch.float).ravel().to(device)
    data.X_test = torch.tensor(data.X_test, dtype=torch.float).to(device)
    data.Y_test = torch.tensor(data.Y_test, dtype=torch.float).ravel().to(device)
    kernels = [
        ScaleKernel(RBFKernel(ard_num_dims=data.X_train.shape[1])),
        ScaleKernel(MaternKernel(ard_num_dims=data.X_train.shape[1])),
    ]

    for seed in range(5):
        for kernel in kernels:
            gpytorch_model = exact_gp_regressor_from_gpytorch(
                data, kernel, seed, n_iters, device
            )
            gp = ExactGPRegressor(
                data.X_train,
                data.Y_train,
                kernel,
                random_state=seed,
                device=device,
            )
            gp.fit(n_iters=n_iters)
            pred_mean, pred_var = gp.predict(data.X_test)

            assert torch.allclose(pred_var, gpytorch_model.pred_dist.variance)
            assert torch.allclose(pred_mean, gpytorch_model.pred_dist.mean)
