"""Variables that are used in multiple tests."""

from chemprop.nn import BCELoss
from torch import Tensor, nn

from molpipeline.estimators.chemprop.component_wrapper import (
    MPNN,
    BinaryClassificationFFN,
    BondMessagePassing,
    SumAggregation,
)

# These are model parameters which are copied by value, but are too complex to check for equality.
# Thus, for these model parameters, only the type is checked.
NO_IDENTITY_CHECK = [
    "model__agg",
    "model__message_passing",
    "model",
    "model__predictor",
    "model__predictor__criterion",
    "model__predictor__output_transform",
]

# Default parameters for the Chemprop model.

DEFAULT_PARAMS = {
    "batch_size": 64,
    "lightning_trainer": None,
    "lightning_trainer__enable_checkpointing": False,
    "lightning_trainer__enable_model_summary": False,
    "lightning_trainer__max_epochs": 500,
    "lightning_trainer__accelerator": "cpu",
    "lightning_trainer__default_root_dir": None,
    "lightning_trainer__limit_predict_batches": 1.0,
    "lightning_trainer__detect_anomaly": False,
    "lightning_trainer__reload_dataloaders_every_n_epochs": 0,
    "lightning_trainer__precision": "32-true",
    "lightning_trainer__min_steps": None,
    "lightning_trainer__max_time": None,
    "lightning_trainer__limit_train_batches": 1.0,
    "lightning_trainer__strategy": "auto",
    "lightning_trainer__gradient_clip_algorithm": None,
    "lightning_trainer__log_every_n_steps": 50,
    "lightning_trainer__limit_val_batches": 1.0,
    "lightning_trainer__gradient_clip_val": None,
    "lightning_trainer__overfit_batches": 0.0,
    "lightning_trainer__num_nodes": 1,
    "lightning_trainer__use_distributed_sampler": True,
    "lightning_trainer__check_val_every_n_epoch": 1,
    "lightning_trainer__benchmark": False,
    "lightning_trainer__inference_mode": True,
    "lightning_trainer__limit_test_batches": 1.0,
    "lightning_trainer__fast_dev_run": False,
    "lightning_trainer__logger": None,
    "lightning_trainer__max_steps": -1,
    "lightning_trainer__num_sanity_val_steps": 2,
    "lightning_trainer__devices": "auto",
    "lightning_trainer__min_epochs": None,
    "lightning_trainer__val_check_interval": 1.0,
    "lightning_trainer__barebones": False,
    "lightning_trainer__accumulate_grad_batches": 1,
    "lightning_trainer__deterministic": False,
    "lightning_trainer__enable_progress_bar": True,
    "model": MPNN,
    "model__agg__dim": 0,
    "model__agg": SumAggregation,
    "model__batch_norm": True,
    "model__final_lr": 0.0001,
    "model__init_lr": 0.0001,
    "model__max_lr": 0.001,
    "model__message_passing__activation": "relu",
    "model__message_passing__bias": False,
    "model__message_passing__d_e": 14,
    "model__message_passing__d_h": 300,
    "model__message_passing__d_v": 72,
    "model__message_passing__d_vd": None,
    "model__message_passing__depth": 3,
    "model__message_passing__dropout_rate": 0.0,
    "model__message_passing__undirected": False,
    "model__message_passing": BondMessagePassing,
    "model__metric_list": None,
    "model__predictor__activation": "relu",
    "model__warmup_epochs": 2,
    "model__predictor": BinaryClassificationFFN,
    "model__predictor__criterion": BCELoss,
    "model__predictor__criterion__task_weights": Tensor([1.0]),
    "model__predictor__dropout": 0,
    "model__predictor__hidden_dim": 300,
    "model__predictor__input_dim": 300,
    "model__predictor__n_layers": 1,
    "model__predictor__n_tasks": 1,
    "model__predictor__output_transform": nn.Identity,
    "model__predictor__task_weights": Tensor([1.0]),
    "model__predictor__threshold": None,
    "n_jobs": 1,
}
