from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from uuid import UUID

import torch

Shape = Tuple[int, ...]

class EinsumSpec(object):
    input_ints: List[List[int]]
    output_ints: List[int]
    int_sizes: List[int]
    def __init__(
        self,
        input_ints: List[List[int]],
        output_ints: List[int],
        int_sizes: Dict[int, int],
    ): ...
    def flops(self) -> int: ...
    def optimize_dp(
        self,
        check_outer: bool,
        mem_limit: Optional[int] = None,
        hash_limit: Optional[int] = None,
    ) -> List[List[int]]: ...
    def optimize(
        self,
        check_outer: bool,
        mem_limit: Optional[int] = None,
        hash_limit: Optional[int] = None,
    ) -> List[List[int]]: ...
    def normalize(self) -> "EinsumSpec": ...
    def validate(self) -> bool: ...
    def shapes(self) -> Tuple[List[List[int]], List[int]]: ...
    def to_einsum_string(self) -> str: ...
    @staticmethod
    def string_to_ints(string: str) -> Tuple[List[List[int]], List[int]]: ...

def optimize_einsum_spec_cached(
    spec: EinsumSpec,
    check_outer: bool,
    mem_limit: Optional[int] = None,
    hash_limit: Optional[int] = None,
) -> List[List[int]]: ...

class RearrangeSpec(object):
    input_ints: List[List[int]]
    output_ints: List[List[int]]
    int_sizes: List[Optional[int]]
    def __init__(
        self,
        input_ints: List[List[int]],
        output_ints: List[List[int]],
        int_sizes: List[Optional[int]],
    ): ...
    def is_identity(self) -> bool: ...
    def is_permute(self) -> bool: ...
    def shapes(self) -> Tuple[List[int], List[int]]: ...
    def is_valid(self) -> bool: ...
    def to_einops_string(self) -> str: ...
    def to_einops_string_and_letter_sizes(
        self,
    ) -> Tuple[str, List[Tuple[str, int]]]: ...
    def apply(self, tensor: torch.Tensor) -> torch.Tensor: ...
    @staticmethod
    def fuse(inner: RearrangeSpec, outer: RearrangeSpec) -> RearrangeSpec: ...
    def canonicalize(self, special_case_ones: bool = True) -> RearrangeSpec: ...
    def fill_empty_ints(self, allow_rust_invalid: bool) -> RearrangeSpec: ...
    def conform_to_input_shape(self, shape: Tuple[int, ...], coerce: bool) -> RearrangeSpec: ...
    @staticmethod
    def ident(rank: int) -> RearrangeSpec: ...
    def to_py_rearrange_spec(self, shape: Tuple[int, ...]) -> Any: ...
    @staticmethod
    def flatten(rank: int) -> RearrangeSpec: ...
    @staticmethod
    def unflatten(shape: List[int]) -> RearrangeSpec: ...

class Circuit:
    @property
    def shape(self) -> Shape: ...
    @property
    def is_constant(self) -> bool: ...
    @property
    def is_explicitly_computable(self) -> bool: ...
    @property
    def can_be_sampled(self) -> bool: ...
    @property
    def name(self) -> str: ...
    @property
    def hash(self) -> bytes: ...
    def children(self) -> List[Circuit]: ...
    def self_flops(self) -> int: ...
    def total_flops(self) -> int: ...
    def max_non_input_size(self) -> int: ...
    def apply_fn_to_sub(self, f: Callable[[Circuit], Circuit]) -> Circuit: ...
    def compiler_print(self) -> None: ...
    def print_stats(self) -> None: ...
    def compiler_repr(self) -> str: ...
    def numel(self) -> int: ...
    def rank(self) -> int: ...
    def to_py(self) -> Any: ...
    def evaluate(self) -> torch.Tensor: ...

class ArrayConstant(Circuit):
    def __init__(self, value: torch.Tensor, name: str = "ArrayConstant") -> None: ...
    @property
    def uuid(self) -> UUID: ...
    @property
    def value(self) -> torch.Tensor: ...
    @staticmethod
    def randn(*args: int) -> "ArrayConstant": ...
    @staticmethod
    def randn_seeded(shape: List[int], seed: int) -> "ArrayConstant": ...
    @staticmethod
    def new_named_axes(value: torch.Tensor, name: str, named_axes: List[Optional[str]]) -> "ArrayConstant": ...

class Symbol(Circuit):
    def __init__(self, shape: Shape, uuid: UUID, name: str = "Symbol") -> None: ...
    @property
    def uuid(self) -> UUID: ...

class ScalarConstant(Circuit):
    def __init__(self, value: float, shape: Shape = (), name: str = "ScalarConstant") -> None: ...
    @property
    def value(self) -> float: ...
    def is_zero(self) -> bool: ...
    def is_one(self) -> bool: ...

class Einsum(Circuit):
    def __init__(
        self,
        *args: Tuple[Circuit, Tuple[int, ...]],
        out_axes: Tuple[int, ...],
        name: Optional[str] = None,
    ) -> None: ...
    @property
    def args(self) -> List[Tuple[Circuit, Tuple[int, ...]]]: ...
    @property
    def out_axes(self) -> Tuple[int, ...]: ...
    def all_input_circuits(self) -> List[Circuit]: ...
    def all_input_axes(self) -> List[Tuple[int, ...]]: ...
    @staticmethod
    def from_einsum_string(string: str, nodes: List[Circuit], name: Optional[str] = None) -> Einsum: ...
    @staticmethod
    def from_spec(spec: EinsumSpec, circuits: List[Circuit], name: Optional[str] = None) -> Einsum: ...
    @staticmethod
    def new_diag(node: Circuit, ints: List[int], name: Optional[str] = None) -> Einsum: ...
    @staticmethod
    def new_trace(node: Circuit, ints: List[int], name: Optional[str] = None) -> Einsum: ...

class Add(Circuit):
    def __init__(self, nodes: List[Circuit], name: Optional[str] = None) -> None: ...
    @property
    def nodes(self) -> List[Circuit]: ...
    def has_broadcast(self) -> bool: ...
    def nodes_and_rank_differences(self) -> List[Tuple[Circuit, int]]: ...

class Rearrange(Circuit):
    def __init__(self, node: Circuit, spec: RearrangeSpec, name: Optional[str] = None) -> None: ...
    @property
    def node(self) -> Circuit: ...
    @property
    def spec(self) -> RearrangeSpec: ...

class Index(Circuit):
    def __init__(
        self,
        node: Circuit,
        index: Tuple[Union[int, slice, torch.Tensor], ...],
        name: Optional[str] = None,
    ) -> None: ...
    @property
    def node(self) -> Circuit: ...
    @property
    def index(self) -> Tuple[Union[int, slice, torch.Tensor], ...]: ...

class Scatter(Circuit):
    def __init__(
        self,
        node: Circuit,
        index: Tuple[Union[int, slice, torch.Tensor], ...],
        shape: Tuple[int, ...],
        name: Optional[str] = None,
    ) -> None: ...
    @property
    def node(self) -> Circuit: ...
    @property
    def index(self) -> Tuple[Union[int, slice, torch.Tensor], ...]: ...

class GeneralFunctionSpec(object):
    function: Callable
    get_shape: Callable
    get_jacobian: Optional[Callable]
    num_non_batchable_output_dims: int
    input_batchability: List[bool]
    name: str
    def __init__(
        self,
        function: Callable,
        get_shape: Callable,
        get_jacobian: Optional[Callable],
        num_non_batchable_output_dims: int,
        input_batchability: List[bool],
        name: str,
    ) -> None: ...
    def is_batchable(self) -> bool: ...

class PyGFSpecShapeGetter(object):
    num_non_batchable: int
    def __init__(self, num_non_batchable: int) -> None: ...
    def __call__(self, args: List[Shape]) -> Optional[Shape]: ...

class GeneralFunction(Circuit):
    def __init__(
        self,
        nodes: List[Circuit],
        spec: GeneralFunctionSpec,
        name: Optional[str] = None,
    ) -> None: ...
    @property
    def nodes(self) -> List[Circuit]: ...
    @property
    def spec(self) -> GeneralFunctionSpec: ...
    @staticmethod
    def new_by_name(nodes: List[Circuit], spec_name: str, name: Optional[str] = None) -> GeneralFunction: ...

class Concat(Circuit):
    def __init__(self, nodes: List[Circuit], axis: int, name: Optional[str] = None) -> None: ...
    @property
    def nodes(self) -> List[Circuit]: ...
    @property
    def axis(self) -> int: ...

class TorchDeviceDtype(object):
    def __init__(self, device: str, dtype: str) -> None: ...
    @property
    def device(self) -> str: ...
    @property
    def dtype(self) -> str: ...

class OptimizationSettings(object):
    verbose: int
    max_memory: int
    scheduling_num_mem_chunks: int
    distribute_min_size: int
    scheduling_naive: bool
    scheduling_simplify: bool
    adjust_numerical_scale: bool
    numerical_scale_min: float
    numerical_scale_max: float
    capture_and_print: bool
    def __init__(
        self,
        verbose=0,
        max_memory=9_000_000_000,
        scheduling_num_mem_chunks=200,
        distribute_min_size=600_000_000,
        scheduling_naive=False,
        scheduling_simplify=True,
        adjust_numerical_scale=False,
        numerical_scale_min=1e-8,
        numerical_scale_max=1e8,
        capture_and_print=False,
    ) -> None: ...

class PyOOMError(Exception): ...

def add_collapse_scalar_inputs(add: Add) -> Optional[Add]: ...
def add_deduplicate(add: Add) -> Optional[Add]: ...
def add_flatten_once(add: Add) -> Optional[Add]: ...
def remove_add_few_input(add: Add) -> Optional[Add]: ...
def add_pull_removable_axes(add: Add, remove_non_common_axes: bool) -> Optional[Circuit]: ...
def einsum_flatten_once(einsum: Einsum) -> Optional[Einsum]: ...
def einsum_elim_identity(einsum: Einsum) -> Optional[Circuit]: ...
def index_merge_scalar(index: Index) -> Optional[Circuit]: ...
def index_elim_identity(index: Index) -> Optional[Circuit]: ...
def index_fuse(index: Index) -> Optional[Index]: ...
def rearrange_fuse(node: Rearrange) -> Optional[Rearrange]: ...
def rearrange_merge_scalar(rearrange: Rearrange) -> Optional[Circuit]: ...
def rearrange_elim_identity(rearrange: Rearrange) -> Optional[Circuit]: ...
def concat_elim_identity(concat: Concat) -> Optional[Circuit]: ...
def concat_merge_uniform(concat: Concat) -> Optional[Concat]: ...
def concat_pull_removable_axes(concat: Concat) -> Optional[Circuit]: ...
def generalfunction_pull_removable_axes(node: GeneralFunction) -> Optional[Circuit]: ...
def einsum_pull_removable_axes(node: Einsum) -> Optional[Circuit]: ...
def scatter_pull_removable_axes(node: Scatter) -> Optional[Circuit]: ...
def add_make_broadcasts_explicit(node: Add) -> Optional[Add]: ...
def distribute(node: Einsum, operand_idx: int, do_broadcasts: bool) -> Optional[Add]: ...
def distribute_all(node: Einsum) -> Optional[Add]: ...
def einsum_of_permute_merge(node: Einsum) -> Optional[Einsum]: ...
def permute_of_einsum_merge(node: Rearrange) -> Optional[Einsum]: ...
def einsum_elim_zero(node: Einsum) -> Optional[ScalarConstant]: ...
def einsum_merge_scalars(node: Einsum) -> Optional[Einsum]: ...
def push_down_index(node: Index) -> Optional[Circuit]: ...
def deep_push_down_index(node: Circuit, min_size: Optional[int] = None) -> Circuit: ...
def index_split_axes(node: Index, top_axes: Set[int]) -> Optional[Index]: ...
def add_elim_zeros(node: Add) -> Optional[Add]: ...
def compiler_simp(node: Circuit) -> Circuit: ...
def compiler_simp_step(node: Circuit) -> Optional[Circuit]: ...
def compiler_simp_until_same(node: Circuit) -> Circuit: ...
def strip_names(node: Circuit) -> Circuit: ...
def deep_canonicalize(node: Circuit) -> Circuit: ...
def canonicalize_node(node: Circuit) -> Circuit: ...
def deep_maybe_distribute(node: Circuit, settings: OptimizationSettings) -> Circuit: ...
def maybe_distribute(node: Einsum, settings: OptimizationSettings) -> Optional[Circuit]: ...
def einsum_nest_optimize(node: Einsum) -> Einsum: ...
def deep_optimize_einsums(node: Circuit) -> Circuit: ...
def einsum_nest_path(node: Einsum, path: List[List[int]]) -> Einsum: ...
def scatter_elim_identity(node: Scatter) -> Optional[Circuit]: ...
def einsum_pull_scatter(node: Einsum) -> Optional[Circuit]: ...
def add_pull_scatter(node: Add) -> Optional[Circuit]: ...
def index_einsum_to_scatter(node: Index) -> Optional[Circuit]: ...
def optimize_circuit(circuit: Circuit, settings: OptimizationSettings) -> Circuit: ...
def cast_circuit(circuit: Circuit, device_dtype: TorchDeviceDtype) -> Circuit: ...
def scatter_to_concat(scatter: Scatter) -> Circuit: ...
def count_nodes(circuit: Circuit) -> int: ...
def scheduled_evaluate(circuit: Circuit, settings: OptimizationSettings) -> torch.Tensor: ...
def optimize_and_evaluate(circuit: Circuit, settings: OptimizationSettings) -> torch.Tensor: ...
def optimize_and_evaluate_many(circuits: List[Circuit], settings: OptimizationSettings) -> List[torch.Tensor]: ...
def flat_concat(circuits: List[Circuit]) -> Concat: ...
def deep_heuristic_nest_adds(circuit: Circuit) -> Circuit: ...
def generalfunction_pull_concat(circuit: GeneralFunction) -> Optional[Concat]: ...
def concat_fuse(circuit: Concat) -> Optional[Concat]: ...
def concat_drop_size_zero(circuit: Concat) -> Optional[Concat]: ...
def index_concat_drop_unreached(circuit: Index) -> Optional[Index]: ...
def einsum_pull_concat(circuit: Einsum) -> Optional[Circuit]: ...
def add_pull_concat(circuit: Add) -> Optional[Circuit]: ...
def split_to_concat(circuit: Circuit, axis: int, sections: List[int]) -> Optional[Circuit]: ...
def deep_pull_concat_messy(circuit: Circuit, min_size: Optional[int]) -> Circuit: ...
def deep_pull_concat(circuit: Circuit, min_size: Optional[int]) -> Circuit: ...
def batch_inputs_axis_len(circuit: Circuit, axis_len: int, num_batches: int, min_size: Optional[int]) -> Circuit: ...
def batch_largest(circuit: Circuit, num_batches: int, min_size: Optional[int]) -> Circuit: ...
def set_named_axes(circuit: Circuit, named_axes: List[Optional[str]]) -> Circuit: ...
def propagate_named_axes(circuit: Circuit, named_axes: List[Optional[str]]) -> Circuit: ...
def toposort_circuit(circuit: Circuit) -> List[Circuit]: ...
def einsum_push_down_trace(circuit: Einsum) -> Optional[Einsum]: ...
def add_pull_diags(circuit: Add) -> Optional[Circuit]: ...
def concat_repeat_to_rearrange(circuit: Concat) -> Optional[Concat]: ...
def add_outer_product_broadcasts_on_top(circuit: Add) -> Optional[Add]: ...
def replace_all_randn_seeded(circuit: Circuit) -> Circuit: ...
