#!/usr/bin/env python3
"""
Extract subtree of NCBI taxonomy under a node

Usage:
  ntmirror-extract-subtreee [options] <dbuser> <dbpass> <dbname> <dbsocket> <root>

Arguments:
  dbuser:       database user to use
  dbpass:       password of the database user
  dbname:       database name
  dbsocket:     connection socket file
  root:         tax id of the root

Options:
{common}
  --dbecho         turn on SQLAlchemy engine echo
                   (default: off)
"""

import os
from sqlalchemy import create_engine
from ntmirror.dbschema import NtNode
from schema import Use, And
from sqlalchemy.orm import sessionmaker, aliased
from sqlalchemy.engine.url import URL
import snacli
from ntmirror import scripts_helper

def main(args):
  connection_string = URL.create(drivername="mysql+mysqldb",
                    username=args["<dbuser>"],
                    password=args["<dbpass>"],
                    database=args["<dbname>"],
                    host="localhost",
                    query={"unix_socket":args["<dbsocket>"]})
  engine = create_engine(connection_string, echo=args["--dbecho"],
                         future=True)
  Session = sessionmaker(bind=engine)
  session = Session()
  subtree = session.query(NtNode)
  subtree = subtree.filter(NtNode.tax_id==args["<root>"])
  subtree = subtree.cte(name="subtree", recursive=True)
  parent = aliased(subtree, name="parent")
  children = aliased(NtNode, name="children")
  subtree = subtree.union(
      session.query(children).\
          filter(children.parent_tax_id == parent.c.tax_id))
  result = session.query(NtNode).select_entity_from(subtree).all()
  for r in result:
    print(r.tax_id)

def check_db_args(args):
  if not(args["<dbsocket>"]):
    raise RuntimeError("DB unix socket not provided")
  elif not os.path.exists(args["<dbsocket>"]):
    raise RuntimeError(f"DB unix socket does not exist: {args['<dbsocket>']}")
  elif not(args["<dbuser>"]):
    raise RuntimeError("Database user name not provided")
  elif not(args["<dbpass>"]):
    raise RuntimeError("Database password for user '{}' not provided".\
                      format(args["<dbuser>"]))
  elif not(args["<dbname>"]):
    raise RuntimeError("Database name not provided")

def validated(args):
  check_db_args(args)
  args = scripts_helper.validate(args, scripts_helper.ARGS_SCHEMA,
                  {"<root>": And(str, len, Use(int)),
                   "<dbuser>": And(str, len),
                   "<dbpass>": And(str, len),
                   "<dbname>": And(str, len),
                   "<dbsocket>": And(str, len, os.path.exists)})
  return args

with snacli.args(params = ["<root>"],
                 config=["<dbuser>", "<dbpass>", "<dbname>"],
                 input=["<dbsocket>"],
                 docvars={"common": scripts_helper.ARGS_DOC},
                 version="1.0") as args:
  if args: main(validated(args))
