use std::{
    collections::BTreeMap,
    iter::{self, zip},
};

use anyhow::{bail, Context, Result};
use macro_rules_attribute::apply;
use pyo3::{exceptions::PyValueError, prelude::*};
use rr_util::{
    pycall, python_error_exception,
    rearrange_spec::{ExpandToSpecOrShape, OpShape, OpSize, RearrangeSpec},
    sv,
    symbolic_size::SymbolicSizeProduct,
    tensor_util::{broadcast_shapes, Shape, TensorAxisIndex, TensorIndex},
    util::EinsumAxes,
};
use rustc_hash::FxHashMap as HashMap;
use smallvec::SmallVec;
use thiserror::Error;

use crate::{
    deep_map_fallible_pre_new_children, prelude::*, Add, Concat, Einsum, GeneralFunction, Index,
    Module, ModuleArgSpec, ModuleSpec, Rearrange, Scatter, SetSymbolicShape, Symbol, Tag,
};

#[pyfunction]
#[pyo3(name = "expand_node")]
pub fn expand_node_py(circuit: CircuitRc, inputs: Vec<CircuitRc>) -> Result<CircuitRc> {
    expand_node(circuit, &inputs, |_, _, _| {
        bail!("fancy module expanding not supported via python expand node")
    })
}

pub fn expand_node(
    circuit: CircuitRc,
    inputs: &Vec<CircuitRc>,
    rerun_with_extra_replacements: impl FnOnce(
        CircuitRc,
        HashMap<CircuitRc, CircuitRc>,
        usize,
    ) -> Result<CircuitRc>,
) -> Result<CircuitRc> {
    if inputs.len() != circuit.children().count() {
        bail!(ExpandError::WrongNumChildren {
            expected: circuit.children().count(),
            got: inputs.len(),
        });
    }
    if inputs
        .iter()
        .zip(circuit.children())
        .all(|(new, old)| new == &old)
    {
        return Ok(circuit); // in all identical case we can return same circuit for efficiency
    }
    let batch_ranks: Vec<usize> = zip(circuit.children(), inputs)
        .filter_map(|(old, new)| new.info().rank().checked_sub(old.info().rank()))
        .collect();
    if batch_ranks.len() != inputs.len() {
        bail!(ExpandError::BatchingRankTooLow {
            default: circuit.children().map(|x| x.info().rank()).collect(),
            got: inputs.iter().map(|x| x.info().rank()).collect(),
        });
    }

    let batch_shapes: Vec<&[usize]> = zip(&batch_ranks, inputs)
        .map(|(br, new)| &new.info().shape[0..*br])
        .collect();

    // TODO: maybe we should allow for inconsistent symbolic batch shapes?
    // (probably not...)
    let broadcasted_batch_shape =
        broadcast_shapes(&batch_shapes).context("batch shapes couldn't be broadcast together")?;
    let batch_rank = broadcasted_batch_shape.len();

    let get_expanded_inputs = |batchable: &[bool]| -> Result<Vec<CircuitRc>, usize> {
        let br = batch_rank;
        assert_eq!(batchable.len(), batch_ranks.len());
        assert_eq!(inputs.len(), batch_ranks.len());
        inputs
            .iter()
            .zip(&batch_ranks)
            .zip(batchable)
            .enumerate()
            .map(|(i, ((x, &r), batchable_x))| {
                assert!(br >= r);
                if !batchable_x && r > 0 {
                    return Err(i);
                }
                if br == r || !batchable_x {
                    return Ok(x.clone());
                }
                Ok(Rearrange::nrc(
                    x.clone(),
                    RearrangeSpec::prepend_batch_shape(
                        SmallVec::from_slice(&broadcasted_batch_shape[..br - r]),
                        x.ndim(),
                    )
                    .unwrap(),
                    x.name().map(|x| format!("{} rep_for_batch", x)),
                ))
            })
            .collect()
    };

    let find_non_symbolic = |sizes: &[usize]| {
        sizes
            .into_iter()
            .find_map(|&s| (!SymbolicSizeProduct::has_symbolic(s)).then_some(s))
            .unwrap_or(sizes[0])
    };

    let out = match &**circuit {
        Circuit::SetSymbolicShape(_) => Ok(inputs[0].clone()), /* just return circuit, symbolic shapes will be added as needed! (TODO: maybe should keep?) */
        Circuit::Symbol(_) | Circuit::Scalar(_) => Ok(circuit.clone()),
        Circuit::Rearrange(node) => {
            let input_shape_non_batch = inputs[0].info().shape[batch_ranks[0]..]
                .iter()
                .cloned()
                .collect();

            // TODO: test me better!
            let (new_input, expanded_spec) = match node
                .spec
                .expand_to_spec_or_shape(&input_shape_non_batch, false)
                .expect("errors here have been handled above")
            {
                ExpandToSpecOrShape::Spec(spec) => (inputs[0].clone(), spec),
                ExpandToSpecOrShape::SetShape(forced_shape) => {
                    let new_input = SetSymbolicShape::some_set_and_symbolic_neq(
                        inputs[0].clone(),
                        iter::repeat(OpSize::NONE)
                            .take(batch_ranks[0])
                            .chain(forced_shape)
                            .collect(),
                        None,
                    )
                    .with_context(|| {
                        format!(
                            "failed to set symbolic in expand rearrange, old rearrange={:?}",
                            node
                        )
                    })?;
                    let expanded_spec = node
                        .spec
                        .expand_to(new_input.shape())
                        .context("failed to expand rearrange to input shape in expand")?;
                    (new_input, expanded_spec)
                }
            };

            let new_spec = expanded_spec.add_batch_dims(batch_ranks[0]);
            Rearrange::try_new(new_input, new_spec, circuit.name_cloned()).map(|x| x.rc())
        }
        // I think we just don't do any symbolic shape munging in Index/Scatter?
        Circuit::Index(node) => {
            // for now non-batch non-identity dims can't change
            for i in 0..node.node.info().rank() {
                if node.node.info().shape[i] != inputs[0].info().shape[i + batch_ranks[0]]
                    && node.index.0[i] != TensorAxisIndex::IDENT
                {
                    bail!(ExpandError::FixedIndex {
                        index: node.index.clone(),
                        old_shape: node.node.info().shape.clone(),
                        new_shape: inputs[0].info().shape.clone(),
                    });
                }
            }
            Ok(Index::nrc(
                inputs[0].clone(),
                TensorIndex(
                    vec![TensorAxisIndex::IDENT; batch_ranks[0]]
                        .into_iter()
                        .chain(node.index.0.iter().cloned())
                        .collect(),
                ),
                node.name_cloned(),
            ))
        }
        Circuit::Scatter(node) => {
            // for now non-batch non-identity dims can't change
            for i in 0..node.node.info().rank() {
                if node.node.info().shape[i] != inputs[0].info().shape[i + batch_ranks[0]]
                    && node.index.0[i] != TensorAxisIndex::IDENT
                {
                    bail!(ExpandError::FixedIndex {
                        index: node.index.clone(),
                        old_shape: node.node.info().shape.clone(),
                        new_shape: inputs[0].info().shape.clone(),
                    });
                }
            }
            Ok(Scatter::nrc(
                inputs[0].clone(),
                TensorIndex(
                    vec![TensorAxisIndex::IDENT; batch_ranks[0]]
                        .into_iter()
                        .chain(node.index.0.iter().cloned())
                        .collect(),
                ),
                inputs[0].info().shape[0..batch_ranks[0]]
                    .iter()
                    .cloned()
                    .chain(node.info().shape.iter().cloned())
                    .collect(),
                node.name_cloned(),
            ))
        }
        Circuit::Concat(node) => {
            // TODO: error on concat axis having symbolic size? (or maybe concat axis is fine???
            let br = batch_rank;

            let inputs = get_expanded_inputs(&vec![true; inputs.len()]).unwrap();

            let new_axis = node.axis + br;
            if !zip(&node.nodes, &inputs)
                .all(|(old, new)| old.info().shape[node.axis] == new.info().shape[new_axis])
            {
                bail!(ExpandError::ConcatAxis {
                    axis: node.axis,
                    old_shape: sv![],
                    new_shape: sv![],
                });
            }

            let inputs = if inputs.iter().all(|x| x.ndim() == inputs[0].ndim()) {
                // if statment just so that if we're going to fail anyway, we skip this case
                let new_shape: OpShape = (0..inputs[0].ndim())
                    .map(|i| {
                        if i == new_axis {
                            return None;
                        }

                        Some(find_non_symbolic(
                            &inputs.iter().map(|x| x.shape()[i]).collect::<Vec<_>>(),
                        ))
                    })
                    .map(|x| x.into())
                    .collect();

                inputs
                    .iter()
                    .map(|inp| {
                        SetSymbolicShape::some_set_and_symbolic_neq(
                            inp.clone(),
                            new_shape.clone(),
                            None,
                        )
                    })
                    .collect::<Result<_>>()
                    .with_context(|| {
                        format!(
                            "failed to set symbolic in expand concat, old concat={:?}",
                            node
                        )
                    })?
            } else {
                inputs.clone()
            };

            Concat::try_new(inputs, new_axis, node.name_cloned()).map(|x| x.rc())
        }
        Circuit::Add(node) => {
            let inputs = if let Some(max) = inputs.iter().map(|x| x.ndim()).max() {
                let full_shape: OpShape = (0..max)
                    .map(|i| {
                        let sizes: Vec<_> = inputs
                            .iter()
                            .filter_map(|x| {
                                i.checked_sub(max - x.ndim()).and_then(|i| {
                                    let size = x.shape()[i];
                                    (size != 1).then_some(size)
                                })
                            })
                            .collect();

                        if sizes.is_empty() {
                            None
                        } else {
                            Some(find_non_symbolic(&sizes))
                        }
                        .into()
                    })
                    .collect();
                inputs
                    .iter()
                    .map(|inp| {
                        SetSymbolicShape::some_set_and_symbolic_neq(
                            inp.clone(),
                            full_shape[max - inp.ndim()..].iter().cloned().collect(),
                            None,
                        )
                    })
                    .collect::<Result<_>>()
                    .with_context(|| {
                        format!("failed to set symbolic in expand add, old add={:?}", node)
                    })?
            } else {
                inputs.clone()
            };

            Add::try_new(inputs, node.name_cloned()).map(|x| x.rc())
        }
        Circuit::GeneralFunction(node) => {
            // TODO: symbolic? (annoying to handle, would require changes to spec)
            // We could optimize out these input expansions in various cases.
            // TODO: we could run into strange issues with general functions and expand.
            // Maybe resolve with an assert...
            // TODO: maybe alow for batching over individual inputs like:
            // (This is a bit annoying in various ways and probably shouldn't be supported.
            // - [b_0, b_1, b_2, (stuff)]
            // - [b_2, (other stuff)]
            GeneralFunction::try_new(
                get_expanded_inputs(&node.input_batchability).map_err(|input_i| {
                    ExpandError::GeneralFunctionTriedToBatchNonBatchableInput {
                        input_i,
                        batched_inputs: inputs.clone(),
                        general_function: node.clone(),
                    }
                })?,
                node.spec.clone(),
                node.name_cloned(),
            )
            .map(|x| x.rc())
        }
        Circuit::Einsum(node) => {
            let br = batch_rank;
            let next_axis = node.next_axis();
            assert!(next_axis as usize + br <= u8::MAX as usize);
            let end_axis = next_axis + br as u8;
            let out_axes = (next_axis..end_axis)
                .chain(node.out_axes.iter().cloned())
                .collect();
            let new_args: Vec<_> = node
                .args
                .iter()
                .zip(inputs)
                .zip(batch_ranks)
                .map(|(((_child, ints), inp), r)| {
                    assert!(br >= r);
                    (
                        inp.clone(),
                        (next_axis + (br - r) as u8..end_axis)
                            .chain(ints.iter().cloned())
                            .collect::<EinsumAxes>(),
                    )
                })
                .collect();

            let mut shape_map_many = HashMap::default();
            for (circ, axes) in &new_args {
                for (&circuit_shape, axis) in circ.info().shape.iter().zip(axes) {
                    shape_map_many
                        .entry(*axis)
                        .or_insert(Vec::new())
                        .push(circuit_shape);
                }
            }

            let axis_to_set_shape: HashMap<_, _> = shape_map_many
                .into_iter()
                .map(|(axis, sizes)| (axis, find_non_symbolic(&sizes)))
                .collect();

            let new_args = new_args
                .into_iter()
                .map(|(x, axes)| {
                    assert_eq!(
                        x.ndim(),
                        axes.len(),
                        "should be true due to above batching code"
                    );
                    Ok((
                        SetSymbolicShape::some_set_and_symbolic_neq(
                            x,
                            axes.iter()
                                .map(|x| Some(axis_to_set_shape[x]).into())
                                .collect(),
                            None,
                        )?,
                        axes,
                    ))
                })
                .collect::<Result<_>>()
                .with_context(|| {
                    format!(
                        "failed to set symbolic in expand einsum, old einsum={:?}",
                        node
                    )
                })?;

            Einsum::try_new(new_args, out_axes, node.name_cloned()).map(|x| x.rc())
        }
        Circuit::Module(node) => {
            assert_eq!(inputs.len(), node.children().count());
            let mut children = inputs.clone();
            let rest = children.split_off(1);
            let spec_circuit = children.pop().unwrap();

            if node.nodes.is_empty() {
                // handle empty case so we can assume non-empty below
                return Ok(Module::nrc(
                    vec![],
                    ModuleSpec {
                        circuit: spec_circuit,
                        arg_specs: vec![],
                    },
                    node.name_cloned(),
                ));
            }

            let new_nodes = rest
                .chunks_exact(2)
                .zip(&node.spec.arg_specs)
                .map(|(sym_inp, arg_spec)| {
                    if let [sym, inp] = sym_inp {
                        if sym != &arg_spec.symbol.crc() {
                            bail!(ExpandError::ModuleArgSpecSymbolChangedInExpand {
                                old_symbol: arg_spec.symbol.clone(),
                                new_symbol_circ: sym.clone(),
                                module: node.clone(),
                            });
                        }
                        Ok(inp.clone())
                    } else {
                        unreachable!()
                    }
                })
                .collect::<Result<Vec<_>>>()?;

            // Approach:
            // suppose that orig node shapes are:
            // [     b_1, b_2, | x_0, x_1]
            // [          b_2, | y_0]
            // [b_0, b_1, b_2, | z_0, z_1]
            // Where x/y/z is the part which matches up with the input symbol
            // (with a | to indicate the division) and the `b_i` are batched
            // over by the module itself. (I aligned batch shapes in the
            // diagram for clarity, the number of batch dims is different for
            // each input)
            //
            // Suppose this yields an overall module shape of
            // [b_0, b_1, b_2, | w_0, w_1, w_2, w_3]
            // Where w_0, ..., w_3 are from the spec circuit itself and b_0, b_1, b_2
            // are the batching done by the module.
            //
            //
            // Now suppose we get new shapes:
            // [               n_3,      b_1, b_2, x_0, x_1]
            // [          n_2, n_3,           b_2, y_0]
            // [n_0, n_1, n_2, n_3, b_0, b_1, b_2, z_0, z_1]
            // (n is for 'new')
            //
            // Then we'll first move new dims into the symbol part as needed to
            // make batching 'contiguous' (via a rearrange)
            // [               b_1, b_2, | n_3, x_0, x_1]
            // [                    b_2, | n_2, n_3, y_0]
            // [n_0, n_1, b_0, b_1, b_2, | n_2, n_3, z_0, z_1]
            //
            // Note this requires changing the symbol shapes - we'll have to prepend the extra new dims.
            //
            // Now the module will originally have the following shape because we shifted dims around:
            // [n_0, n_1, b_0, b_1, b_2, | n_2, n_3, w_0, w_1, w_2, w_3]
            //
            // So we'll 'unshift' these dims to yield:
            // [n_0, n_1, n_2, n_3, b_0, b_1, b_2,  w_0, w_1, w_2, w_3]

            let current_node_batch_ranks: Vec<_> = node
                .nodes
                .iter()
                .zip(&node.spec.arg_specs)
                .map(|(node, arg_spec)| node.ndim().checked_sub(arg_spec.symbol.ndim()).unwrap())
                .collect();
            let current_batch_rank = *current_node_batch_ranks.iter().max().unwrap();

            let new_node_batch_ranks: Vec<_> = new_nodes
                .iter()
                .zip(&node.nodes)
                .map(|(node, orig_node)| node.ndim().checked_sub(orig_node.ndim()).unwrap())
                .collect();

            let spec_circuit_batch_rank = batch_ranks[0]; // spec circuit is 0th child

            // how many dims do we have to move to make batching 'contiguous'
            let rank_to_push_into_spec_circuit = current_node_batch_ranks
                .iter()
                .zip(&new_node_batch_ranks)
                .map(|(this_current_batch_rank, this_new_batch_rank)| {
                    if *this_current_batch_rank == current_batch_rank {
                        0
                    } else {
                        *this_new_batch_rank
                    }
                })
                .chain(iter::once(spec_circuit_batch_rank))
                .max()
                .unwrap();

            let (spec_circuit, arg_specs, new_nodes) = if rank_to_push_into_spec_circuit > 0 {
                let to_unzip = node
                    .spec
                    .arg_specs
                    .iter()
                    .zip(new_node_batch_ranks)
                    .zip(new_nodes)
                    .zip(current_node_batch_ranks)
                    .map(
                        |(((arg_spec, this_new_batch_rank), node), this_current_batch_rank)| {
                            let this_sym_extra_rank =
                                rank_to_push_into_spec_circuit.min(this_new_batch_rank);
                            // move dims to make batching contiguous
                            let new_sym = Symbol::new(
                                node.shape()[this_new_batch_rank - this_sym_extra_rank
                                    ..this_new_batch_rank]
                                    .iter()
                                    .chain(arg_spec.symbol.shape())
                                    .cloned()
                                    .collect(),
                                arg_spec.symbol.uuid,
                                arg_spec.symbol.name_cloned(),
                            );
                            let node = if this_current_batch_rank > 0 && this_sym_extra_rank > 0 {
                                let old_sym_rank = arg_spec.symbol.rank();
                                let node_rank = node.rank();
                                let up_to_old_sym_rank = node_rank - old_sym_rank;
                                let shift_name = node.name().map(|x| format!("{} shift_batch", x));
                                let spec = RearrangeSpec::new_permute(
                                    (this_sym_extra_rank..up_to_old_sym_rank)
                                        .chain(0..this_sym_extra_rank)
                                        .chain(up_to_old_sym_rank..node_rank)
                                        .collect(),
                                )
                                .unwrap();
                                assert!(!spec.is_identity()); // if statement should be checking for identity
                                Rearrange::nrc(node, spec, shift_name)
                            } else {
                                node
                            };

                            (
                                (arg_spec.symbol.crc(), new_sym.crc()),
                                ModuleArgSpec {
                                    symbol: new_sym,
                                    ..arg_spec.clone()
                                },
                                node,
                            )
                        },
                    );
                let (replacements, new_arg_specs, new_nodes): (Vec<_>, _, _) =
                    itertools::multiunzip(to_unzip);
                let replacements = replacements // this is just for efficiency
                    .into_iter()
                    .filter_map(|(a, b)| (a != b).then_some((a, b)))
                    .collect();

                let new_spec_circuit =
                    rerun_with_extra_replacements(node.spec.circuit.clone(), replacements, 0)
                        .with_context(|| {
                            format!("failed to rerun batching for spec circuit with expanded symbols! (TODO: info)")
                        })?;

                assert_eq!(
                    new_spec_circuit.ndim(),
                    rank_to_push_into_spec_circuit + node.spec.circuit.ndim(),
                    "this assumption should hold due to guarantees that batching makes!"
                );

                (new_spec_circuit, new_arg_specs, new_nodes)
            } else {
                (spec_circuit, node.spec.arg_specs.clone(), new_nodes)
            };

            let spec = ModuleSpec {
                circuit: spec_circuit,
                arg_specs,
            };

            let out = Module::try_new(new_nodes, spec, node.name_cloned())?.rc();

            // unshift dims
            let final_out = if current_batch_rank > 0 && rank_to_push_into_spec_circuit > 0 {
                let orig_spec_ndim = node.spec.circuit.ndim();
                assert_eq!(
                    out.ndim(),
                    current_batch_rank + batch_rank + orig_spec_ndim,
                    "this assumption should hold due to guarantees that batching makes!"
                );
                assert!(batch_rank >= rank_to_push_into_spec_circuit);
                let batch_rank_done_via_mod = batch_rank - rank_to_push_into_spec_circuit;
                let total_new_old_batch_via_mod = batch_rank_done_via_mod + current_batch_rank;

                // starting to get a bit complicated : )
                let new_batch_dims_via_mod = 0..batch_rank_done_via_mod;
                let end_all_batching = total_new_old_batch_via_mod + rank_to_push_into_spec_circuit;
                let new_batch_dims_via_push_into_spec_circuit =
                    total_new_old_batch_via_mod..end_all_batching;
                let old_batch_dims = batch_rank_done_via_mod..total_new_old_batch_via_mod;
                let remaining_spec_circuit_dims =
                    end_all_batching..end_all_batching + orig_spec_ndim;
                assert_eq!(end_all_batching + orig_spec_ndim, out.ndim());

                let spec = RearrangeSpec::new_permute(
                    new_batch_dims_via_mod
                        .chain(new_batch_dims_via_push_into_spec_circuit)
                        .chain(old_batch_dims)
                        .chain(remaining_spec_circuit_dims)
                        .collect(),
                )
                .unwrap();
                assert!(!spec.is_identity()); // if statement should be checking for identity

                let shift_name = out.name().map(|x| format!("{} shift_batch", x));
                let out_name = out.name_cloned();
                Rearrange::nrc(out.rename(shift_name), spec, out_name)
            } else {
                out
            };

            Ok(final_out)
        }
        Circuit::Tag(node) => Ok(Tag::new(inputs[0].clone(), node.uuid, node.name_cloned()).rc()),
        _ => {
            if inputs[..] == circuit.children().collect::<Vec<_>>()[..] {
                Ok(circuit.clone())
            } else {
                bail!(ExpandError::NodeUnhandledVariant {
                    variant: circuit.variant_string(),
                })
            }
        }
    }?;
    assert!(
        out.ndim() == circuit.ndim() + batch_rank,
        "batching assumption violated!"
    );
    Ok(out)
}

#[apply(python_error_exception)]
#[base_error_name(Expand)]
#[base_exception(PyValueError)]
#[derive(Error, Debug, Clone)]
pub enum ExpandError {
    #[error("expand wrong number of children, expected {expected} got {got} ({e_name})")]
    WrongNumChildren { expected: usize, got: usize },

    #[error("Batching Rank Too Low ({e_name})")]
    BatchingRankTooLow {
        default: Vec<usize>,
        got: Vec<usize>,
    },

    #[error("Trying to expand fixed index, index {index:?} old shape{old_shape:?} new shape {new_shape:?} ({e_name})")]
    FixedIndex {
        index: TensorIndex,
        old_shape: Shape,
        new_shape: Shape,
    },

    #[error(
        "Trying to expand concat axis, index {axis} old shape{old_shape:?} new shape {new_shape:?} ({e_name})"
    )]
    ConcatAxis {
        axis: usize,
        old_shape: Shape,
        new_shape: Shape,
    },

    // error could be improved...
    #[error("input_i={input_i} is not batchable, but tried to batch\nbatched_inputs={batched_inputs:?} general_function={general_function:?}\n ({e_name})")]
    GeneralFunctionTriedToBatchNonBatchableInput {
        input_i: usize,
        batched_inputs: Vec<CircuitRc>,
        general_function: GeneralFunction,
    },

    #[error("trying to expand node, unknown variant {variant} ({e_name})")]
    NodeUnhandledVariant { variant: String },

    #[error("Not currently supported! old_symbol={old_symbol:?} != new_symbol_circ={new_symbol_circ:?}\nmodule={module:?} ({e_name})")]
    ModuleArgSpecSymbolChangedInExpand {
        old_symbol: Symbol,
        new_symbol_circ: CircuitRc,
        module: Module,
    },

    #[error("node_rank={node_rank} < symbol_rank={symbol_rank}, arg_spec={arg_spec:?} node_shape={node_shape:?} spec_circuit={spec_circuit:?} ({e_name})")]
    ModuleRankReduced {
        node_rank: usize,
        symbol_rank: usize,
        arg_spec: ModuleArgSpec,
        node_shape: Shape,
        spec_circuit: CircuitRc,
    },

    #[error("node_rank={node_rank} > symbol_rank={symbol_rank} (which indicates batching) and arg_spec={arg_spec:?} spec_circuit={spec_circuit:?} ({e_name})")]
    ModuleTriedToBatchUnbatchableInput {
        node_rank: usize,
        symbol_rank: usize,
        arg_spec: ModuleArgSpec,
        spec_circuit: CircuitRc,
    },

    #[error("node_shape={node_shape:?} symbol_shape={symbol_shape:?} arg_spec={arg_spec:?} spec_circuit={spec_circuit:?} ({e_name})")]
    ModuleTriedToExpandUnexpandableInput {
        node_shape: Shape,
        symbol_shape: Shape,
        arg_spec: ModuleArgSpec,
        spec_circuit: CircuitRc,
    },
    #[error("new_size={fancy_new_size} != old_size={old_size} and old_size not symbolic at dim={dim}\n{}\n({e_name})",
        format!("node_shape={:?}, arg_spec={:?} spec_circuit={:?}", node_shape, arg_spec, spec_circuit),
        fancy_new_size=SymbolicSizeProduct::from(*new_size))]
    ModuleTriedToExpandOnNonSymbolicSizeAndBanned {
        new_size: usize,
        old_size: usize,
        dim: usize,
        node_shape: Shape,
        arg_spec: ModuleArgSpec,
        spec_circuit: CircuitRc,
    },
}

#[pyfunction]
#[pyo3(name = "replace_expand_bottom_up_dict")]
pub fn replace_expand_bottom_up_dict_py(
    circuit: CircuitRc,
    dict: HashMap<CircuitRc, CircuitRc>,
) -> Result<CircuitRc> {
    replace_expand_bottom_up(circuit, |x| dict.get(&x).cloned())
}

#[pyfunction]
#[pyo3(name = "replace_expand_bottom_up")]
pub fn replace_expand_bottom_up_py(circuit: CircuitRc, f: PyObject) -> Result<CircuitRc> {
    replace_expand_bottom_up(circuit, |x| pycall!(f, (x.clone(),)))
}

pub fn replace_expand_bottom_up<F>(circuit: CircuitRc, replacer: F) -> Result<CircuitRc>
where
    F: Fn(CircuitRc) -> Option<CircuitRc>,
{
    replace_expand_bottom_up_impl(circuit, &replacer, Default::default())
}

pub fn replace_expand_bottom_up_impl<F>(
    circuit: CircuitRc,
    replacer: &F,
    extra_replacements: BTreeMap<CircuitRc, CircuitRc>,
) -> Result<CircuitRc>
where
    F: Fn(CircuitRc) -> Option<CircuitRc>,
{
    let recursor = |circuit: CircuitRc, new_children: &Vec<CircuitRc>| -> Result<CircuitRc> {
        if let Some(replaced) = replacer(circuit.clone()) {
            return Ok(replaced);
        }
        if let Some(replaced) = extra_replacements.get(&circuit) {
            return Ok(replaced.clone());
        }
        expand_node(circuit, new_children, |c, rep, _child_idx| {
            // This recursion is a bit gross + uncached (caching requires caching by extra_replacements ofc, so a bit more complex) : /
            replace_expand_bottom_up_impl(
                c,
                replacer,
                extra_replacements
                    .iter()
                    .map(|(a, b)| (a.clone(), b.clone()))
                    .chain(rep)
                    .collect(),
            )
        })
    };
    deep_map_fallible_pre_new_children(circuit, recursor)
}

#[test]
fn check_map_override() {
    let m: BTreeMap<_, _> = [(1, 2), (3, 4), (5, 7), (3, 9)].into_iter().collect();
    assert_eq!(m[&3], 9);
    let m: BTreeMap<_, _> = [(1, 2), (3, 4), (5, 7), (1, 9)].into_iter().collect();
    assert_eq!(m[&1], 9);
    let m: BTreeMap<_, _> = [(1, 2), (3, 4), (5, 7), (1, 9)].into_iter().collect();
    assert_eq!(m[&5], 7);
    let m: BTreeMap<_, _> = [(1, 2), (3, 4), (5, 7), (1, 9), (5, 4), (5, 8), (3, 7)]
        .into_iter()
        .collect();
    assert_eq!(m[&1], 9);
    assert_eq!(m[&5], 8);
    assert_eq!(m[&3], 7);
}
