#!/usr/bin/env python3
import json
from redbaron import RedBaron
import traceback
import sys
from collections.abc import Iterable

def iterable(obj):
    return isinstance(obj, Iterable)
    
def catchdump(x):
  try:    
    return x.dumps()      
  except:
    pass  

def hasMultiStmt(node):
  l =[]
  try:
    for nd in node:
      nodes = nd.find_all(['ifelseblock','atomtrailers', 'return', 'while'], recursive=False)
      nodes = nodes.filter(lambda x: x != None)
      l.extend(nodes)
    return len(l) > 1
  except:
    pass

def getScratch(red):
  if not hasattr(red, 'find_all'):
    return ''
  s = red.find_all('assign')
  if isinstance(s, list):
    s = s.filter(lambda x: 'ScratchVar' in x)
  if s:
    return '\n'.join(s.map(lambda x: x.dumps()))
  else:
    return ''

def removeblanks(ml):
  lines = ml.splitlines()
  lines = filter(lambda l: l != '' and l.strip() != '', lines)
  return '\n'.join(lines)

def hasScratch(red, varname):
  parentdef = None
  if hasattr(red, 'parentdef'):
    parentdef = red.parentdef
  else:
    parentdef = red.parent_find('def')
    red.parentdef = parentdef
  if parentdef is None:
    return False
  s = parentdef.find_all('assign')
  match = f"{varname} = ScratchVar"
  if match in s.dumps():
    return True
  else:
    return False

def findmap(red, which, how):
  l = red.find_all(which)
  l.map(how)

def decs(d):
  if d.value.dumps() == 'bytes':
    d.parent.decorators.pop()
    d.parent.decorators.append('@Subroutine(TealType.bytes)')

def all(red, x = ''):
  try:
    hasMulti = hasMultiStmt(red)
    if hasattr(red,'type'):
      if red.type == 'def' and red.name != 'app' and red.name != 'sig':
        if len(red.decorators) == 0:
          if 'return' in red.dumps() or 'Return' in red.dumps():
            red.decorators.append("@Subroutine(uint64)")
          else:
            red.decorators.append("@Subroutine(TealType.none)")
      findmap(red, 'assign', assigns)      
      findmap(red, 'boolean_operator', bools)
      findmap(red, 'binary_operator', concats)
      if hasMulti or red.type == 'while':
        findmap(red, 'return', returns)
      findmap(red, 'int', ints)
      findmap(red, 'string', strings)
      findmap(red, 'ifelseblock', ifs)
      if red.type != 'while':      
        findmap(red, 'while', whiles)
      findmap(red, 'call', calls)
      findmap(red, 'name', names)
    strs_ = []
    #hasMulti = hasMultiStmt(red)
    scratch = getScratch(red)
    if hasattr(red,'type') and red.type == 'name':
      return
    try:      
      for nd in red:
        nodes = nd.find_all(['ifelseblock','atomtrailers', 'return', 'while'], recursive=False)
        if hasMulti:
          findmap(nodes, 'return', returns)
        
        strs = nodes.map(catchdump)
        strs = strs.filter(lambda x: x != None)
        strs_.extend(strs)
    except:
      if x == 'if':
        sys.stderr.write('looperr')
        sys.stderr.write(f"Unexpected {err=}, {type(err)=}")  
        sys.stderr.write(traceback.format_exc())  
      pass
    if len(strs_) > 1:
      strs_ = filter(lambda x: x != '', strs_)
      try:
        strs_ = removeblanks(strs_)
      except:
        pass
      strlist = ',\n\t'.join(strs_)      
      retstr = ''
      if red.type == 'def':
        retstr = f"{scratch}\nreturn "
      red.value = f"{retstr} Seq([\n\t{strlist} ])\n"    
      red.value = removeblanks(red.value.dumps())
    elif len(strs_) == 1:
      #if True or red.find('ifelseblock') or red.find('while') or red.find('assert'):
      if not ('return' in red.value.dumps()) and ( red.find('ifelseblock') or red.find('while') or red.type == 'assert' or red.find('assign')  ):      
        red.value =f"{scratch}\nreturn {strs_[0]}\n".replace('return return', 'return').replace('return Return','return')
  except BaseException as err:
    sys.stderr.write(f"Unexpected {err=}, {type(err)=}")  
    sys.stderr.write(traceback.format_exc())
    pass  

def bools(boolOp):
  all(boolOp.first)
  all(boolOp.second)
  try:
    boolOp.replace(f"{boolOp.value.title()}( {boolOp.first}, {boolOp.second} )")
  except:
    pass
  #  print(boolOp.help())
  #boolOp.parent.value = f"{boolOp.value.title()}( {boolOp.first}, {boolOp.second} )"

def unitaries(un):
  all(un.value)
  try:
    un.replace(f"{un.value.title()}( {un.target.dumps()} )")
  except:
    pass
  #  print(boolOp.help())
  #boolOp.parent.value = f"{boolOp.value.title()}( {boolOp.first}, {boolOp.second} )"
  
def concats(ct):
  all(ct.first)
  all(ct.second)
  if ct.first.type == 'string' or ct.second.type == 'string':
    ct.replace(f"Concat({ct.first.dumps()},{ct.second.dumps()})")
  try:
    if ct.first.value[0].value == 'Bytes' or ct.second.value[0].value == 'Bytes':
      ct.replace(f"Concat({ct.first.dumps()},{ct.second.dumps()})")  
  except:
    pass

def assigns(asn):  
  if 'ScratchVar' in asn.value.dumps():
    return
  if asn.parent_find('def') is None:
    return
  #if hasattr(asn, 'asnChecked'):
  #  return
  parval = asn.value
  #asn.asnChecked = True
  if hasattr(parval, 'filtered'):
    parval = parval.filtered()[0]
  if hasattr(parval, 'type'):  
    if parval.type == 'boolean' or parval.type == 'comparison' or parval.type in 'boolean_operator':
      return
    if hasattr(parval, 'value'):
      if parval.value in ['And', 'Or', 'Not']:
        return
  makescratch = ''
  is_str = "'" in asn.value.dumps() or '"' in asn.value.dumps()
  if not hasScratch(asn, asn.target.value):
    if is_str:
      typ = 'TealType.bytes'
    else:
      typ = 'TealType.uint64'
    makescratch = f"{asn.target.value} = ScratchVar({typ})\n"
    asn.insert_before(makescratch)
  asn.replace(f"{asn.target.value}.store({asn.value.dumps()})")

def calls(nd):
  #all(nd.value)
  try:
    if nd.type != 'call':
      return
    if nd.parent[0].value in ['Begin','SetField','SetFields','Submit']:          
      nd.parent[0].replace(f"InnerTxnBuilder.{nd.parent[0].dumps()}")
  
  except:
    pass

def prints(p):
  p.replace(f"Log{p.value.dumps()}")

def names(nd):
  if nd.parent.type == 'assignment' and nd.parent.target == nd:
    return
    
  f = nd.find('name', lambda n: n.value=='False')
  if f:
    f.replace('Int(0)')
    return
  t = nd.find('name', lambda n: n.value=='True')
  if t:
    t.replace('Int(1)')
    return
    
  if nd.dumps()+'.load()' in nd.parent.dumps():
    return
  if nd.dumps()+'.store(' in nd.parent.dumps():
    return

  varname = nd.dumps()

  try:
    if hasScratch(nd, varname):
      nd.replace(f"{nd.dumps()}.load()")
    else:
      try:
        if nd.value in ['fee','sender','first_valid','last_valid','receiver','note','lease', 'round', 
                              'amount', 'close_remainder_to', 'vote_pk', 'type', 'type_enum', 'xfer_asset', 'asset_amount',
                               'asset_sender', 'asset_receiver', 'asset_close_to', 'group_index', 'tx_id', 'application_id',
                               'on_completion', 'rekey_to', 'config_asset', 'config_asset_total', 'config_asset_decimals',
                               'config_asset_default_frozen', 'config_asset_unit_name','config_asset_name', 'config_asset_url',
                               'config_asset_manager', 'created_asset_id', 'created_application_id', 'current_application_address',
                               'group_size', 'application_args']:

          if nd.parent.value[0].value in ['Txn', 'Gtxn', 'Global']:
            try:
              if nd.parent.parent.value[2].type == 'getitem':
                return
            except:
              pass
            if not (f"{varname}()" in nd.parent.dumps()):
              nd.parent.replace(nd.parent.dumps()+'()')            
      except:
        pass  
  except:
    pass

def whiles(w):
  all(w)
  w.replace(f"While( {w.test.dumps()}).Do(\n{w.value.dumps()} )")
  
def ifs(if_):
  all(if_.value[0], 'if')
  try:
    all(if_.value[1], 'if')
  except:
    pass
  else_ = ''
  if len(if_.value) == 2:
    findmap(if_.value,'return',returns)
    else_ = ', ' + if_.value[1].value.dumps()
  try:  
    if_.value = f"If( {if_.value[0].test.dumps()}, {if_.value[0].value.dumps()}{else_} )\n\n"
  except:
    pass

def returns(ret):
  all(ret.value)
  ret.replace(f"Return( {ret.value.dumps().replace('return ','')} )")

def fixreturns(d):
  try:
    if d[0][0].value == 'Return':
      d[0] = f"return {d[0][1].dumps()}"
  except:
    pass

def addreturns(d):
  try:
    if d[0][0].value != 'return':
      d[0] = f"return {d[0].dumps()}"
  except:
    pass

def ints(i):
  try:
    if i.parent.type == 'def_argument':
      return
    parname = i.parent.parent.parent.value[0].value
    if parname == 'Int' or parname == 'Arg' or parname in ['application_args', 'Gtxn']:
      return
    try:
      if i.parent.parent.value[1].value in ['application_args', 'Gtxn']:
        return
    except:
      pass
  except:
    pass
    
  i.replace(f"Int({i.value})")

def strings(i):
  try:
    if i.parent.type == 'def_argument':
      return
    parval = i.parent.parent.parent.value[0].value
    if parval == 'Bytes' or parval == 'Addr':
      return
  except:
    pass
  i.replace(f"Bytes({i.dumps()})")

progroot = 0

def isSigDef(nd):
  try:
    if nd.type == 'def' and nd.name=='sig':
      return True
  except:
    pass
  return False

def convert(fname):
  mode = 'Application'
  entry = 'app'
  source = open(fname, "r")
  red = RedBaron(source.read())
  sig = red.filter(isSigDef)
  if sig: 
    mode = 'Signature'
    entry = 'sig'
  progroot = red
  findmap(red, 'print', prints)
  findmap(red, 'unitary_operator', unitaries)
  findmap(red, 'return', returns)  
  findmap(red, 'def', all)  
  findmap(red, 'assign', assigns)
  findmap(red, 'def', fixreturns)
  findmap(red, 'def', addreturns)
  findmap(red, 'decorator', decs)
  print("from pyteal import *\n")
  print("globals().update(TealType.__members__)\n")

  print(red.dumps())
  print(f"if __name__ == \"__main__\":\n    print(compileTeal({entry}(), mode=Mode.{mode}, version=5))")

convert(sys.argv[1])
