#!/usr/bin/env python3
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import logging
import os
import subprocess
import sys
import tarfile
import tempfile

import weather_sp

if __name__ == '__main__':
    site_pkg = weather_sp.__path__[0]

    try:
        from splitter_pipeline import cli
    except ImportError:
        # Install the subpackage.
        subprocess.check_call(
            f'{sys.executable} -m pip -q install -e {site_pkg}'.split())

        # Re-load sys.path
        import site
        from importlib import reload

        reload(site)

        # Re-attempt import. If this fails, the user probably has an older version of
        # the package already installed on their machine that breaks this process.
        # If that's the case, it's best to start from a clean virtual environment.
        try:
            from splitter_pipeline import cli
        except ImportError as e:
            raise ImportError(
                'please re-install package in a clean python environment.') from e

    if '-h' in sys.argv or '--help' in sys.argv or len(sys.argv) == 1:
        cli()
    else:
        with tempfile.TemporaryDirectory() as tmpdir:
            original_dir = os.getcwd()

            # Convert subpackage to a tarball
            os.chdir(site_pkg)
            subprocess.check_call(
                f'{sys.executable} ./setup.py -q sdist --dist-dir {tmpdir}'.split(),
            )

            os.chdir(original_dir)

            # Set tarball as extra packages for Beam.
            pkg_archive = glob.glob(os.path.join(tmpdir, '*.tar.gz'))[0]

            with tarfile.open(pkg_archive, 'r') as tar:
                assert any([f.endswith('.py') for f in
                            tar.getnames()]), 'extra_package must include python files!'

            # cleanup memory to prevent pickling error.
            tar = None
            weather_sp = None

            cli(['--extra_package', pkg_archive])
