# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registry for layers and their parameters/variables.

This represents the collection of all layers in the approximate Fisher
information matrix to which a particular FisherBlock may belong. That is, we
might have several layer collections for one TF graph (if we have multiple K-FAC
optimizers being used, for example.)

The model and loss function are registered using the register_XXX() methods.
A subset of the layer types can be handled with the auto_register_layers()
method.

Note that the data formats in the docstrings for the register_XXX() methods
must be strictly adhered to. So for example, if a method asks for a Tensor of
shape [batch_size, ...], then the first dimension must be the batch size and
nothing else.  And the tensors must contain actual data, not a mixture of real
and fake data / zeros generated by mini-batch padding, for example.  (Padding
is only fine if it's treated as regular data by both your model and loss
function. e.g. adding "blank tokens" at the end of a sequence which the model
is still expected to predict.) If a method asks for the  parameters of a layer
then they must be the actual variable object(s) for said parameters, not a
tensor formed by reshaping, re-casting, or tranposing its value.

If the internal data format used by your model isn't natively supported by
this system, you shouldn't try to crow-bar the arguments of the registration
methods until they seem to fit. Although the K-FAC code tries to protect
against some common mistakes, it may often seem to run fine with incorrect
registrations, generating no exceptions or errors. But this will almost
certainly lead to (potentially severe) underperformance of the method.

If you have model code that doesn't represent tensors in the format expected
by K-FAC, one thing you can try is introducing transformations that perform the
conversion back and forth. But make sure the format that you convert to is
actually valid according to the strict specifications of the registration
function docstrings (e.g. that batch_size really is the mini-batch size, etc).

So if "x" is some data needed in the registration function that isn't of the
correct format, you can try something like the following:

x_transformed = transform(x)
lc.register_XXX(x_transformed)
x = untransform(x_transformed)
...use x in rest of model...

Note that without "x = untransform(x_transformed)" this often won't work since
x_transformed won't be part of the model's forward graph, which is something
K-FAC needs (especially for the "output" arguments of layers).
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from collections import defaultdict
from collections import OrderedDict
from contextlib import contextmanager
from functools import partial
import math

# Dependency imports
import six
import tensorflow.compat.v1 as tf

from tensorflow.python.util import nest
from kfac.python.ops import fisher_blocks as fb
from kfac.python.ops import loss_functions as lf
from kfac.python.ops import utils
from kfac.python.ops.tensormatch import graph_search

# Names for various approximations that can be requested for Fisher blocks.
APPROX_KRONECKER_NAME = "kron"
APPROX_KRONECKER_IN_DIAG_NAME = "kron_in_diag"
APPROX_KRONECKER_OUT_DIAG_NAME = "kron_out_diag"
APPROX_KRONECKER_BOTH_DIAG_NAME = "kron_both_diag"
APPROX_DIAGONAL_NAME = "diagonal"
APPROX_FULL_NAME = "full"

APPROX_KRONECKER_INDEP_NAME = "kron_indep"
APPROX_KRONECKER_INDEP_IN_DIAG_NAME = "kron_indep_in_diag"
APPROX_KRONECKER_INDEP_OUT_DIAG_NAME = "kron_indep_out_diag"
APPROX_KRONECKER_INDEP_BOTH_DIAG_NAME = "kron_indep_both_diag"
APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1"
APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2"
APPROX_KRONECKER_SUA_NAME = "kron_sua"


# Possible value for 'reuse' keyword argument. Sets 'reuse' to
# tf.get_variable_scope().reuse.
VARIABLE_SCOPE = "VARIABLE_SCOPE"

_DEFAULT_LAYER_COLLECTION = None


def get_default_layer_collection():
  """Get default LayerCollection."""
  if _DEFAULT_LAYER_COLLECTION is None:
    raise ValueError(
        "Attempted to retrieve default LayerCollection when none is set. Use "
        "LayerCollection.as_default().")

  return _DEFAULT_LAYER_COLLECTION


def set_default_layer_collection(layer_collection):
  global _DEFAULT_LAYER_COLLECTION

  if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None:
    raise ValueError("Default LayerCollection is already set.")

  _DEFAULT_LAYER_COLLECTION = layer_collection


class LayerParametersDict(OrderedDict):
  """An OrderedDict where keys are Tensors or tuples of Tensors.

  Ensures that no Tensor is associated with two different keys.
  """

  def __init__(self, *args, **kwargs):
    self._tensors = set()
    super(LayerParametersDict, self).__init__(*args, **kwargs)

  def __setitem__(self, key, value):
    key = self._canonicalize_key(key)
    tensors = key if isinstance(key, (tuple, list)) else (key,)
    key_collisions = self._tensors.intersection(tensors)
    if key_collisions:
      raise ValueError("Key(s) already present: {}".format(key_collisions))
    self._tensors.update(tensors)
    super(LayerParametersDict, self).__setitem__(key, value)

  def __delitem__(self, key):
    key = self._canonicalize_key(key)
    self._tensors.remove(key)
    super(LayerParametersDict, self).__delitem__(key)

  def __getitem__(self, key):
    key = self._canonicalize_key(key)
    return super(LayerParametersDict, self).__getitem__(key)

  def __contains__(self, key):
    key = self._canonicalize_key(key)
    return super(LayerParametersDict, self).__contains__(key)

  def _canonicalize_key(self, key):
    if isinstance(key, (list, tuple)):
      return tuple(key)
    return key


# TODO(b/68034464): add capability for LayerCollection to be "finalized"
# and do this when it gets used by FisherEstimator / KfacOptimizer.


class LayerCollection(object):
  """Registry of information about layers and losses.

  Note that you need to create a new one of these for each FisherEstimator or
  KfacOptimizer, as they can't be used more than once.

  The methods that you should interact with directly are:
   - register_XXX()
   - auto_register_layers()

  Additional control over the automatic registration process can be exerted by
  using the methods/properties:
   - set_default_XXX() and default_XXX
   - define_linked_parameters() and linked_parameters


  Attributes:
    fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer
        parameters (Tensors or tuples of Tensors) to FisherBlock instances.
    fisher_factors: an OrderedDict mapping tuples to FisherFactor instances.
    losses: a list of LossFunction objects. The loss to be optimized is their
        sum.
    loss_colocation_ops: ops to colocate loss function evaluations with.  These
        will typically be the inputs to the losses.
  """

  def __init__(self,
               graph=None,
               name="LayerCollection"):
    self.fisher_blocks = LayerParametersDict()
    self.fisher_factors = OrderedDict()
    self._linked_parameters = dict(
    )  # dict mapping sets of variables to optionally specified approximations.
    self._graph = graph or tf.get_default_graph()
    self._loss_dict = OrderedDict()  # {str: LossFunction}
    self._subgraph = None
    self._default_generic_approximation = APPROX_DIAGONAL_NAME
    self._default_fully_connected_approximation = APPROX_KRONECKER_NAME
    self._default_conv2d_approximation = APPROX_KRONECKER_NAME
    self._default_fully_connected_multi_approximation = (
        APPROX_KRONECKER_INDEP_NAME)
    self._default_conv2d_multi_approximation = (
        APPROX_KRONECKER_INDEP_NAME)
    self._default_scale_and_shift_approximation = APPROX_FULL_NAME
    self.loss_colocation_ops = {}
    self.loss_coeffs = {}
    self._vars_to_uses = defaultdict(lambda: 0)

    self._finalized = False

    with tf.variable_scope(None, default_name=name) as scope:
      self._var_scope = scope.name

    self._generic_approx_to_block_types = {
        APPROX_FULL_NAME: fb.NaiveFullFB,
        APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,
    }

    self._fully_connected_approx_to_block_types = {
        APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,
        APPROX_KRONECKER_IN_DIAG_NAME:
            partial(fb.FullyConnectedKFACBasicFB,
                    diagonal_approx_for_input=True),
        APPROX_KRONECKER_OUT_DIAG_NAME:
            partial(fb.FullyConnectedKFACBasicFB,
                    diagonal_approx_for_output=True),
        APPROX_KRONECKER_BOTH_DIAG_NAME:
            partial(fb.FullyConnectedKFACBasicFB,
                    diagonal_approx_for_input=True,
                    diagonal_approx_for_output=True),
        APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,
    }

    self._conv2d_approx_to_block_types = {
        APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB,
        APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,
        APPROX_KRONECKER_SUA_NAME: fb.ConvKFCBasicFB,
    }

    self._fully_connected_multi_approx_to_block_types = {
        APPROX_KRONECKER_INDEP_NAME:
            fb.FullyConnectedMultiIndepFB,
        APPROX_KRONECKER_INDEP_IN_DIAG_NAME:
            partial(fb.FullyConnectedMultiIndepFB,
                    diagonal_approx_for_input=True),
        APPROX_KRONECKER_INDEP_OUT_DIAG_NAME:
            partial(fb.FullyConnectedMultiIndepFB,
                    diagonal_approx_for_output=True),
        APPROX_KRONECKER_INDEP_BOTH_DIAG_NAME:
            partial(fb.FullyConnectedMultiIndepFB,
                    diagonal_approx_for_input=True,
                    diagonal_approx_for_output=True),
        APPROX_KRONECKER_SERIES_1_NAME:
            partial(fb.FullyConnectedSeriesFB, option=1),
        APPROX_KRONECKER_SERIES_2_NAME:
            partial(fb.FullyConnectedSeriesFB, option=2)
    }

    self._conv2d_multi_approx_to_block_types = {
        APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB
    }

    self._scale_and_shift_approx_to_block_types = {
        APPROX_FULL_NAME: fb.ScaleAndShiftFullFB,
        APPROX_DIAGONAL_NAME: fb.ScaleAndShiftDiagonalFB
    }

  @property
  def losses(self):
    """Tuple of LossFunction objects registered with this LayerCollection."""
    return nest.flatten(self.towers_by_loss)

  @property
  def towers_by_loss(self):
    """Tuple across losses of LossFunction objects registered to each tower."""
    return tuple(tuple(lst) for lst in self._loss_dict.values())

  @property
  def registered_variables(self):
    """A tuple of all of the variables currently registered."""
    tuple_of_tuples = (utils.ensure_sequence(key) for key, block
                       in six.iteritems(self.fisher_blocks))
    flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_)
    return flat_tuple

  @property
  def linked_parameters(self):
    """Groups of parameters with an optionally specified approximation.

    Linked parameters can be added using `define_linked_parameters`.
    If an approximation is specified, then this approximation will be used
    when registering a layer with exactly these parameters, unless an
    approximation is specified when calling the registration function.

    Returns:
      A `dict` mapping tuples of parameters to an optional string.
    """
    return self._linked_parameters

  @property
  def default_generic_approximation(self):
    return self._default_generic_approximation

  def set_default_generic_approximation(self, value):
    if value not in self._generic_approx_to_block_types:
      raise ValueError(
          "{} is not a valid approximation for generic variables.".format(
              value))
    self._default_generic_approximation = value

  @property
  def default_fully_connected_approximation(self):
    return self._default_fully_connected_approximation

  def set_default_fully_connected_approximation(self, value):
    if value not in self._fully_connected_approx_to_block_types:
      raise ValueError(
          "{} is not a valid approximation for fully connected layers.".format(
              value))
    self._default_fully_connected_approximation = value

  @property
  def default_conv2d_approximation(self):
    return self._default_conv2d_approximation

  def set_default_conv2d_approximation(self, value):
    if value not in self._conv2d_approx_to_block_types:
      raise ValueError(
          "{} is not a valid approximation for 2d convolutional layers.".format(
              value))
    self._default_conv2d_approximation = value

  @property
  def default_fully_connected_multi_approximation(self):
    return self._default_fully_connected_multi_approximation

  def set_default_fully_connected_multi_approximation(self, value):
    if value not in self._fully_connected_multi_approx_to_block_types:
      raise ValueError("{} is not a valid approximation for a fully-connected "
                       "multi layer.".format(value))
    self._default_fully_connected_multi_approximation = value

  @property
  def default_conv2d_multi_approximation(self):
    return self._default_conv2d_multi_approximation

  def set_default_conv2d_multi_approximation(self, value):
    if value not in self._conv2d_multi_approx_to_block_types:
      raise ValueError("{} is not a valid approximation for a conv2d "
                       "multi layer.".format(value))
    self._default_conv2d_multi_approximation = value

  @property
  def default_scale_and_shift_approximation(self):
    return self._default_scale_and_shift_approximation

  def set_default_scale_and_shift_approximation(self, value):
    if value not in self._scale_and_shift_approx_to_block_types:
      raise ValueError("{} is not a valid approximation for a scale & shift "
                       "layer.".format(value))
    self._default_scale_and_shift_approximation = value

  def auto_register_layers(self, var_list=None, batch_size=None):
    """Registers remaining unregistered layers automatically using a scanner.

    Requires all function / distribution registrations to be performed
    (manually) first.

    Registrations will be performed using the default approximation mode for
    each type, as if the scanner were calling the user-level registration
    functions in this LayerCollection object (which it will be). These
    defaults can be overridden using the set_default_XXX_approximation methods
    for types of layers, or using the define_linked_parameters method for
    specific parameters.

    This function should only be called after any desired manual registrations
    are performed. For example, if you have a layer which isn't recognized
    properly by the scanner, or a layer which you want to register differently.

    Note that this function is an experimental convenience feature which won't
    work for every possible model architecture. Any layers/parameters that
    whose structure is not recognized will be registered as "generic", which
    is the worst curvature matrix approximation available in the system, and
    should be avoided if possible.

    See the docstring for register_layers in graph_search.py for more details.

    Args:
      var_list: A list of variables that the automatic registration should
        consider. If you have some trainable variables (i.e. those included in
        tf.trainable_variables()) that you don't want included you need to pass
        in this list. (Default: tf.trainable_variables()).
      batch_size: A `int` representing the batch size. Needs to specified if
        registering generic variables that don't match any layer patterns or
        if time/uses is folded. If the time/uses dimension is merged with
        batch then this is used to infer number of uses/time-steps. NOTE: In the
        replicated context this must be the per-replica batch size, and not
        the total batch size.
    """
    if var_list is None:
      var_list = tf.trainable_variables()
    graph_search.register_layers(self, var_list, batch_size=batch_size)

  def finalize(self):
    if not self._finalized:
      self._create_subgraph()
      self._finalized = True
    else:
      raise ValueError("LayerCollection was finalized a second time, which "
                       "indicates an error. Perhaps you used the same "
                       "LayerCollection object in multiple "
                       "optimizers/estimators, which is not allowed.")

  def _register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
    """Validates and registers the layer_key associated with the fisher_block.

    Args:
      layer_key: A variable or tuple of variables. The key to check for in
          existing registrations and to register if valid.
      fisher_block: The associated `FisherBlock`.
      reuse: Method to use for inserting new `FisherBlock`s. One of True, False,
        or 'VARIABLE_SCOPE'.

    Raises:
      ValueError: If `layer_key` was already registered and reuse is `False`,
        if `layer_key` was registered with a different block type, or if
        `layer_key` shares any variables with but is not equal to a previously
        registered key.
      KeyError: If `reuse` is `True` but `layer_key` was not previously
        registered.

    Returns:
      The `FisherBlock` registered under `layer_key`. If `layer_key` was already
      registered, this will be the previously registered `FisherBlock`.
    """
    if self._finalized:
      raise ValueError("You cannot register additional losses or layers after "
                       "LayerCollection is finalized. Finalization happens "
                       "after the estimator or optimizer object first uses "
                       "the data in the LayerCollection. For example, when "
                       "the minimize() method is called in "
                       "PeriodicInvCovUpdateKfacOpt.")

    if reuse is VARIABLE_SCOPE:
      reuse = tf.get_variable_scope().reuse

    if reuse is True or (reuse is tf.AUTO_REUSE and
                         layer_key in self.fisher_blocks):

      if layer_key not in self.fisher_blocks:
        raise ValueError(
            "reuse was True for attempted registration involving variables {}, "
            "but no previously registered layer was found for these. Perhaps "
            "reuse was set to True by mistake. One way this can happen is if "
            "reuse is set to True in the surrounding variable scope."
            "".format(layer_key))

      result = self.fisher_blocks[layer_key]

      if type(result) != type(fisher_block):  # pylint: disable=unidiomatic-typecheck
        raise ValueError(
            "Attempted to register FisherBlock of type %s when existing "
            "FisherBlock has type %s." % (type(fisher_block), type(result)))
      return result
    if reuse is False and layer_key in self.fisher_blocks:
      raise ValueError("FisherBlock for %s is already in LayerCollection." %
                       (layer_key,))

    # Insert fisher_block into self.fisher_blocks.
    if layer_key in self.fisher_blocks:
      raise ValueError("Duplicate registration: {}".format(layer_key))
    # Raise an error if any variable in layer_key has been registered in any
    # other blocks.
    variable_to_block = {
        var: (params, block)
        for (params, block) in self.fisher_blocks.items()
        for var in utils.ensure_sequence(params)
    }
    for variable in utils.ensure_sequence(layer_key):
      if variable in variable_to_block:
        prev_key, prev_block = variable_to_block[variable]
        raise ValueError(
            "Attempted to register layer_key {} with block {}, but variable {}"
            " was already registered in key {} with block {}.".format(
                layer_key, fisher_block, variable, prev_key, prev_block))
    self.fisher_blocks[layer_key] = fisher_block
    return fisher_block

  def _register_loss_function(self,
                              loss,
                              colocation_op,
                              base_name,
                              name=None,
                              coeff=1.0,
                              reuse=VARIABLE_SCOPE):
    """Registers a LossFunction object.

    Args:
      loss: The LossFunction object.
      colocation_op: The op to colocate the loss function's computations with.
      base_name: The name to derive a new unique name from is the name argument
        is None.
      name: (OPTIONAL) str or None. Unique name for this loss function. If None,
        a new name is generated. (Default: None)
      coeff: (OPTIONAL) a scalar. coefficient on the loss function
        (Default: 1.0)
      reuse: (OPTIONAL) bool or str.  If True, adds 'loss' as an additional
        tower for the existing loss function.

    Raises:
      ValueError: If reuse == True and name == None.
      ValueError: If reuse == True and seed != None.
      KeyError: If reuse == True and no existing LossFunction with 'name' found.
      KeyError: If reuse == False and existing LossFunction with 'name' found.
    """

    if self._finalized:
      raise ValueError("You cannot register additional losses or layers after "
                       "LayerCollection is finalized. Finalization happens "
                       "after the estimator or optimizer object first uses "
                       "the data in the LayerCollection. For example, when "
                       "the minimize() method is called in "
                       "PeriodicInvCovUpdateKfacOpt.")

    name = name or self._graph.unique_name(base_name)

    if reuse == VARIABLE_SCOPE:
      reuse = tf.get_variable_scope().reuse

    if reuse:
      if name is None:
        raise ValueError(
            "If reuse is enabled, loss function's name must be set.")

      loss_list = self._loss_dict.get(name, None)

      if loss_list is None:
        raise KeyError(
            "Unable to find loss function named {}. Register a new loss "
            "function with reuse=False.".format(name))

      if self.loss_coeffs[loss_list[0]] != coeff:
        raise ValueError(
            "Reused loss function's coeff didn't match previous supplied "
            "value.")

    else:
      if name in self._loss_dict:
        raise KeyError(
            "Loss function named {} already exists. Set reuse=True to append "
            "another tower.".format(name))

      loss_list = []
      self._loss_dict[name] = loss_list

    loss_list.append(loss)
    self.loss_colocation_ops[loss] = colocation_op
    self.loss_coeffs[loss] = coeff

  def _get_use_count_map(self):
    """Returns a dict mapping variables to their number of registrations."""
    return self._vars_to_uses

  def _add_uses(self, params, uses):
    """Register additional uses by params in the graph.

    Args:
      params: Variable or tuple of Variables. Parameters for a layer.
      uses: int or float. Number of additional uses for these parameters.
    """
    params = params if isinstance(params, (tuple, list)) else (params,)
    for var in params:
      self._vars_to_uses[var] += uses

  def check_registration(self, variables):
    """Checks that all variable uses have been registered properly.

    Args:
      variables: List of variables.

    Raises:
      ValueError: If any registered variables are not included in the list.
      ValueError: If any variable in the list is not registered.
      ValueError: If any variable in the list is registered with the wrong
          number of "uses" in the subgraph recorded (vs the number of times that
          variable is actually used in the subgraph).
    """
    # Note that overlapping parameters (i.e. those that share variables) will
    # be caught by layer_collection.LayerParametersDict during registration.

    reg_use_map = self._get_use_count_map()

    error_messages = []

    for var in variables:
      total_uses = self.subgraph.variable_uses(var)
      reg_uses = reg_use_map[var]

      if reg_uses == 0:
        error_messages.append("Variable {} not registered.".format(var))
      elif (not math.isinf(reg_uses)) and reg_uses != total_uses:
        error_messages.append(
            "Variable {} registered with wrong number of uses ({} uses "
            "registered vs {} uses found in sub-graph generated from "
            "registered losses).".format(var, reg_uses, total_uses))

    num_get_vars = len(reg_use_map)

    if num_get_vars > len(variables):
      error_messages.append("{} registered variables were not included in list."
                            .format(num_get_vars - len(variables)))

    if error_messages:
      error_string = "\n\t".join([
          "Found the following errors with variable registration:"
      ] + error_messages)
      raise ValueError(error_string)

  def get_blocks(self):
    return tuple(self.fisher_blocks.values())

  def get_factors(self):
    return tuple(self.fisher_factors.values())

  @property
  def graph(self):
    return self._graph

  @property
  def subgraph(self):
    return self._subgraph

  def define_linked_parameters(self, params, approximation=None):
    """Identify a set of parameters that should be grouped together.

    Also allows the approximation type string to be set for the given
    parameter grouping.

    During automatic graph scanning (as done by the auto_register_layers method)
    any matches containing variables that have been identified as part of a
    linked group will be filtered out unless the match parameters are exactly
    equal to the ones specified in the linked group.

    Args:
      params: A variable, or a tuple or list of variables. The variables
        to be linked.
      approximation: Optional string specifying the type of approximation to use
        for these variables. If unspecified, this layer collection's default
        approximation for the layer type will be used.

    Raises:
      ValueError: If the parameters were already registered in a layer or
        identified as part of an incompatible group.
    """
    params = frozenset(utils.ensure_sequence(params))

    # Check if any of the variables in 'params' is already in
    # 'self.fisher_blocks.keys()'.
    for registered_params, fisher_block in self.fisher_blocks.items():
      registered_params_set = set(utils.ensure_sequence(registered_params))
      for variable in params:
        if (variable in registered_params_set and
            params != registered_params_set):
          raise ValueError(
              "Can't link parameters {}, variable {} was already registered in "
              "group {} with layer {}".format(params, variable,
                                              registered_params, fisher_block))

    # Check if any of the variables in 'params' is already in
    # 'self.linked_parameters'.
    for variable in params:
      for other_linked_params in self.linked_parameters:
        if variable in other_linked_params:
          raise ValueError("Can't link parameters {}, variable {} was already "
                           "linked in group {}.".format(params, variable,
                                                        other_linked_params))
    self._linked_parameters[params] = approximation

  def _create_subgraph(self):
    if not self.losses:
      raise ValueError("Must have at least one registered loss.")
    inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses))
    self._subgraph = utils.SubGraph(inputs_to_losses)

  def eval_losses(self, target_mode="data", coeff_mode="regular"):
    """Returns evaluated losses (colocated with inputs to losses)."""
    evals = []
    for loss in self.losses:
      with tf.colocate_with(self.loss_colocation_ops[loss]):
        if target_mode == "data":
          loss_value = loss.evaluate()
        elif target_mode == "sample":
          loss_value = loss.evaluate_on_sample()
        else:
          raise ValueError("target_mode must be in ['data', 'sample']")

        if coeff_mode == "regular":
          multiplier = self.loss_coeffs[loss]
        elif coeff_mode == "sqrt":
          multiplier = tf.sqrt(self.loss_coeffs[loss])
        elif coeff_mode == "off":
          multiplier = 1.0
        else:
          raise ValueError("coeff_mode must be in ['regular', 'sqrt', 'off']")
        multiplier = tf.cast(multiplier, dtype=loss_value.dtype)
        evals.append(multiplier * loss_value)
    return evals

  def total_loss(self, coeff_mode="regular"):
    return tf.add_n(self.eval_losses(target_mode="data",
                                     coeff_mode=coeff_mode))

  def total_sampled_loss(self, coeff_mode="regular"):
    return tf.add_n(self.eval_losses(target_mode="sample",
                                     coeff_mode=coeff_mode))

  def _get_linked_approx(self, params):
    """If params were linked, return their specified approximation."""
    params_set = frozenset(utils.ensure_sequence(params))
    if params_set in self.linked_parameters:
      return self.linked_parameters[params_set]
    else:
      return None

  def _get_block_type(self, params, approx, default, approx_to_type):
    if approx is None:
      approx = self._get_linked_approx(params)
      if approx is None:
        approx = default

    if approx not in approx_to_type:
      raise ValueError("Bad value {} for approx.".format(approx))

    return approx_to_type[approx], approx

  def register_fully_connected(self,
                               params,
                               inputs,
                               outputs,
                               approx=None,
                               dense_inputs=True,
                               reuse=VARIABLE_SCOPE):
    """Registers a fully connected layer.

    Args:
      params: Variable or 2-tuple of variables corresponding to weight and
        bias parameters of this layer. Weight matrix should have shape
        [input_size, output_size]. Bias should have shape [output_size].
      inputs: Tensor. Two formats are accepted. In most cases the Tensor is
        dense inputs, with shape [batch_size, input_size]. In some cases
        the Tensor is sparse inputs, with shape [batch_size]. A typical example
        of sparse inputs is the vocab indices into an embedding matrix. Sparse
        inputs will be converted to the dense format within KFAC. For sparse
        inputs, dense_inputs should be set to False.
      outputs: Tensor of shape [batch_size, output_size]. Outputs
        produced by layer.
      approx: str or None. If not None must be one of "kron", "kron_in_diag"
        (diagonal approximation for the input kronecker factor), "kron_out_diag"
        (diagonal approximation for the output kronecker factor),
        "kron_both_diag" or "diagonal". The Fisher approximation to use. If
        None the default value is used. (Default: None)
      dense_inputs: bool. True if inputs are dense inputs. (Default: True)
      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
        additional mini-batch/tower of data to use when estimating the Fisher
        block for this layer (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")

    Raises:
      ValueError: For improper value to 'approx'.
      KeyError: If reuse == True but no FisherBlock found for 'params'.
      ValueError: If reuse == True and FisherBlock found but of the wrong type.
    """

    block_type, approx = self._get_block_type(
        params, approx, self.default_fully_connected_approximation,
        self._fully_connected_approx_to_block_types)

    has_bias = isinstance(params, (tuple, list))
    block = self._register_block(
        params, block_type(self, has_bias=has_bias), reuse=reuse)

    if not dense_inputs:
      inputs.one_hot_depth = int(params.shape[0])
    block.register_additional_tower(inputs, outputs)

    self._add_uses(params, 1)

  def register_conv1d(self,
                      params,
                      strides,
                      padding,
                      inputs,
                      outputs,
                      dilations=None,
                      approx=None,
                      reuse=VARIABLE_SCOPE,
                      sub_sample_inputs=None,
                      sub_sample_patches=None):
    """Registers a call to tf.nn.conv1d().

    Args:
      params: Variablle or 2-tuple of variables corresponding to weight and
        bias parameters this layer. Weight matrix should have shape
        [kernel_width, in_channels, out_channels].  Bias should have shape
        [out_channels].
      strides: List of 3 ints. Strides for convolution kernel.
      padding: string. see tf.nn.conv2d for valid values.
      inputs: Tensor of shape [batch_size, width, in_channels]. Inputs
        to layer.
      outputs: Tensor of shape [batch_size, width, out_channels].
        Output produced by layer.
      dilations: List of 3 ints. Dilations along each dimension.
      approx: str or None. If not None, must be "kron". The Fisher approximation
        to use. If None, the default value is used. (Default: None)
      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
        additional mini-batch/tower of data to use when estimating the Fisher
        block for this layer (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")
      sub_sample_inputs: `bool`. If True, then subsample the inputs from which
        the image patches are extracted. (Default: None)
      sub_sample_patches: `bool`, If `True` then subsample the extracted
        patches. (Default: None)

    Raises:
      KeyError: If reuse == True but no FisherBlock found for 'params'.
      ValueError: If reuse == True and FisherBlock found but of the wrong type.
    """
    assert approx is None or approx == APPROX_KRONECKER_NAME

    block = self._register_block(
        params,
        fb.ConvKFCBasicFB(
            layer_collection=self,
            params=params,
            padding=padding,
            strides=strides,
            data_format="NWC",
            dilation_rate=dilations,
            extract_patches_fn="extract_convolution_patches",
            sub_sample_inputs=sub_sample_inputs,
            sub_sample_patches=sub_sample_patches,
            use_sua_approx_for_input_factor=False),
        reuse=reuse)
    block.register_additional_tower(inputs, outputs)

    self._add_uses(params, 1)

  def register_conv2d(self,
                      params,
                      strides,
                      padding,
                      inputs,
                      outputs,
                      data_format=None,
                      dilations=None,
                      approx=None,
                      reuse=VARIABLE_SCOPE,
                      sub_sample_inputs=None,
                      sub_sample_patches=None,
                      patch_mask=None):
    """Registers a call to tf.nn.conv2d().

    Args:
      params: Variable or 2-tuple of variables corresponding to weight and
        bias parameters of this layer. Weight matrix should have shape
        [kernel_height, kernel_width, in_channels, out_channels].  Bias should
        have shape [out_channels].
      strides: List of 4 ints. Strides for convolution kernel.
      padding: string. see tf.nn.conv2d for valid values.
      inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
        to layer.
      outputs: Tensor of shape [batch_size, height, width, out_channels].
        Output produced by layer.
      data_format: str or None. Format of data. If None, this should default
        to 'NWHC'. (Default: None)
      dilations: List of 4 ints. Dilations along each dimension.
      approx: str or None. If not None must be one of "kron" or "diagonal".
        The Fisher approximation to use. If None the default value is used.
        (Default: None)
      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
        additional mini-batch/tower of data to use when estimating the Fisher
        block for this layer (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")
      sub_sample_inputs: `bool`. If True, then subsample the inputs from which
        the image patches are extracted. (Default: None)
      sub_sample_patches: `bool`, If `True` then subsample the extracted
        patches. (Default: None)
      patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels]
        or None. If not None this is multiplied against the extracted patches
        Tensor (broadcasting along the batch dimension) before statistics are
        computed. This can (and probably should) be used if the filter bank
        matrix is masked in a way that is homogenous across the output channels.
        (Other masking patterns have no direct support.) Currently only works
        with the approx="kron" or "diagonal". (Default: None)
    Raises:
      ValueError: For improper value to 'approx'.
      KeyError: If reuse == True but no FisherBlock found for 'params'.
      ValueError: If reuse == True and FisherBlock found but of the wrong type.
    """
    assert data_format in [None, "NHWC"]  # We don't support NCHW right now

    block_type, approx = self._get_block_type(
        params, approx, self.default_conv2d_approximation,
        self._conv2d_approx_to_block_types)

    # It feels bad to pass in configuration that has to do with the internal
    # implementation.  And then we can't use the same constructor for both
    # anymore and are thus forced to use this ugly if-statement.
    # TODO(b/74793309): Clean this up?
    if approx == APPROX_KRONECKER_NAME:
      block = self._register_block(
          params,
          block_type(
              layer_collection=self,
              params=params,
              padding=padding,
              strides=strides,
              data_format=data_format,
              dilation_rate=dilations,
              extract_patches_fn="extract_image_patches",
              sub_sample_inputs=sub_sample_inputs,
              sub_sample_patches=sub_sample_patches,
              use_sua_approx_for_input_factor=False,
              patch_mask=patch_mask),
          reuse=reuse)
    elif approx == APPROX_DIAGONAL_NAME:
      assert strides[0] == strides[-1] == 1
      block = self._register_block(
          params,
          block_type(
              layer_collection=self,
              params=params,
              padding=padding,
              strides=strides,
              dilations=dilations,
              data_format=data_format,
              patch_mask=patch_mask),
          reuse=reuse)
    elif approx == APPROX_KRONECKER_SUA_NAME:
      block = self._register_block(
          params,
          block_type(
              layer_collection=self,
              params=params,
              padding=padding,
              use_sua_approx_for_input_factor=True),
          reuse=reuse)

    else:
      raise NotImplementedError(approx)

    block.register_additional_tower(inputs, outputs)

    self._add_uses(params, 1)

  def register_convolution(self,
                           params,
                           inputs,
                           outputs,
                           padding,
                           strides=None,
                           dilation_rate=None,
                           data_format=None,
                           approx=None,
                           reuse=VARIABLE_SCOPE):
    """Register a call to tf.nn.convolution().

    Unless you know what you are doing you should be using register_conv2d
    instead.

    Args:
      params: Variable or 2-tuple of variables corresponding to weight and
        bias parameters of this layer. Weight matrix should have shape
        [..filter_spatial_size.., in_channels, out_channels].  Bias should have
        shape [out_channels].
      inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels].
        Inputs to layer.
      outputs: Tensor of shape [batch_size, ..output_spatial_size..,
        out_channels].  Output produced by layer.
      padding: string. see tf.nn.conv2d for valid values.
      strides: List of ints of length len(..input_spatial_size..). Strides for
        convolution kernel in spatial dimensions.
      dilation_rate: List of ints of length len(..input_spatial_size..).
        Dilations along spatial dimension.
      data_format: str or None. Format of data.
      approx: str or None. If not None, must be "kron". The Fisher approximation
        to use. If None, the default value is used. (Default: None)
      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
        additional mini-batch/tower of data to use when estimating the Fisher
        block for this layer (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")

    Raises:
      ValueError: For improper value to 'approx'.
      KeyError: If reuse == True but no FisherBlock found for 'params'.
      ValueError: If reuse == True and FisherBlock found but of the wrong type.
    """
    # TODO(b/74793309): Have this use _get_block_type like the other
    # registration functions?
    assert approx is None or approx == APPROX_KRONECKER_NAME

    block = self._register_block(
        params,
        fb.ConvKFCBasicFB(
            layer_collection=self,
            params=params,
            padding=padding,
            strides=strides,
            dilation_rate=dilation_rate,
            data_format=data_format),
        reuse=reuse)
    block.register_additional_tower(inputs, outputs)

    self._add_uses(params, 1)

  def register_depthwise_conv2d(self,
                                params,
                                inputs,
                                outputs,
                                strides,
                                padding,
                                rate=None,
                                data_format=None,
                                approx=None,
                                reuse=VARIABLE_SCOPE):
    """Register a call to tf.nn.depthwise_conv2d().

    Note that this is an experimental feature that hasn't been experimentally
    validated or published on.

    Args:
      params: 4-D variable of shape [filter_height, filter_width, in_channels,
        channel_multiplier].  Convolutional filter.
      inputs: Tensor of shape [batch_size, input_height, input_width,
        in_channels].  Inputs to layer.
      outputs: Tensor of shape [batch_size, output_height, output_width,
        in_channels * channel_multiplier].  Output produced by depthwise conv2d.
      strides: List of ints of length 4. Strides along all dimensions.
      padding: string. see tf.nn.conv2d for valid values.
      rate: None or List of ints of length 2. Dilation rates in spatial
        dimensions.
      data_format: str or None. Format of data.
      approx: str or None. If not None must "diagonal".  The Fisher
        approximation to use. If None the default value is used. (Default: None)
      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
        additional mini-batch/tower of data to use when estimating the Fisher
        block for this layer (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")

    Raises:
      ValueError: For improper value to 'approx'.
      KeyError: If reuse == True but no FisherBlock found for 'params'.
      ValueError: If reuse == True and FisherBlock found but of the wrong type.
    """
    # TODO(b/74793309): Have this use _get_block_type like the other
    # registration functions?
    assert approx is None or approx == APPROX_DIAGONAL_NAME
    assert data_format in [None, "NHWC"]

    block = self._register_block(
        params,
        fb.DepthwiseConvDiagonalFB(
            layer_collection=self,
            params=params,
            strides=strides,
            padding=padding,
            rate=rate,
            data_format=data_format),
        reuse=reuse)
    block.register_additional_tower(inputs, outputs)

    self._add_uses(params, 1)

  def register_separable_conv2d(self,
                                depthwise_params,
                                pointwise_params,
                                inputs,
                                depthwise_outputs,
                                pointwise_outputs,
                                strides,
                                padding,
                                rate=None,
                                data_format=None,
                                approx=None,
                                reuse=VARIABLE_SCOPE):
    """Register a call to tf.nn.separable_conv2d().

    Note: This requires access to intermediate outputs between depthwise and
    pointwise convolutions.

    Note that this is an experimental feature that hasn't been experimentally
    validated or published on.

    Args:
      depthwise_params: 4-D variable of shape [filter_height, filter_width,
        in_channels, channel_multiplier].  Filter for depthwise conv2d.
      pointwise_params: 4-D variable of shape [1, 1, in_channels *
        channel_multiplier, out_channels].  Filter for pointwise conv2d.
      inputs: Tensor of shape [batch_size, input_height, input_width,
        in_channels].  Inputs to layer.
      depthwise_outputs: Tensor of shape [batch_size, output_height,
        output_width, in_channels * channel_multiplier].  Output produced by
        depthwise conv2d.
      pointwise_outputs: Tensor of shape [batch_size, output_height,
        output_width, out_channels].  Output produced by pointwise conv2d.
      strides: List of ints of length 4. Strides for depthwise conv2d kernel in
        all dimensions.
      padding: string. see tf.nn.conv2d for valid values.
      rate: None or List of ints of length 2. Dilation rate of depthwise conv2d
        kernel in spatial dimensions.
      data_format: str or None. Format of data.
      approx: str or None. If not None must be one of "kron" or "diagonal".
        The Fisher approximation to use. If None the default value is used.
        (Default: None)
      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
        additional mini-batch/tower of data to use when estimating the Fisher
        block for this layer (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")

    Raises:
      ValueError: For improper value to 'approx'.
      KeyError: If reuse == True but no FisherBlock found for 'params'.
      ValueError: If reuse == True and FisherBlock found but of the wrong type.
    """
    self.register_depthwise_conv2d(
        params=depthwise_params,
        inputs=inputs,
        outputs=depthwise_outputs,
        strides=strides,
        padding=padding,
        rate=rate,
        data_format=data_format,
        approx=APPROX_DIAGONAL_NAME,
        reuse=reuse)

    self.register_conv2d(
        params=pointwise_params,
        inputs=depthwise_outputs,
        outputs=pointwise_outputs,
        strides=[1, 1, 1, 1],
        padding="VALID",
        data_format=data_format,
        approx=approx,
        reuse=reuse)

  def register_generic(self,
                       params,
                       batch_size,
                       approx=None,
                       reuse=VARIABLE_SCOPE):
    """Registers parameters without assuming any structure.

    Note that this is an approximation of last resort and should be avoided if
    anything else will work.

    Args:
      params: Variable or tuple of variables corresponding to the parameters.
        If using "diagonal" approximation this must be a single variable.
      batch_size: 0-D Tensor. Size of the minibatch (for this tower).
      approx: str or None. It not None, must be one of "full" or "diagonal".
        The Fisher approximation to use. If None the default value is used.
        (Default: None)
      reuse: bool or str. If True, this adds 'batch_size' to the total
        mini-batch size use when estimating the Fisher block for this layer
        (which must have already been registered). If "VARIABLE_SCOPE", use
        tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")

    Raises:
      ValueError: For improper value to 'approx'.
      KeyError: If reuse == True but no FisherBlock found for 'params'.
      ValueError: If reuse == True and FisherBlock found but of the wrong type.
      ValueError: If approx == "diagonal" and params is a tuple.
    """
    block_type, approx = self._get_block_type(
        params, approx, self.default_generic_approximation,
        self._generic_approx_to_block_types)

    if approx == APPROX_DIAGONAL_NAME and isinstance(params, (tuple, list)):
      raise ValueError("Params must be a Variable if using the diagonal "
                       "approximation.")

    block = self._register_block(params, block_type(self, params), reuse=reuse)
    block.register_additional_tower(batch_size)

    self._add_uses(params, float("inf"))

  def register_fully_connected_multi(self, params, inputs, outputs,
                                     num_uses=None, approx=None,
                                     dense_inputs=True, reuse=VARIABLE_SCOPE):
    """Register fully connected layers with shared parameters.

    This can handle general fully-connected layers with shared parameters, but
    has specialized approximations to deal with the case where there is a
    meaningful linear order to the share instances (such as in an RNN).

    Note that padding is *not* supported. The arguments to this method cannot
    be zero-padded or anything of that sort.

    Args:
      params: Variable or 2-tuple of variables corresponding to weight and
        bias of this layer. Weight matrix should have shape [input_size,
        output_size]. Bias should have shape [output_size].
      inputs: A list of Tensors or a single Tensor. Inputs to this layer. If a
        list of Tensors, the list indexes each use in the model (which might
        correspond to a "time-step" in an RNN). Each Tensor in the list has
        leading dimension batch_size. If a single Tensor, the leading dimension
        would be num_uses * batch_size, which is a reshaped version of the list
        of Tensors. Similar to register_fully_connected(), two formats of
        tensors are accepted: dense inputs and sparse inputs. In most cases
        the Tensors are dense inputs, with shape [batch_size, input_size] (if a
        list) or [num_uses * batch_size , input_size] (if a single Tensor).
        In some cases the Tensors are sparse inputs, with shape [batch_size] (if
        a list) or [num_uses * batch_size] (if a single Tensor). A typical
        example of sparse inputs is the vocab indices into an embedding matrix.
        Sparse inputs will be converted to the dense format within KFAC. For
        sparse inputs, dense_inputs should be set to False.
      outputs: A list of Tensors, the same length as 'inputs', each of shape
        [batch_size, output_size]. Outputs produced by layer. The list indexes
        each use in the model (which might correspond to a "time-step" in an
        RNN). Needs to correspond with the order used in 'inputs'.  OR, can be
        a single Tensor of shape [num_uses * batch_size, output_size], which is
        a reshaped version of a Tensor of shape [num_uses, batch_size,
        output_size].
      num_uses: int or None. The number uses/time-steps in the model where the
        layer appears. Only needed if both inputs and outputs are given in the
        single Tensor format. (Default: None)
      approx: str or None. If not None, must be one of "kron_indep",
        "kron_indep_in_diag" (diagonal approximation for the input kronecker
        factor), "kron_indep_out_diag" (diagonal approximation for the output
        kronecker factor), "kron_indep_both_diag", "kron_series_1" or
        "kron_series_2". The Fisher approximation to use. If None the default
        value is used (which starts out as "kron_indep"). (Default: None)
      dense_inputs: bool. True if inputs are dense inputs. (Default: True)
      reuse: bool or str.  If True, this adds inputs and outputs as an
        additional mini-batch/tower of data to use when estimating the Fisher
        block for this layer (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.  (Note that the
        word 'use' here has a completely different meaning to "use in the model"
        as it pertains to the 'inputs', 'outputs', and 'num_uses' arguments.)
        (Default: "VARIABLE_SCOPE")

    Raises:
      ValueError: For improper value to 'approx'.
    """
    block_type, approx = self._get_block_type(
        params, approx, self.default_fully_connected_multi_approximation,
        self._fully_connected_multi_approx_to_block_types)

    # TODO(b/70283649): something along the lines of find_canonical_output
    # should be added back in here (and for the other block types, arguably).
    has_bias = isinstance(params, (tuple, list))
    block = self._register_block(
        params,
        block_type(self, has_bias=has_bias, num_uses=num_uses),
        reuse=reuse)

    if isinstance(inputs, (tuple, list)):
      inputs = tuple(inputs)
    if isinstance(outputs, (tuple, list)):
      outputs = tuple(outputs)

    if not dense_inputs:
      if isinstance(inputs, (tuple, list)):
        for input in inputs:
          input.one_hot_depth = int(params.shape[0])
      else:
        inputs.one_hot_depth = int(params.shape[0])

    block.register_additional_tower(inputs, outputs)
    if isinstance(inputs, (tuple, list)):
      assert len(inputs) == len(outputs)
      self._add_uses(params, len(inputs))
    else:
      self._add_uses(params, 1)

  def register_conv2d_multi(self,
                            params,
                            strides,
                            padding,
                            inputs,
                            outputs,
                            num_uses=None,
                            data_format=None,
                            dilations=None,
                            approx=None,
                            reuse=VARIABLE_SCOPE):
    """Registers convolutional layers with shared parameters.

    Note that padding is *not* supported. The arguments to this method cannot
    be zero-padded or anything of that sort.

    Args:
      params: Variable or 2-tuple of variables corresponding to weight and
        bias of this layer. Weight matrix should have shape [kernel_height,
        kernel_width, in_channels, out_channels].  Bias should have shape
        [out_channels].
      strides: 1-D Tensor of length 4. Strides for convolution kernel.
      padding: string. see tf.nn.conv2d for valid values.
      inputs: A list of Tensors, each of shape [batch_size, height, width,
        in_channels]. Inputs to layer. The list indexes each use in the model
        (which might correspond to a "time-step" in an RNN). OR, can be single
        Tensor, of shape [num_uses * batch_size, height, width, in_channels],
        which is a reshaped version of a Tensor of shape [num_uses, batch_size,
        height, width, in_channels].
      outputs: A list of Tensors, each of shape [batch_size, height, width,
        out_channels]. Output produced by layer. The list indexes each use
        in the model (which might correspond to a "time-step" in an RNN).
        Needs to correspond with the order used in 'inputs'.  OR, can be a
        single Tensor, of shape [num_uses * batch_size, height, width,
        out_channels], which is a reshaped version of a Tensor of shape
        [num_uses, batch_size, height, width, out_channels].
      num_uses: int or None. The number uses/time-steps in the model where the
        layer appears. Only needed if both inputs and outputs are given in the
        single Tensor format. (Default: None)
      data_format: str or None. Format of data.
      dilations: List of 4 ints. Dilations along each dimension.
      approx: str or None. If not None must be "kron_indep". The Fisher
        approximation to use. If None the default value is used (which starts
        out as "kron_indep"). (Default: None)
      reuse: bool or str.  If True, this adds inputs and outputs as an
        additional mini-batch/tower of data to use when estimating the Fisher
        block for this layer (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.  (Note that the
        word 'use' here has a completely different meaning to "use in the model"
        as it pertains to the 'inputs', 'outputs', and 'num_uses' arguments.)
        (Default: "VARIABLE_SCOPE")

    Raises:
      ValueError: For improper value to 'approx'.
      KeyError: If reuse == True but no FisherBlock found for 'params'.
      ValueError: If reuse == True and FisherBlock found but of the wrong type.
    """
    assert data_format in [None, "NHWC"]  # We don't support NCHW right now

    block_type, approx = self._get_block_type(
        params, approx, self.default_conv2d_multi_approximation,
        self._conv2d_multi_approx_to_block_types)

    block = self._register_block(
        params,
        block_type(
            layer_collection=self,
            params=params,
            padding=padding,
            strides=strides,
            data_format=data_format,
            dilation_rate=dilations,
            extract_patches_fn="extract_image_patches",
            num_uses=num_uses),
        reuse=reuse)

    if isinstance(inputs, (tuple, list)):
      inputs = tuple(inputs)
    if isinstance(outputs, (tuple, list)):
      outputs = tuple(outputs)

    block.register_additional_tower(inputs, outputs)
    if isinstance(inputs, (tuple, list)):
      assert len(inputs) == len(outputs)
      self._add_uses(params, len(inputs))
    else:
      self._add_uses(params, 1)

  def register_scale_and_shift(self,
                               params,
                               inputs,
                               outputs,
                               approx=None,
                               reuse=VARIABLE_SCOPE):
    """Registers a scale and shift operation.

    A scale and shift operation is a parameterized operation of the form

    outputs = scale * inputs + shift ,

    where scale and shift are variables that broadcast to the shape of inputs.

    outputs and inputs must have batch dimension. scale and shift can have
    a corresponding dimension (although they don't need to), but it must
    be 1.

    These kinds of operations appear frequently in various "normalization"
    layers like Layer Normalization. Batch Normalization layers should still
    be registered as "generic".

    Note that this is an experimental feature that hasn't been experimentally
    validated or published on.

    Args:
      params: Variable or 2-tuple of Variables corresponding to the scale
        and possibly shift parameters (scale must be first).  Note that if
        these have a dimension corresponding to the batch dimension of 'inputs'
        and 'outputs', that dimension must be 1.
      inputs: Tensor of shape [batch_size, ...]. Input tensor that is multiplied
        by the scale the scale tensor.
      outputs: Tensor of shape [batch_size, ...]. Final output produced by the
        scale and shift. Must have the same shape as 'inputs'.
      approx: str or None. If not None must be one of "full" or "diagonal".
        The Fisher approximation to use. If None the default value is used.
        (Default: None)
      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
        additional mini-batch/tower of data to use when estimating the Fisher
        block for this layer (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")

    Raises:
      ValueError: For improper value to 'approx'.
      KeyError: If reuse == True but no FisherBlock found for 'params'.
      ValueError: If reuse == True and FisherBlock found but of the wrong type.
    """
    # TODO(jamesmartens): Consider replacing some of the logic below with calls
    # to tf.broadcast_static_shape.
    if isinstance(params, (tuple, list)):
      scale = params[0]
      shift = params[1]

      has_shift = True

      start_dim = len(outputs.shape) - len(shift.shape)
      if start_dim < 0:
        raise ValueError("Rank of shift cannot exceed that of outputs.")
      if start_dim == 0 and shift.shape[0] != 1:
        raise ValueError("If shift has a batch dimension its value must be 1.")
      broadcast_dims_shift = list(range(1, start_dim))
      for i in range(max(start_dim, 1), len(outputs.shape)):
        if shift.shape[i - start_dim] < outputs.shape[i]:
          if shift.shape[i - start_dim] == 1:
            broadcast_dims_shift.append(i)
          else:
            raise ValueError("It appears that shift param and output have "
                             "incompatible shapes. This is probably due to "
                             "misspecified arguments.")
        elif shift.shape[i - start_dim] > outputs.shape[i]:
          raise ValueError("It appears that shift param and output have "
                           "incompatible shapes. This is probably due to "
                           "misspecified arguments.")
      broadcast_dims_shift = tuple(broadcast_dims_shift)
    else:
      has_shift = False
      scale = params
      broadcast_dims_shift = None

    start_dim = len(inputs.shape) - len(scale.shape)
    if start_dim < 0:
      raise ValueError("Rank of scale cannot exceed that of inputs.")
    if start_dim == 0 and scale.shape[0] != 1:
      raise ValueError("If scale has a batch dimension its value must be 1.")
    broadcast_dims_scale = list(range(1, start_dim))
    for i in range(max(start_dim, 1), len(inputs.shape)):
      if scale.shape[i - start_dim] < inputs.shape[i]:
        if scale.shape[i - start_dim] == 1:
          broadcast_dims_scale.append(i)
        else:
          raise ValueError("It appears that scale param and input have "
                           "incompatible shapes. This is probably due to "
                           "misspecified arguments.")
    broadcast_dims_scale = tuple(broadcast_dims_scale)

    block_type, approx = self._get_block_type(
        params, approx, self.default_scale_and_shift_approximation,
        self._scale_and_shift_approx_to_block_types)

    block = self._register_block(params, block_type(
        self,
        broadcast_dims_scale,
        broadcast_dims_shift=broadcast_dims_shift,
        has_shift=has_shift),
                                 reuse=reuse)
    block.register_additional_tower(inputs, outputs)

    self._add_uses(params, 1)

  def register_categorical_predictive_distribution(self,
                                                   logits,
                                                   seed=None,
                                                   targets=None,
                                                   name=None,
                                                   coeff=1.0,
                                                   reuse=VARIABLE_SCOPE):
    """Registers a categorical predictive distribution.

    Corresponds to losses computed using
    tf.nn.sparse_softmax_cross_entropy_with_logits.

    Note that this is distinct from
    register_multi_bernoulli_predictive_distribution and should not be confused
    with it.

    Args:
      logits: The logits of the distribution (i.e. its parameters). The first
        dimension must be the batch size.
      seed: The seed for the RNG (for debugging) (Default: None)
      targets: (OPTIONAL) The targets for the loss function.  Only required if
        one wants to use the "empirical Fisher" instead of the true Fisher
        (which is controlled by the 'estimation_mode' to the optimizer).
        (Default: None)
      name: (OPTIONAL) str or None. Unique name for this loss function. If None,
        a new name is generated. (Default: None)
      coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the
        log prob loss associated with this distribution. The Fisher will be
        multiplied by the corresponding factor. This is NOT equivalent to
        changing the temperature of the distribution since we don't renormalize
        the log prob in the objective function. (Default: 1.0)
      reuse: (OPTIONAL) bool or str.  If True, this adds 'logits' as an
        additional mini-batch/tower of inputs to the loss-function/predictive
        distribution (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")
    """
    loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,
                                                   seed=seed)
    self._register_loss_function(loss, logits,
                                 "categorical_predictive_distribution",
                                 name=name, coeff=coeff, reuse=reuse)

  def register_softmax_cross_entropy_loss(self,
                                          logits,
                                          seed=None,
                                          targets=None,
                                          name=None,
                                          coeff=1.0,
                                          reuse=VARIABLE_SCOPE):
    """Registers a softmax cross-entropy loss function.

    Corresponds to losses computed using
    tf.nn.sparse_softmax_cross_entropy_with_logits.

    Note that this is distinct from register_sigmoid_cross_entropy_loss and
    should not be confused with it. It is similar to
    register_categorical_predictive_distribution but without the explicit
    probabilistic interpretation. It behaves identically for now.

    Args:
      logits: The logits of the distribution (i.e. its parameters). The first
        dimension must be the batch size.
      seed: The seed for the RNG (for debugging) (Default: None)
      targets: (OPTIONAL) The targets for the loss function.  Only required if
        one wants to use the "empirical Fisher" instead of the true Fisher
        (which is controlled by the 'estimation_mode' to the optimizer).
        (Default: None)
      name: (OPTIONAL) str or None. Unique name for this loss function. If None,
        a new name is generated. (Default: None)
      coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the
        loss function by. (Default: 1.0)
      reuse: (OPTIONAL) bool or str.  If True, this adds 'logits' as an
        additional mini-batch/tower of inputs to the loss-function/predictive
        distribution (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")
    """
    loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,
                                                   seed=seed)
    self._register_loss_function(loss, logits,
                                 "sparse_softmax_cross_entropy_loss",
                                 name=name, coeff=coeff, reuse=reuse)

  def register_normal_predictive_distribution(self,
                                              mean,
                                              var=0.5,
                                              seed=None,
                                              targets=None,
                                              name=None,
                                              coeff=1.0,
                                              reuse=VARIABLE_SCOPE):
    """Registers a normal predictive distribution.

    This corresponds to a squared error loss of the form
       coeff/(2*var) * ||target - mean||^2

    Args:
      mean: A tensor defining the mean vector of the distribution. The first
        dimension must be the batch size.
      var: float. The variance of the distribution. Note that the default value
        of 0.5 corresponds to a standard squared error loss coeff*||target -
        prediction||^2. If you want your squared error loss to be of the form
        0.5*coeff*||target - prediction||^2 you should use var=1.0.
        (Default: 0.5)
      seed: The seed for the RNG (for debugging) (Default: None)
      targets: (OPTIONAL) The targets for the loss function.  Only required if
        one wants to use the "empirical Fisher" instead of the true Fisher
        (which is controlled by the 'estimation_mode' to the optimizer).
        (Default: None)
      name: (OPTIONAL) str or None. Unique name for this loss function. If None,
        a new name is generated. (Default: None)
      coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the
        log prob loss associated with this distribution. The Fisher will be
        multiplied by the corresponding factor. In general this is NOT
        equivalent to changing the temperature of the distribution, but in the
        case of normal distributions it may be. (Default: 1.0)
      reuse: (OPTIONAL) bool or str.  If True, this adds 'mean' and 'var' as an
        additional mini-batch/tower of inputs to the loss-function/predictive
        distribution (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")
    """
    loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets,
                                            seed=seed)
    self._register_loss_function(loss, mean,
                                 "normal_predictive_distribution",
                                 name=name, coeff=coeff, reuse=reuse)

  def register_squared_error_loss(self,
                                  prediction,
                                  seed=None,
                                  targets=None,
                                  name=None,
                                  coeff=1.0,
                                  reuse=VARIABLE_SCOPE):
    """Registers a squared error loss function.

    This assumes the squared error loss of the form ||target - prediction||^2,
    averaged across the mini-batch. If your loss uses a coefficient of 0.5
    (tf.nn.l2_loss does this, for example) you need to set the "coeff" argument
    to reflect this.

    Args:
      prediction: The prediction made by the network (i.e. its output). The
        first dimension must be the batch size.
      seed: The seed for the RNG (for debugging) (Default: None)
      targets: (OPTIONAL) The targets for the loss function.  Only required if
        one wants to use the "empirical Fisher" instead of the true Fisher
        (which is controlled by the 'estimation_mode' to the optimizer).
        (Default: None)
      name: (OPTIONAL) str or None. Unique name for this loss function. If None,
        a new name is generated. (Default: None)
      coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the
        loss function by. (Default: 1.0)
      reuse: (OPTIONAL) bool or str.  If True, this adds 'prediction' as an
        additional mini-batch/tower of inputs to the loss-function/predictive
        distribution (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")
    """
    loss = lf.NormalMeanNegativeLogProbLoss(prediction, var=0.5,
                                            targets=targets,
                                            seed=seed)
    self._register_loss_function(loss, prediction,
                                 "squared_error_loss",
                                 name=name, coeff=coeff, reuse=reuse)

  def register_multi_bernoulli_predictive_distribution(self,
                                                       logits,
                                                       seed=None,
                                                       targets=None,
                                                       name=None,
                                                       coeff=1.0,
                                                       reuse=VARIABLE_SCOPE):
    """Registers a multi-Bernoulli predictive distribution.

    Corresponds to losses computed using
    tf.nn.sigmoid_cross_entropy_with_logits.

    Note that this is distinct from
    register_categorical_predictive_distribution and should not be confused
    with it.


    Args:
      logits: The logits of the distribution (i.e. its parameters). The first
        dimension must be the batch size.
      seed: The seed for the RNG (for debugging) (Default: None)
      targets: (OPTIONAL) The targets for the loss function.  Only required if
        one wants to use the "empirical Fisher" instead of the true Fisher
        (which is controlled by the 'estimation_mode' to the optimizer).
        (Default: None)
      name: (OPTIONAL) str or None. Unique name for this loss function. If None,
        a new name is generated. (Default: None)
      coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the
        log prob loss associated with this distribution. The Fisher will be
        multiplied by the corresponding factor. This is NOT equivalent to
        changing the temperature of the distribution since we don't renormalize
        the log prob in the objective function. (Default: 1.0)
      reuse: (OPTIONAL) bool or str.  If True, this adds 'logits' as an
        additional mini-batch/tower of inputs to the loss-function/predictive
        distribution (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")
    """
    loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets,
                                                seed=seed)
    self._register_loss_function(loss, logits,
                                 "multi_bernoulli_predictive_distribution",
                                 name=name, coeff=coeff, reuse=reuse)

  def register_sigmoid_cross_entropy_loss(self,
                                          logits,
                                          seed=None,
                                          targets=None,
                                          name=None,
                                          coeff=1.0,
                                          reuse=VARIABLE_SCOPE):
    """Registers a sigmoid cross-entropy loss function.

    Corresponds to losses computed using
    tf.nn.sigmoid_cross_entropy_with_logits.

    Note that this is distinct from register_softmax_cross_entropy_loss and
    should not be confused with it. It is similar to
    register_multi_bernoulli_predictive_distribution but without the explicit
    probabilistic interpretation. It behaves identically for now.

    Args:
      logits: The logits tensor. The first dimension must be the batch size.
      seed: The seed for the RNG (for debugging) (Default: None)
      targets: (OPTIONAL) The targets for the loss function.  Only required if
        one wants to use the "empirical Fisher" instead of the true Fisher
        (which is controlled by the 'estimation_mode' to the optimizer).
        (Default: None)
      name: (OPTIONAL) str or None. Unique name for this loss function. If None,
        a new name is generated. (Default: None)
      coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the
        loss function by. (Default: 1.0)
      reuse: (OPTIONAL) bool or str.  If True, this adds 'logits' as an
        additional mini-batch/tower of inputs to the loss-function/predictive
        distribution (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")
    """
    loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets,
                                                seed=seed)
    self._register_loss_function(loss, logits,
                                 "sigmoid_cross_entropy_loss",
                                 name=name, coeff=coeff, reuse=reuse)

  def make_or_get_factor(self, cls, args):
    """Insert 'cls(args)' into 'self.fisher_factors' if not already present.

    Wraps constructor in 'tf.variable_scope()' to ensure variables constructed
    in 'cls.__init__' are placed under this LayerCollection's scope.

    Args:
      cls: Class that implements FisherFactor.
      args: Tuple of arguments to pass into 'cls's constructor. Must be
        hashable.

    Returns:
      Instance of 'cls' found in self.fisher_factors.
    """
    # TODO(b/123190346): Should probably change the args list to be keyworded
    # instead of positional.  Note that this would require making changes in
    # each FisherBlock's call to make_or_get_factor.
    try:
      hash(args)
    except TypeError:
      raise TypeError(
          ("Unable to use (cls, args) = ({}, {}) as a key in "
           "LayerCollection.fisher_factors. The pair cannot be hashed.").format(
               cls, args))

    key = cls, args
    if key not in self.fisher_factors:
      with tf.variable_scope(self._var_scope):
        self.fisher_factors[key] = cls(*args)
    return self.fisher_factors[key]

  @contextmanager
  def as_default(self):
    """Sets this LayerCollection as the default."""
    set_default_layer_collection(self)
    yield
    set_default_layer_collection(None)
