#!/usr/bin/env python3
import json
from redbaron import RedBaron
import traceback
import sys
from collections.abc import Iterable
import os
import re
from pathlib import Path

ABI = False

scratches_added = {"root":{}}
txnargs_ = {}
refargs_ = {}

def iterable(obj):
    return isinstance(obj, Iterable)
    
def catchdump(x):
  try:    
    return x.dumps()      
  except:
    pass  

def hasMultiStmt(node):
  l =[]
  try:
    for nd in node:
      nodes = nd.find_all(['ifelseblock','atomtrailers', 'return', 'while'], recursive=False)
      nodes = nodes.filter(lambda x: x != None)
      l.extend(nodes)
    return len(l) > 1
  except:
    pass

def getScratch(red):
  if not hasattr(red, 'find_all'):
    return ''
  s = red.find_all('assign')
  if isinstance(s, list):
    s = s.filter(lambda x: 'ScratchVar' in x)
  
  s2 = red.find_all('name')
  #if isinstance(s2, list):
  s2 = s2.filter(lambda x: x.dumps() == 'maybe' or 'asset_bal' in x.dumps())
  #else:
  #  sys.stderr.write('not filtering s2')
  if s or s2:
    s3 = []
    if s: s3 = s
    if s2: 
      s3 = s3 + s2
    return '\n'.join(s3.map(lambda x: x.dumps()))
  else:
    return ''

def removeblanks(ml):
  lines = ml.splitlines()
  lines = filter(lambda l: l != '' and l.strip() != '', lines)
  return '\n'.join(lines)

def hasScratch(red, varname):
  parentdef = None
  if hasattr(red, 'parentdef'):
    parentdef = red.parentdef
  else:
    parentdef = red.parent_find('def')
    red.parentdef = parentdef
  if parentdef is None:
    parentdef = { 'name': 'root'}
    
  if parentdef.name in scratches_added and varname in scratches_added[parentdef.name]:
    return True
  else:
    if varname in scratches_added['root']:
      return True
    else:
      if varname == 'maybe':
        return '\nmaybe\n' not in parentdef.dumps()
      else:
        return False

def add_root_scratches(red):
  gs = red.find_all('g:*', lambda v: v.root is v.parent and v.type != 'def')
  for nd in gs:
    if 'ScratchVar' in nd.dumps():
      scratches_added['root'][str(nd.name)] = True
    
def findmap(red, which, how, def_=None):
  l = red.find_all(which)
  if True or def_ is None:
    l.map(how)
  else:
    l.map(lambda nd: how(nd, def_))

def decs(d):
  if d.value.dumps() == 'bytes':
    d.parent.decorators.pop()
    d.parent.decorators.append('@Subroutine(TealType.bytes)')

def all(red, x = ''):
  try:
    hasMulti = hasMultiStmt(red)
    if hasattr(red,'type'):
      if red.type == 'def' and red.name != 'app' and red.name != 'sig':
        if len(red.decorators) == 0:
          if 'return' in red.dumps() or 'Return' in red.dumps():
            red.decorators.append("@Subroutine(uint64)")
          else:
            red.decorators.append("@Subroutine(TealType.none)")
      defnode = None
      if red.type == 'def': defnode = red
      stat('assign..')
      findmap(red, 'assign', assigns, red)      
      stat('bool..')
      findmap(red, 'boolean_operator', bools)
      stat('binary_operator..')
      findmap(red, 'binary_operator', concats)
      if hasMulti or red.type == 'while':
        stat('return..')
        findmap(red, 'return..', returns)
      findmap(red, 'int', ints)
      findmap(red, 'string', strings)
      stat('ifelseblock..')
      findmap(red, 'ifelseblock', ifs)
      if red.type != 'while':      
        findmap(red, 'while', whiles)
      findmap(red, 'call', calls)
      findmap(red, 'name', names)
    strs_ = []
    #hasMulti = hasMultiStmt(red)
    scratch = getScratch(red)
    if hasattr(red,'type') and red.type == 'name':
      return
    try:    
      if not hasattr(red, 'title') and not (red.type == 'binary_operator') and hasattr(red, '__iter__'):
        for nd in red:
          if not hasattr(nd, 'find_all'): continue
          nodes = nd.find_all(['ifelseblock','atomtrailers', 'return', 'while'], recursive=False)
          if hasMulti:
            findmap(nodes, 'return', returns)
          
          strs = nodes.map(catchdump)
          strs = strs.filter(lambda x: x != None)
          strs_.extend(strs)
    except BaseException as err:
      sys.stderr.write('looperr')
      sys.stderr.write(f"Unexpected {err=}, {type(err)=}")  
      sys.stderr.write(traceback.format_exc())  
      pass
    if len(strs_) > 1:
      strs_ = filter(lambda x: x != '', strs_)
      try:
        strs_ = removeblanks(strs_)
      except:
        pass
      strlist = ',\n\t'.join(strs_)      
      retstr = ''
      if red.type == 'def':
        retstr = f"{scratch}\nreturn "
      red.value = f"{retstr} Seq(\n\t{strlist} )\n"    
      red.value = removeblanks(red.value.dumps())

      #####
    elif len(strs_) == 1:
      #if True or red.find('ifelseblock') or red.find('while') or red.find('assert'):

      ########################3
      
      if not ('return' in red.value.dumps()) and ( red.find('ifelseblock') or red.find('while') or red.type == 'assert' or red.find('assign')  ):              
        red.value =f"{scratch}\nreturn {strs_[0]}\n".replace('return return', 'return').replace('return Return','return').replace('maybe\n  maybe','maybe\n')
  except BaseException as err:
    sys.stderr.write(f"Unexpected {err=}, {type(err)=}")  
    sys.stderr.write(traceback.format_exc())
    pass  

def bools(boolOp):
  all(boolOp.first)
  all(boolOp.second)
  try:
    boolOp.replace(f"{boolOp.value.title()}( {boolOp.first}, {boolOp.second} )")
  except:
    pass
  #  print(boolOp.help())
  #boolOp.parent.value = f"{boolOp.value.title()}( {boolOp.first}, {boolOp.second} )"

def unitaries(un):
  all(un.value)
  try:
    un.replace(f"{un.value.title()}( {un.target.dumps()} )")
  except:
    pass
  #  print(boolOp.help())
  #boolOp.parent.value = f"{boolOp.value.title()}( {boolOp.first}, {boolOp.second} )"
  
def concats(ct):
  all(ct.first)
  all(ct.second)
  if ct.first.type == 'string' or ct.second.type == 'string':
    ct.replace(f"Concat({ct.first.dumps()},{ct.second.dumps()})")
  try:
    if (ct.first.value[0].value == 'Bytes' or ct.second.value[0].value == 'Bytes' or 
       ct.first.value[0].value == 'Concat' or ct.second.value[0].value == 'Concat'):
      ct.replace(f"Concat({ct.first.dumps()},{ct.second.dumps()})")  
  except:
    pass

def assigns(asn, parentdef_ = None):  
  dumps = asn.value.dumps()  
  if 'ScratchVar' in dumps or 'StringArray' in dumps or 'DynamicArray' in dumps or 'maybe' in asn.target.dumps():
    return
  if parentdef_ is None:
    parentdef = asn.parent_find('def')
  else:
    parentdef = parentdef_
  if parentdef is None:
    return
  parval = asn.value
  if hasattr(parval, 'filtered'):
    parval = parval.filtered()[0]
  if hasattr(parval, 'type'):  
    if parval.type == 'boolean' or parval.type == 'comparison' or parval.type in 'boolean_operator':
      return
    if hasattr(parval, 'value'):
      if parval.value in ['And', 'Or', 'Not']:
        return
  makescratch = ''
  
  is_str = "'" in dumps or '"' in dumps
  if not hasScratch(asn, asn.target.value):
    if is_str:
      typ = 'TealType.bytes'
    else:
      typ = 'TealType.uint64'
    makescratch = f"{asn.target.value} = ScratchVar({typ})\n"
    #asn.insert_before(makescratch)
    parentdef.insert(0,makescratch)
    
    if not parentdef.name in scratches_added:
      scratches_added[parentdef.name] = {}    
      
    scratches_added[parentdef.name][asn.target.value] = True
    
  asn.replace(f"{asn.target.value}.store({asn.value.dumps()})")

def calls(nd):
  #all(nd.value)
  try:
    if nd.type != 'call':
      return
    if nd.parent[0].value in ['Begin','SetField','SetFields','Submit']:          
      nd.parent[0].replace(f"InnerTxnBuilder.{nd.parent[0].dumps()}")
  
  except:
    pass

def txnargs(df):
  i = 0
  toremove = []
  for arg in df.arguments:    
    if arg.name.value in ['pay', 'axfer', 'optin', 'acfg', 'txn', 'keyreg', 'afrz', 'appl']:
      if not (df.name in txnargs_):
        txnargs_[df.name] = []
      vall = arg.name.value
      if arg.name.value == 'optin': vall = 'axfer'
      txnargs_[df.name].append(vall)
      toremove.append(arg)      
    i += 1
  if len(toremove)>0:    
    for argn in toremove:
      df.arguments.remove(argn)

def refargs(df):
  i = 0
  toremove = []
  for arg in df.arguments:    
    if arg.name.value in ['account', 'asset', 'application']:
      if not (df.name in refargs_):
        refargs_[df.name] = []      
      refargs_[df.name].append((df.name, arg.name.value, i ))
      #toremove.append(arg)
    i += 1
  #for argn in toremove:
  #  df.arguments.remove(argn)

def prints(p):
  p.replace(f"Log{p.value.dumps()}")

def names(nd):
  if nd.parent.type == 'assignment' and nd.parent.target == nd:
    return
    
  f = nd.find('name', lambda n: n.value=='False')
  if f:
    f.replace('Int(0)')
    return
  t = nd.find('name', lambda n: n.value=='True')
  if t:
    t.replace('Int(1)')
    return
  m = nd.find('name', lambda n: n.value=='maybe')
  if m:
    return
    
  if nd.dumps()+'.load()' in nd.parent.dumps() or 'maybe' in nd.parent.dumps():
    return
  if nd.dumps()+'.store(' in nd.parent.dumps():
    return

  varname = nd.dumps()

  try:
    if hasScratch(nd, varname) or varname == 'size':
      nd.replace(f"{nd.dumps()}.load()")
    else:
      try:
        if nd.value in ['fee','sender','first_valid','last_valid','receiver','note','lease', 'round', 
                              'amount', 'close_remainder_to', 'vote_pk', 'type', 'type_enum', 'xfer_asset', 'asset_amount',
                               'asset_sender', 'asset_receiver', 'asset_close_to', 'group_index', 'tx_id', 'application_id',
                               'on_completion', 'rekey_to', 'config_asset', 'config_asset_total', 'config_asset_decimals',
                               'config_asset_default_frozen', 'config_asset_unit_name','config_asset_name', 'config_asset_url',
                               'config_asset_manager', 'created_asset_id', 'created_application_id', 'current_application_address',
                               'group_size', 'application_args', 'latest_timestamp']:

          if nd.parent.value[0].value in ['Txn', 'Gtxn', 'Global']:
            try:
              if nd.parent.parent.value[2].type == 'getitem':
                return
            except:
              pass
            if not (f"{varname}()" in nd.parent.dumps()):
              nd.parent.replace(nd.parent.dumps()+'()')            
      except:
        pass  
  except:
    pass

def whiles(w):
  all(w)
  w.replace(f"While( {w.test.dumps()}).Do(\n{w.value.dumps()} )")
  
def ifs(if_):
  all(if_.value[0], 'if')
  try:
    all(if_.value[1], 'if')
  except:
    pass
  else_ = ''
  if len(if_.value) == 2:
    findmap(if_.value,'return',returns)
    else_ = ', ' + if_.value[1].value.dumps()
  try:  
    if_.value = f"If( {if_.value[0].test.dumps()}, {if_.value[0].value.dumps()}{else_} )\n\n \n"
  except:
    pass

def returns(ret):
  all(ret.value)
  ret.replace(f"Return( {ret.value.dumps().replace('return ','')} )")

def fixreturns(d):
  try:
    if d[0][0].value == 'Return':
      d[0] = f"return {d[0][1].dumps()}"
  except:
    pass

def addreturns(d):
  try:
    if d[0][0].value != 'return':
      d[0] = f"return {d[0].dumps()}"
  except:
    pass

def ints(i):
  try:
    if i.parent.type == 'def_argument' or i.parent.type == 'getitem':
      return
    if hasattr(i.parent.parent.parent, 'value') and hasattr(i.parent.parent.parent.value, 'pop'):
      parname = i.parent.parent.parent.value[0].value
      if parname == 'Int' or parname == 'Arg' or parname in ['application_args', 'Gtxn']:
        return
    try:      
      if hasattr(i.parent.parent, 'value') and hasattr(i.parent.parent.value, 'pop'):
        if i.parent.parent.value[1].value in ['application_args', 'Gtxn']:
          return
    except:
      pass
  except BaseException as err:
    sys.stderr.write(' int outer exception ' + i.dumps())
    sys.stderr.write(f"Unexpected {err=}, {type(err)=}")  
    sys.stderr.write(traceback.format_exc())      
    pass
    return
    
  i.replace(f"Int({i.value})")

def strings(i):
  try:
    if i.parent.type == 'def_argument':
      return
    parval = i.parent.parent.parent.value[0].value
    if parval == 'Bytes' or parval == 'Addr':
      return
  except:
    pass
  i.replace(f"Bytes({i.dumps()})")

progroot = 0


def abimethod(d):
  if hasattr(d, 'return_annotation') and d.return_annotation is not None:
    todel = -1
    n = 0
    for dd in d.decorators:
      if 'Subroutine' in dd.dumps():
        todel = n
      n = n + 1
        
    if todel >= 0:
      del d.decorators[todel]
    d.decorators.insert(0,'@staticmethod')
    d.decorators.insert(1,'@ABIMethod')

def app_ops(d):
  if d.name in ['delete', 'create', 'update', 'optIn', 'closeOut', 'clearState']:
    del d.decorators[0]
    d.decorators.insert(0,'@staticmethod')
    d.decorators.insert(1, '@Subroutine(TealType.uint64)')

def isSigDef(nd):
  try:
    if nd.type == 'def' and nd.name=='sig':
      return True
  except:
    pass
  return False

def stat(s):
  #sys.stderr.write(f"{s} ")  
  sys.stderr.write(f".")  
  sys.stderr.flush()


opslist = ['delete', 'create', 'update', 'optIn', 'closeOut', 'clearState']

def dumpglobals(red):
  gs = red.find_all('g:*', lambda v: v.root is v.parent and v.type != 'def')
  return "\n".join(gs.map(lambda n: n.dumps()))

def dump_nonabi_defs(red):
  gs = red.find_all('def', lambda v: v.return_annotation is None and not (v.name in opslist))
  return "\n".join(gs.map(lambda n: n.dumps()))
  
def dump_abi_defs(red):
  gs = red.find_all('def', lambda v: not (v.return_annotation is None))
  return "\n".join(gs.map(lambda n: n.dumps()))

def dump_app_ops(d):
  gs = d.find_all('def', lambda v: v.name in opslist)
  return "\n".join(gs.map(lambda n: n.dumps()))

def dropReturns(currsource):  
  return currsource
  newsource = currsource
  while 'Return(' in newsource:
    newsource = re.sub(r'Return(\([\s\S]*\))', r'\1 ', newsource) 
  while 'Return (' in newsource:
    newsource = re.sub(r'Return (\([\s\S]*\))', r'\1 ', newsource)
  
  return newsource

def abiapp(red):
  print("""
import json
from typing import Tuple

from pyteal import *

from pytealutils.applications import ABIMethod, DefaultApprove
from pytealutils.applications.application import set_tx_args

from pytealutils import abi

globals().update(TealType.__members__)

""")

  print(dumpglobals(red))
  
  print(simpleReplaces(dropReturns(dump_nonabi_defs(red))))

  print('set_tx_args(' + json.dumps(txnargs_) + ')\n')
  
  print("class ABIApp(DefaultApprove):")
  red.increase_indentation(2)

  dfs = dropReturns(dump_abi_defs(red))
  
  print(dfs)

  print(dump_app_ops(red))

  fname = Path(sys.argv[1]).stem
  
  code = """

  
if __name__ == "__main__":

  app = ABIApp()

  currint = {}
  try:
    currjson = open('{fname}.json').read()
    currint = json.loads(currjson)    
  except:
    pass
    
  currint['methods'] = app.get_interface().dictify()['methods']
    
  with open("{fname}.json", "w") as f:
    f.write(json.dumps(currint, indent=4))

  print(compileTeal(app.handler(), mode=Mode.Application, version=5, assembleConstants=True))
  """
  code = code.replace('{fname}', fname)
  
  print(code)

def usesabi(red):
  defs_ = red.find_all('def')
  for nd in defs_:
    if hasattr(nd, 'return_annotation') and nd.return_annotation is not None:
      return True
  return False

def simpleReplaces(source):
  source = source.replace('len(', 'Len(')
  source = source.replace('lput(', 'App.localPut(0,')
  source = source.replace('gput(', 'App.globalPut(')
  source = source.replace('holding(', 'AssetHolding.balance(')
  source = source.replace('app_address','Global.current_application_address()')
  source = source.replace('optin', 'axfer')
  return source

def convert(fname):
  sys.stderr.write("genpyteal version 3.0.0\n")

  mode = 'Application'
  entry = 'app'
  source = open(fname, "r")
  source = source.read()
  source = simpleReplaces(source)
  red = RedBaron(source)
  stat('Parsed file..')

  isabi = usesabi(red)
    
  sig = red.filter(isSigDef)
  if sig: 
    mode = 'Signature'
    entry = 'sig'
    stat('LogicSig')
  else:
    stat('Application')
  progroot = red

  findmap(red, 'def', txnargs)
  #findmap(red, 'def', refargs)
 
  add_root_scratches(red)

  findmap(red, 'binary_operator', concats)

  stat('prints..')
  findmap(red, 'print', prints)
  findmap(red, 'name', names)
  stat('unitary_operator..')
  findmap(red, 'unitary_operator', unitaries)
  stat('return..')
  findmap(red, 'return', returns)  
  stat('def..')
  findmap(red, 'def', all)  
  stat('assign..')
  findmap(red, 'assign', assigns)
  stat('def..')

  findmap(red, 'def', fixreturns)
  stat('def..')
  findmap(red, 'def', addreturns)
  stat('decorator..')
  findmap(red, 'decorator', decs)
  stat('strings..')
  findmap(red, 'string', strings)
  sys.stderr.write(f"\n")    

  if isabi:
    findmap(red, 'def', abimethod)
    findmap(red, 'def', app_ops)
    abiapp(red)
    return

    
  print("from pyteal import *\n")
  print("globals().update(TealType.__members__)\n")

  print(simpleReplaces(dropReturns(red.dumps())))
  
  print(f"if __name__ == \"__main__\":\n    print(compileTeal({entry}(), mode=Mode.{mode}, version=5))")

convert(sys.argv[1])
