use smallvec::Array;

use crate::{
    all_imports::Shape,
    hashmaps::{AHashSet as HashSet, FxHashMap as HashMap},
    smallvec::Sv,
    util::cumsum,
};
use std::iter::zip;

use super::{
    algebraic_rewrite::{get_removable_axes, remove_axes},
    prelude::*,
    Add, Concat, Einsum, Index, Rearrange, ScalarConstant, Scatter,
};
use crate::pyo3_prelude::*;
use crate::{
    rearrange_spec::RearrangeSpec,
    tensor_util::{
        compose, uslices_shrink_base, uslices_to_index, TensorAxisIndex, TensorIndex, USlice,
    },
    util::filter_out_idx,
};
use std::hash::Hash;

/// right now writing scatter_fuse, einsum_pull_scatter, add_pull_scatter
#[pyfunction]
pub fn scatter_fuse(scatter: &Scatter) -> Option<Scatter> {
    // this is just composing indices, lower is top
    if let Circuit::Scatter(inner) = &**scatter.node {
        let index_composed = compose(&inner.index, &scatter.index);
        Some(
            Scatter::try_new(
                inner.node.clone(),
                index_composed,
                scatter.info().shape.clone(),
                scatter.name_cloned(),
            )
            .unwrap(),
        )
    } else {
        None
    }
}

pub fn uslice_map_get_index<A: Array>(
    ints: &Sv<A>,
    int_slices: &HashMap<A::Item, USlice>,
) -> TensorIndex
where
    A::Item: Eq + Hash,
{
    TensorIndex(
        ints.iter()
            .map(|i| {
                int_slices
                    .get(i)
                    .map(|s| (*s).into())
                    .unwrap_or(TensorAxisIndex::IDENT)
            })
            .collect(),
    )
}

#[pyfunction]
pub fn einsum_pull_scatter(einsum: &Einsum) -> Option<CircuitRc> {
    let mut did_anything = false;
    let int_sizes = einsum.shape_map().unwrap();
    let mut int_slices: HashMap<u8, USlice> = HashMap::new();
    for (node, ints) in &einsum.args {
        match &***node {
            Circuit::Scatter(scatter) => {
                did_anything = true;
                for (slice, i) in zip(scatter.index.all_uslices().unwrap(), ints) {
                    int_slices.insert(
                        *i,
                        int_slices
                            .get(i)
                            .unwrap_or(&USlice {
                                start: 0,
                                stop: int_sizes[i],
                            })
                            .intersection(&slice),
                    );
                }
            }
            _ => {}
        }
    }
    if !did_anything {
        None
    } else if int_slices
        .iter()
        .any(|(_i, slice)| slice.start == slice.stop)
    {
        Some(ScalarConstant::new(0.0, einsum.info().shape.clone(), None).rc())
    } else {
        let new_args = einsum
            .args
            .iter()
            .map(|(node, ints)| match &***node {
                Circuit::Scatter(inner) => {
                    let index_orig_base = ints.iter().map(|i| int_slices[i]).collect();
                    let index =
                        uslices_shrink_base(&index_orig_base, &inner.index.all_uslices().unwrap());
                    (
                        Index::try_new(inner.node.clone(), uslices_to_index(&index), None)
                            .unwrap()
                            .rc(),
                        ints.clone(),
                    )
                }
                _ => (
                    Index::try_new(node.clone(), uslice_map_get_index(ints, &int_slices), None)
                        .unwrap()
                        .rc(),
                    ints.clone(),
                ),
            })
            .collect();
        let new_einsum = Einsum::try_new(new_args, einsum.out_axes.clone(), einsum.name_cloned())
            .unwrap()
            .rc();
        Some(
            Scatter::try_new(
                new_einsum,
                uslice_map_get_index(&einsum.out_axes, &int_slices),
                einsum.info().shape.clone(),
                None,
            )
            .unwrap()
            .rc(),
        )
    }
}

#[pyfunction]
pub fn add_pull_scatter(add: &Add) -> Option<Scatter> {
    let mut slices: Vec<USlice> = add
        .info()
        .shape
        .iter()
        .map(|_l| USlice { start: 0, stop: 0 })
        .collect();
    if !add
        .nodes
        .iter()
        .all(|x| matches!(&***x, Circuit::Scatter(_)))
    {
        return None;
    }
    for (operand, rank_difference) in add.nodes_and_rank_differences() {
        let scatter = operand.as_scatter().unwrap();
        for (i, (slice, l)) in
            zip(scatter.index.all_uslices().unwrap(), &scatter.info().shape).enumerate()
        {
            if *l == add.info().shape[i + rank_difference] {
                slices[i + rank_difference] = slices[i + rank_difference].union(&slice);
            }
        }
    }
    if zip(&slices, &add.info().shape).all(|(s, l)| s.start == 0 && s.stop == *l) {
        None
    } else {
        let new_operands = add
            .nodes_and_rank_differences()
            .iter()
            .map(|(node, rank_difference)| {
                let index = TensorIndex(
                    node.info()
                        .shape
                        .iter()
                        .enumerate()
                        .map(|(i, l)| {
                            if *l == add.info().shape[i + rank_difference] {
                                slices[i + rank_difference].into()
                            } else {
                                TensorAxisIndex::IDENT
                            }
                        })
                        .collect(),
                );
                Index::try_new(node.clone(), index, None).unwrap().rc()
            })
            .collect();
        let new_add = Add::try_new(new_operands, add.name_cloned()).unwrap().rc();
        Some(
            Scatter::try_new(
                new_add,
                uslices_to_index(&(0..add.info().rank()).map(|i| slices[i]).collect()),
                add.info().shape.clone(),
                None,
            )
            .unwrap(),
        )
    }
}

#[pyfunction]
pub fn scatter_elim_identity(scatter: &Scatter) -> Option<CircuitRc> {
    if scatter.is_identity() {
        Some(scatter.node.clone())
    } else {
        None
    }
}

#[pyfunction]
pub fn index_einsum_to_scatter(node: &Index) -> Option<CircuitRc> {
    let mut did_anything = false;
    if let Circuit::Einsum(inner) = &**node.node {
        let mut int_slices: HashMap<u8, USlice> = HashMap::new();
        let canon_index = node.index.canonicalize(&inner.info().shape);
        let containing_uslices: Vec<Option<USlice>> = canon_index
            .0
            .iter()
            .map(USlice::containing_uslice)
            .collect();
        for (uslice_here, int) in zip(&containing_uslices, &inner.out_axes) {
            match uslice_here {
                None => {
                    return None;
                }
                Some(uslice_here) => {
                    if int_slices.contains_key(int) {
                        let new_here = uslice_here.intersection(&int_slices[int]);
                        if &new_here != uslice_here {
                            did_anything = true;
                            int_slices.insert(*int, new_here);
                        }
                    } else {
                        int_slices.insert(*int, *uslice_here);
                    }
                }
            }
        }
        if !did_anything {
            None
        } else if int_slices.iter().any(|(_i, slice)| slice.length() == 0) {
            Some(ScalarConstant::new(0.0, node.info().shape.clone(), None).rc())
        } else {
            let new_args = inner
                .args
                .iter()
                .map(|(node, ints)| {
                    (
                        Index::try_new(node.clone(), uslice_map_get_index(ints, &int_slices), None)
                            .unwrap()
                            .rc(),
                        ints.clone(),
                    )
                })
                .collect();
            let new_einsum = Einsum::try_new(new_args, inner.out_axes.clone(), inner.name_cloned())
                .unwrap()
                .rc();
            let new_scatter = Scatter::try_new(
                new_einsum,
                TensorIndex(
                    inner
                        .out_axes
                        .iter()
                        .enumerate()
                        .map(|(i, int)| {
                            int_slices
                                .get(int)
                                .map(|s| (s.shrink_base(&containing_uslices[i].unwrap())).into())
                                .unwrap_or(TensorAxisIndex::IDENT)
                        })
                        .collect(),
                ),
                containing_uslices
                    .iter()
                    .map(|x| x.unwrap().stop - x.unwrap().start)
                    .collect(),
                None,
            )
            .unwrap()
            .rc();
            Some(
                Index::try_new(
                    new_scatter,
                    TensorIndex(
                        node.index
                            .0
                            .iter()
                            .map(|x| match x {
                                TensorAxisIndex::Single(_x) => TensorAxisIndex::Single(0),
                                _ => TensorAxisIndex::IDENT,
                            })
                            .collect(),
                    ),
                    None,
                )
                .unwrap()
                .rc(),
            )
        }
    } else {
        None
    }
}

#[pyfunction]
pub fn scatter_pull_removable_axes(scatter: &Scatter) -> Option<Rearrange> {
    let removable_axes = get_removable_axes(&scatter.node);
    let scatter_uslices = scatter.index.all_uslices().unwrap();
    let removable_axes: HashSet<usize> = removable_axes
        .iter()
        .filter(|i| {
            scatter_uslices[**i]
                == USlice {
                    start: 0,
                    stop: scatter.info().shape[**i],
                }
        })
        .copied()
        .collect();
    if removable_axes.is_empty() {
        None
    } else {
        let new_inner = remove_axes(&scatter.node, &removable_axes).unwrap();
        Some(
            Rearrange::try_new(
                Scatter::try_new(
                    new_inner,
                    TensorIndex(filter_out_idx(&scatter.index.0, &removable_axes)),
                    filter_out_idx(
                        &scatter.info().shape.iter().cloned().collect::<Vec<_>>(),
                        &removable_axes,
                    )
                    .into_iter()
                    .collect(),
                    None,
                )
                .unwrap()
                .rc(),
                RearrangeSpec::unremove_axes(&removable_axes, &scatter.info().shape),
                None,
            )
            .unwrap(),
        )
    }
}

#[pyfunction]
pub fn scatter_to_concat(scatter: &Scatter) -> CircuitRc {
    let mut result = scatter.node.clone();
    for (i, (idx, l)) in
        zip(scatter.index.all_uslices().unwrap(), &scatter.info().shape).enumerate()
    {
        let mut lower_pad_shape: Shape = result.info().shape.clone();
        lower_pad_shape[i] = idx.start;
        let mut upper_pad_shape: Shape = result.info().shape.clone();
        upper_pad_shape[i] = l - idx.stop;
        result = Concat::try_new(
            vec![
                ScalarConstant::new(0.0, lower_pad_shape, None).rc(),
                result,
                ScalarConstant::new(0.0, upper_pad_shape, None).rc(),
            ],
            i,
            None,
        )
        .unwrap()
        .rc();
    }
    result
}

#[pyfunction]
pub fn concat_to_scatter(concat: &Concat) -> Option<Scatter> {
    let pre_zeros = concat
        .nodes
        .iter()
        .take_while(|node| {
            if let Circuit::ScalarConstant(sc) = &****node && sc.value==0.0{
            true
        }else{
            false
        }
        })
        .count();
    let post_zeros = concat
        .nodes
        .iter()
        .rev()
        .take_while(|node| {
            if let Circuit::ScalarConstant(sc) = &****node && sc.value==0.0{
            true
        }else{
            false
        }
        })
        .count();
    let end = concat.nodes.len() - post_zeros;
    if pre_zeros == 0 && post_zeros == 0 || (pre_zeros + post_zeros >= concat.nodes.len()) {
        return None;
    }
    let starts = cumsum(&concat.get_sections());
    let cslice = TensorAxisIndex::new_plain_slice(starts[pre_zeros], starts[end]);
    let scatter_index = TensorIndex::new_single(cslice, concat.axis, concat.info().rank())
        .canonicalize(&concat.info().shape);
    let new_concat = Concat::nrc(
        concat.nodes[pre_zeros..end].to_vec(),
        concat.axis,
        concat.name_cloned(),
    );
    if new_concat.info().numel() == concat.info().numel() {
        return None;
    }
    Some(Scatter::try_new(new_concat, scatter_index, concat.info().shape.clone(), None).unwrap())
}
