"""Visitor(s) for walking ASTs.

This module contains broadly useful basic visitors. Visitors that are more
specialized to pytype are in visitors.py. If you see a visitor there that you'd
like to use, feel free to propose moving it here.
"""

from pytype.pytd import base_visitor
from pytype.pytd import pytd


class CanonicalOrderingVisitor(base_visitor.Visitor):
  """Visitor for converting ASTs back to canonical (sorted) ordering."""

  def __init__(self, sort_signatures=False):
    super().__init__()
    self.sort_signatures = sort_signatures

  def VisitTypeDeclUnit(self, node):
    return pytd.TypeDeclUnit(name=node.name,
                             constants=tuple(sorted(node.constants)),
                             type_params=tuple(sorted(node.type_params)),
                             functions=tuple(sorted(node.functions)),
                             classes=tuple(sorted(node.classes)),
                             aliases=tuple(sorted(node.aliases)))

  def VisitClass(self, node):
    # If we have a dataclass-like decorator we need to preserve the order of the
    # class attributes, otherwise inheritance will not work correctly.
    if any(x.name in ("attr.s", "dataclasses.dataclass")
           for x in node.decorators):
      constants = node.constants
    else:
      constants = sorted(node.constants)
    return pytd.Class(
        name=node.name,
        metaclass=node.metaclass,
        parents=node.parents,
        methods=tuple(sorted(node.methods)),
        constants=tuple(constants),
        decorators=tuple(sorted(node.decorators)),
        classes=tuple(sorted(node.classes)),
        slots=tuple(sorted(node.slots)) if node.slots is not None else None,
        template=node.template)

  def VisitFunction(self, node):
    # Typically, signatures should *not* be sorted because their order
    # determines lookup order. But some pytd (e.g., inference output) doesn't
    # have that property, in which case self.sort_signatures will be True.
    if self.sort_signatures:
      return node.Replace(signatures=tuple(sorted(node.signatures)))
    else:
      return node

  def VisitSignature(self, node):
    return node.Replace(
        template=tuple(sorted(node.template)),
        exceptions=tuple(sorted(node.exceptions)))

  def VisitUnionType(self, node):
    return pytd.UnionType(tuple(sorted(node.type_list)))


class ClassTypeToNamedType(base_visitor.Visitor):
  """Change all ClassType objects to NameType objects."""

  def VisitClassType(self, node):
    return pytd.NamedType(node.name)


class CollectTypeParameters(base_visitor.Visitor):
  """Visitor that accumulates type parameters in its "params" attribute."""

  def __init__(self):
    super().__init__()
    self._seen = set()
    self.params = []

  def EnterTypeParameter(self, p):
    if p.name not in self._seen:
      self.params.append(p)
      self._seen.add(p.name)


class ExtractSuperClasses(base_visitor.Visitor):
  """Visitor for extracting all superclasses (i.e., the class hierarchy).

  When called on a TypeDeclUnit, this yields a dictionary mapping pytd.Class
  to lists of pytd.Type.
  """

  def __init__(self):
    super().__init__()
    self._superclasses = {}

  def _Key(self, node):
    # This method should be implemented by subclasses.
    return node

  def VisitTypeDeclUnit(self, module):
    del module
    return self._superclasses

  def EnterClass(self, cls):
    parents = []
    for p in cls.parents:
      parent = self._Key(p)
      if parent is not None:
        parents.append(parent)
    self._superclasses[self._Key(cls)] = parents


class RenameModuleVisitor(base_visitor.Visitor):
  """Renames a TypeDeclUnit."""

  def __init__(self, old_module_name, new_module_name):
    """Constructor.

    Args:
      old_module_name: The old name of the module as a string,
        e.g. "foo.bar.module1"
      new_module_name: The new name of the module as a string,
        e.g. "barfoo.module2"

    Raises:
      ValueError: If the old_module name is an empty string.
    """
    super().__init__()
    if not old_module_name:
      raise ValueError("old_module_name must be a non empty string.")
    assert not old_module_name.endswith(".")
    assert not new_module_name.endswith(".")
    self._module_name = new_module_name
    self._old = old_module_name + "." if old_module_name else ""
    self._new = new_module_name + "." if new_module_name else ""

  def _MaybeNewName(self, name):
    """Decides if a name should be replaced.

    Args:
      name: A name for which a prefix should be changed.

    Returns:
      If name is local to the module described by old_module_name the
      old_module_part will be replaced by new_module_name and returned,
      otherwise node.name will be returned.
    """
    if not name:
      return name
    before, match, after = name.partition(self._old)
    if match and not before and "." not in after:
      return self._new + after
    else:
      return name

  def _ReplaceModuleName(self, node):
    new_name = self._MaybeNewName(node.name)
    if new_name != node.name:
      return node.Replace(name=new_name)
    else:
      return node

  def VisitClassType(self, node):
    new_name = self._MaybeNewName(node.name)
    if new_name != node.name:
      return pytd.ClassType(new_name, node.cls)
    else:
      return node

  def VisitTypeDeclUnit(self, node):
    return node.Replace(name=self._module_name)

  def VisitTypeParameter(self, node):
    new_scope = self._MaybeNewName(node.scope)
    if new_scope != node.scope:
      return node.Replace(scope=new_scope)
    return node

  VisitConstant = _ReplaceModuleName  # pylint: disable=invalid-name
  VisitAlias = _ReplaceModuleName  # pylint: disable=invalid-name
  VisitClass = _ReplaceModuleName  # pylint: disable=invalid-name
  VisitFunction = _ReplaceModuleName  # pylint: disable=invalid-name
  VisitStrictType = _ReplaceModuleName  # pylint: disable=invalid-name
  VisitNamedType = _ReplaceModuleName  # pylint: disable=invalid-name
