# coding=utf-8
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

u"""The variational ansatz class."""

from __future__ import absolute_import
from typing import Iterable, Optional, Sequence, Tuple

import numpy

import cirq
from cirq import abc
from itertools import izip


class VariationalAnsatz(object):
    __metaclass__ = abc.ABCMeta
    u"""A variational ansatz.

    A variational ansatz is a parameterized circuit. The VariationalAnsatz class
    stores parameters as instances of the Symbol class. A Symbol is simply a
    named object that can be used in a circuit and whose numerical value is
    determined at run time. The Symbols are stored in a dictionary whose keys
    are the names of the corresponding parameters. For instance, the Symbol
    corresponding to the parameter 'theta_0' would be obtained with the
    expression `self.params['theta_0']`.

    Attributes:
        params: A dictionary storing the parameters by name. Key is the
            string name of a parameter and the corresponding value is a Symbol
            with the same name.
        circuit: The ansatz circuit.
        qubits: A list containing the qubits used by the ansatz circuit.
    """

    def __init__(self, qubits=None):
        u"""
        Args:
            qubits: Qubits to be used by the ansatz circuit. If not specified,
                then qubits will automatically be generated by the
                `_generate_qubits` method.
        """
        self.qubits = qubits or self._generate_qubits()

        # Generate the ansatz circuit
        self.circuit = cirq.Circuit.from_ops(
                self.operations(self.qubits),
                strategy=cirq.InsertStrategy.EARLIEST)

    @abc.abstractmethod
    def params(self):
        u"""The parameters of the ansatz."""
        pass

    def param_bounds(self):
        u"""Optional bounds on the parameters.

        Returns a list of tuples of the form (low, high), where low and high
        are lower and upper bounds on a parameter. The order of the tuples
        corresponds to the order of the parameters as yielded by the
        `params` method.
        """
        return None

    def param_resolver(self, param_values):
        u"""Interprets parameters input as an array of real numbers."""
        # Default: leave the parameters unchanged
        return cirq.ParamResolver(
                dict(izip((param.name for param in self.params()),
                         param_values)))

    def default_initial_params(self):
        u"""Suggested initial parameter settings."""
        # Default: zeros
        return numpy.zeros(len(list(self.params())))

    @abc.abstractmethod
    def operations(self, qubits):
        u"""Produce the operations of the ansatz circuit.

        The operations should use Symbols produced by the `params` method
        of the ansatz.
        """
        pass

    @abc.abstractmethod
    def _generate_qubits(self):
        u"""Produce qubits that can be used by the ansatz circuit."""
        pass

    # TODO also need to consider mode permutation
    def qubit_permutation(self, qubits
                          ):
        u"""The qubit permutation induced by the ansatz circuit.

        An ansatz circuit may induce a permutation on its qubits. For example,
        an ansatz that applies interactions using an odd number of swap networks
        will reverse the order of the qubits. Keeping track of the qubit
        ordering is important for composing circuit primitives and calculating
        properties of the final state.
        """
        # Default: identity permutation
        return qubits
