//! Traits for solvers.
//! These traits are meant to be usable as [trait objects](https://doc.rust-lang.org/reference/items/traits.html#object-safety).
use sli_collections::rc::Rc;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};

use crate::fodot::collections::set::Set;
use crate::fodot::error::GetRangeError;
use crate::fodot::structure::{ArgsRef, GlobModel, IntoIterCompleteModel, Model, PartialStructure};
use crate::fodot::theory::{Assertions, Theory};
use crate::fodot::vocabulary::PfuncRc;
use comp_core::solver::Solver as CCSolver;
pub use comp_core::solver::{InterpMethod, MeasurementEnder, SatResult, TimeMeasurements, Timings};

/// The result from a [Solver::get_model] call.
#[allow(clippy::large_enum_variant)]
pub enum ModelResult {
    /// The theory has a model, and as such is satisfiable.
    Sat(GlobModel),
    /// The theory has no mode, and as such is unsatsifiable.
    Unsat,
    /// The solver tried its best, but couldn't succeed :(.
    Unknown,
}

/// A Owning reference to constraints.
///
/// Used in [GlobModel] and [Model.
#[derive(Clone)]
pub struct Constraints {
    #[allow(unused)]
    assertions: Rc<Assertions>,
}

/// All methods a SLI solver has available for use.
pub trait Solver<'a> {
    /// Initialize a solver with the given [Theory].
    ///
    /// See [Solver::initialize_with] and [Solver::initialize_with_timing] for more options.
    fn initialize(fodot_theory: &'a Theory) -> Self
    where
        Self: Sized,
    {
        Self::initialize_with(fodot_theory, Default::default())
    }

    /// Initialize the solver with the given [InterpMethod].
    fn initialize_with(fodot_theory: &'a Theory, interp_method: InterpMethod) -> Self
    where
        Self: Sized,
    {
        Self::initialize_with_timing(fodot_theory, interp_method, &mut ())
    }

    /// Initialize the solver with an [InterpMethod] and a [Timings] object to measure internal
    /// timings.
    fn initialize_with_timing(
        fodot_theory: &'a Theory,
        interp_method: InterpMethod,
        timings: &mut dyn Timings,
    ) -> Self
    where
        Self: Sized;

    /// Get the current constraints.
    fn get_constraints(&self) -> Constraints;

    /// Check if the current solver state is satisfiable.
    fn check(&mut self) -> SatResult;

    /// Check satifiability for the current solver state and get the model if it is satisfiable.
    fn check_get_model(&mut self) -> ModelResult {
        match self.check() {
            SatResult::Sat => ModelResult::Sat(self.get_model().expect("Internal Error")),
            SatResult::Unsat => ModelResult::Unsat,
            SatResult::Unknown => ModelResult::Unknown,
        }
    }

    /// Get a model of the current solver state if possible.
    fn get_model(&mut self) -> Option<GlobModel>;
    /// Invalidate the model of the current solver state.
    fn next_model(&mut self);

    /// Propagate the constraints, returns a [PartialStructure] if the theory is satisfiable.
    ///
    /// This [PartialStructure] the current state of the solver after propagating.
    fn propagate(&mut self) -> Option<PartialStructure>;

    /// Returns the certainly falses of the [Pfunc](crate::fodot::vocabulary::Pfunc)
    /// with the given [Args](crate::fodot::structure::Args).
    ///
    /// Returns [None] if the solver is in an unsatisfiable state.
    fn get_range(&mut self, pfunc: PfuncRc, args: ArgsRef) -> Result<Option<Set>, GetRangeError>;

    fn theory(&self) -> &Theory;
}

pub trait SolverIter<'a>: Solver<'a> {
    /// Returns an iterator for iterating over the models generated by the solver.
    /// Use [take](std::iter::Iterator::take) for limiting the amount of models.
    fn iter_models(&mut self) -> ModelIterator<'_, 'a, Self> {
        ModelIterator::new(self)
    }
}

impl<'a, T> SolverIter<'a> for T where T: Solver<'a> {}

/// An [Iterator] over [GlobModel]s from a solver.
pub struct ModelIterator<'a, 't, T: Solver<'t> + ?Sized>
where
    't: 'a,
{
    solver: &'a mut T,
    phantom_data: PhantomData<&'t ()>,
}

impl<'a, 't, T: Solver<'t> + ?Sized> ModelIterator<'a, 't, T>
where
    't: 'a,
{
    pub fn new(solver: &'a mut T) -> Self {
        Self {
            solver,
            phantom_data: PhantomData,
        }
    }

    pub fn complete_with_infinite(self) -> CompleteModelIterator<'a, 't, T> {
        CompleteModelIterator::new(self).disable_skip_infinite()
    }

    pub fn complete(self) -> CompleteModelIterator<'a, 't, T> {
        CompleteModelIterator::new(self)
    }
}

impl<'t, T: Solver<'t> + ?Sized> Iterator for ModelIterator<'_, 't, T> {
    type Item = GlobModel;

    fn next(&mut self) -> Option<Self::Item> {
        match self.solver.check_get_model() {
            ModelResult::Sat(model) => {
                self.solver.next_model();
                Some(model)
            }
            _ => None,
        }
    }
}

/// Wraps [ModelIterator] to iterate over [Model]s.
///
/// For each [GlobModel] [ModelIterator] returns this iterator returns all [Model]s in this
/// [GlobModel].
pub struct CompleteModelIterator<'a, 't, T: Solver<'t> + ?Sized> {
    model_iter: ModelIterator<'a, 't, T>,
    cur_glob_model_iter: Option<IntoIterCompleteModel>,
    skip_infinite: bool,
}

impl<'a, 't, T: Solver<'t> + ?Sized> CompleteModelIterator<'a, 't, T> {
    fn new(model_iter: ModelIterator<'a, 't, T>) -> Self {
        Self {
            model_iter,
            cur_glob_model_iter: None,
            skip_infinite: false,
        }
    }

    /// Disables iteration over infinite values.
    pub fn enable_skip_infinite(mut self) -> Self {
        self.skip_infinite = true;
        self
    }

    /// Disables iteration over infinite values in place.
    pub fn mut_enable_skip_infinite(&mut self) {
        self.skip_infinite = true;
    }

    /// Enables iteration over infinite values.
    pub fn disable_skip_infinite(mut self) -> Self {
        self.skip_infinite = false;
        self
    }

    /// Disables iteration over infinite values in place.
    pub fn mut_disable_skip_infinite(&mut self) {
        self.skip_infinite = false;
    }

    /// Disable iteration over infinite values if `value` is true.
    /// Disables it otherwise.
    pub fn skip_infinite(self, value: bool) -> Self {
        if value {
            self.enable_skip_infinite()
        } else {
            self.disable_skip_infinite()
        }
    }

    /// Disable iteration over infinite values if `value` is true.
    /// Disables it otherwise. Does this in place.
    pub fn mut_skip_infinite(&mut self, value: bool) {
        if value {
            self.mut_enable_skip_infinite()
        } else {
            self.mut_disable_skip_infinite()
        }
    }
}

impl<'t, T: Solver<'t> + ?Sized> Iterator for CompleteModelIterator<'_, 't, T> {
    type Item = Model;

    fn next(&mut self) -> Option<Self::Item> {
        loop {
            if let Some(value) = &mut self.cur_glob_model_iter {
                if let Some(value) = value.next() {
                    return Some(value);
                }
            }
            match self.model_iter.next() {
                Some(value) => {
                    let iter = value.into_iter_models();
                    self.cur_glob_model_iter = Some(iter.skip_infinite(self.skip_infinite));
                }
                None => return None,
            }
        }
    }
}

pub mod wrapper {
    use crate::fodot::{self, TryIntoCtx, error::DomainMismatch, structure::DomainFullRc};
    use comp_core::node::{BoolElement, ElementNode, NodeEnum};

    use super::*;
    pub struct CCWrapper<'a, T> {
        solver: T,
        theory: &'a Theory,
    }

    impl<T> Deref for CCWrapper<'_, T> {
        type Target = T;

        fn deref(&self) -> &Self::Target {
            &self.solver
        }
    }

    impl<T> DerefMut for CCWrapper<'_, T> {
        fn deref_mut(&mut self) -> &mut Self::Target {
            &mut self.solver
        }
    }

    impl<'a, T: CCSolver<'a>> CCWrapper<'a, T> {
        pub fn has_been_simplified(&self, nth_assertion: usize) -> Option<bool> {
            let transformed_senctence = self
                .solver
                .constraints()
                .formulas_iter()
                .nth(nth_assertion)?;
            match transformed_senctence.first_node_enum() {
                NodeEnum::Element(ElementNode::Bool(BoolElement { value })) => Some(value),
                _ => None,
            }
        }
    }

    impl<'a, T: CCSolver<'a>> Solver<'a> for CCWrapper<'a, T> {
        fn initialize_with_timing(
            theory: &'a Theory,
            interp_method: InterpMethod,
            timings: &mut dyn Timings,
        ) -> Self
        where
            Self: Sized,
        {
            CCWrapper {
                solver: T::initialize_with_timing(
                    theory.lower().unwrap().into(),
                    &theory.structure().cc_struct,
                    interp_method,
                    timings,
                ),
                theory,
            }
        }

        fn check(&mut self) -> SatResult {
            self.solver.check()
        }

        fn get_constraints(&self) -> Constraints {
            Constraints {
                assertions: Rc::clone(self.theory.assertions_rc()),
            }
        }

        fn get_model(&mut self) -> Option<GlobModel> {
            self.solver.get_model().map(|f| GlobModel {
                constraints: self.get_constraints(),
                structure: PartialStructure {
                    type_interps: self.theory.type_interps_rc().clone(),
                    cc_struct: f.into(),
                },
            })
        }

        fn next_model(&mut self) {
            self.solver.next_model()
        }

        fn propagate(&mut self) -> Option<PartialStructure> {
            self.solver.propagate().map(|f| PartialStructure {
                type_interps: Rc::clone(self.theory.type_interps_rc()),
                cc_struct: f,
            })
        }

        fn get_range(
            &mut self,
            pfunc: PfuncRc,
            args: ArgsRef,
        ) -> Result<Option<Set>, fodot::error::GetRangeError> {
            let domain_full: DomainFullRc = pfunc
                .domain()
                .with_interps(self.theory().type_interps_rc().clone())?;
            let args: ArgsRef = args.try_into_ctx((&domain_full).into())?;
            if pfunc.domain() != args.domain().as_domain() {
                return Err(DomainMismatch {
                    expected: pfunc.domain().into(),
                    found: args.domain().as_domain().into(),
                }
                .into());
            }

            self.solver
                .get_range(pfunc.to_cc(), args.domain_enum)
                .map(|f| {
                    f.map(|f| Set {
                        backing: f,
                        domain: DomainFullRc::new(
                            &[pfunc.codomain_rc()],
                            self.theory.type_interps_rc().clone(),
                        )
                        .unwrap(),
                    })
                })
                .map_err(|_| fodot::error::InfiniteCodomainError.into())
        }

        fn theory(&self) -> &Theory {
            self.theory
        }
    }
}

pub type Z3Solver<'a> = wrapper::CCWrapper<'a, comp_core::solver::z3::Z3Solver<'a>>;
