# Procaine is a REST client library for AICore.
# Copyright (C) 2022 Roman Kindruk

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import json
import re
import requests
from datetime import datetime
from dataclasses import dataclass
from urllib.parse import urlparse, urljoin

from authlib.integrations.requests_client import OAuth2Session


class Client:
    """Initializes the connection to AI API and obtains a Bearer token.

    :param ai_api_url: A URL to the AI API instance.
    :param auth: Dictionary with creadentials to obtain a Bearer token.
        Should have ``url``, ``clientid`` and ``clientsecret`` keys.
    :param rg: (optional) A name of a ``resource group`` to use with every call.

    Usage::

      >>> from procaine import aicore

      >>> auth = dict(url=AUTH_URL, clientid=CLIENT_ID, clientsecret=CLIENT_SECRET)
      >>> api = aiapi.Client(AI_API_URL, auth)

      >>> api.healthz()
      {'message': 'OK', 'status': 'READY'}
    """

    def __init__(self, ai_api_url, auth, rg="default"):
        self.url = ai_api_url
        self.rg = rg
        self.authurl = urlparse(auth["url"]).netloc

        r = OAuth2Session(auth["clientid"], auth["clientsecret"]).fetch_token(
            auth["url"] + "/oauth/token"
        )
        self.sess = requests.Session()
        self.sess.headers.update(
            {
                "Authorization": "Bearer %s" % r["access_token"],
                "Content-Type": "application/json",
                "AI-Resource-Group": self.rg,
            }
        )
        self.sess.hooks = {"response": lambda r, *args, **kwargs: r.raise_for_status()}

    def __repr__(self):
        client, _, _ = self.authurl.partition(".")
        return f"<AIAPI client='{client}' api='{self.url}' rg='{self.rg}'>"

    def healthz(self):
        """
        Checks AI API service status.
        """
        return self.sess.get(urljoin(self.url, "/v2/lm/healthz")).json()

    def meta(self):
        """
        Returns AI API service metadata.
        """
        return self.sess.get(urljoin(self.url, "/v2/lm/meta")).json()

    def kpi(self):
        """
        Provides usage statistics.
        """
        return self.sess.get(urljoin(self.url, "/v2/analytics/kpis")).json()

    def list_s3_secrets(self):
        """
        Lists registered S3 secrets.
        """
        return self.sess.get(urljoin(self.url, "/v2/admin/objectStoreSecrets")).json()[
            "resources"
        ]

    def register_s3_secret(
        self,
        name,
        access_key_id,
        secret_access_key,
        bucket=None,
        prefix=None,
        endpoint=None,
        region=None,
    ):
        """
        Registers credentials to access S3 bucket.

        :param name: The name of a secret.
        :param access_key_id: Access key ID to access the bucket.
        :param secret_access_key: Secret access key to access the bucket.
        :param bucket: (optional) A bucket name.
        :param prefix: (optional) A key name prefix.
        :param endpoint: (optional) S3 service endpoint.
        :param region: (optional) A region of the bucket.
        """

        data = {
            "name": name,
            "type": "S3",
            "data": {
                "AWS_ACCESS_KEY_ID": access_key_id,
                "AWS_SECRET_ACCESS_KEY": secret_access_key,
            },
        }
        if bucket:
            data["bucket"] = bucket
        if endpoint:
            data["endpoint"] = endpoint
        if prefix:
            data["pathPrefix"] = prefix
        if region:
            data["region"] = region
        return self.sess.post(
            urljoin(self.url, "/v2/admin/objectStoreSecrets"), json=data
        ).json()

    def delete_s3_secret(self, name):
        """
        Deletes an object-store secret.
        """
        return self.sess.delete(
            urljoin(self.url, f"/v2/admin/objectStoreSecrets/{name}")
        ).json()

    def list_docker_registry_secrets(self):
        """
        Lists registered docker registry secrets.
        """
        return self.sess.get(
            urljoin(self.url, "/v2/admin/dockerRegistrySecrets")
        ).json()["resources"]

    def register_docker_registry_secret(self, name, registry, username, password):
        """
        Registers credentials to access a docker registry.

        :param name: The name of a secret.
        :param registry: A docker registry address.
        :param username: Username to access the registry.
        :param password: Password to access the registry.
        """

        secret = {
            "name": name,
            "data": {
                ".dockerconfigjson": json.dumps(
                    {
                        "auths": {
                            registry: {
                                "username": username,
                                "password": password,
                            }
                        }
                    }
                )
            },
        }
        return self.sess.post(
            urljoin(self.url, "/v2/admin/dockerRegistrySecrets"), json=secret
        ).json()

    def delete_docker_registry_secret(self, name):
        """
        Deletes a docker registry secret.
        """
        return self.sess.delete(
            urljoin(self.url, f"/v2/admin/dockerRegistrySecrets/{name}")
        ).json()

    def list_git_repositories(self):
        """
        Lists registered git repositories.
        """
        return self.sess.get(urljoin(self.url, "/v2/admin/repositories")).json()[
            "resources"
        ]

    def register_git_repository(self, name, url, username=None, password=None):
        """
        Registers credentials to access a git repository.

        :param name: The name of a secret.
        :param url: A URL to the git repository.
        :param username: (optional) Username to access the registry.
        :param password: (optional) Password to access the registry.
        """

        repo = {
            "name": name,
            "url": url,
            "username": username,
            "password": password,
        }
        return self.sess.post(
            urljoin(self.url, "/v2/admin/repositories"), json=repo
        ).json()

    def delete_git_repository(self, name):
        """
        Deletes a git repository secret.
        """
        return self.sess.delete(
            urljoin(self.url, f"/v2/admin/repositories/{name}")
        ).json()

    def list_applications(self):
        """Lists created sync applications."""
        return self.sess.get(urljoin(self.url, "/v2/admin/applications")).json()[
            "resources"
        ]

    def application_status(self, app):
        """Returns a sync status of the application.

        :param app: Application object, name or ID
        """
        if isinstance(app, dict):
            if "id" in app:
                app = app["id"]
            else:
                app = app["applicationName"]
        return self.sess.get(
            urljoin(self.url, f"/v2/admin/applications/{app}/status")
        ).json()

    def create_application(self, repo_name, path=".", ref="HEAD"):
        """Creates an application to sync git repository with AICore.

        :param repo_name: Name of a registered git repository.
        :param path: (optional) Directory in a repo to sync.
        :param ref: (optional) Git reference, i.e. branch, tag, etc.
        """
        name = repo_name
        if path and path != ".":
            s = re.sub("[^0-9a-zA-Z]+", "-", path.strip("/"))
            name = f"{name}-{s}"
        app = {
            "applicationName": name,
            "repositoryName": repo_name,
            "revision": ref,
            "path": path,
        }
        return self.sess.post(
            urljoin(self.url, "/v2/admin/applications"), json=app
        ).json()

    def delete_application(self, name):
        """Deletes an application."""
        return self.sess.delete(
            urljoin(self.url, f"/v2/admin/applications/{name}")
        ).json()

    def list_resource_groups(self):
        """Lists resource groups."""
        return self.sess.get(urljoin(self.url, "/v2/admin/resourceGroups")).json()[
            "resources"
        ]

    def list_scenarios(self):
        """Lists available scenarios."""
        return self.sess.get(urljoin(self.url, "/v2/lm/scenarios")).json()["resources"]

    def list_templates(self, scenario_id=None):
        """Lists templates (executables).

        :param scenario_id: (optional) Returns only templates belong to the scenario.
        """

        def get_executables(scenario_id):
            url = urljoin(self.url, f"/v2/lm/scenarios/{scenario_id}/executables")
            return self.sess.get(url).json()["resources"]

        if scenario_id:
            return get_executables(scenario_id)
        return [t for s in self.list_scenarios() for t in get_executables(s["id"])]

    def template(self, name, scenario_id=None):
        """Returns information about a template.

        :param name: Name of the template (executable ID).
        :param scenario_id: (optional) Template's scenario ID.
        """

        def get_executable(scenario_id, executable_id):
            url = urljoin(
                self.url, f"/v2/lm/scenarios/{scenario_id}/executables/{executable_id}"
            )
            return self.sess.get(url).json()

        if scenario_id:
            return get_executable(scenario_id, name)
        for s in self.list_scenarios():
            try:
                return get_executable(s["id"], name)
            except requests.HTTPError as err:
                if err.response.status_code == 404:
                    continue
                raise
        return None

    def _create_configuration(self, name, executable, parameters=None, artifacts=None):
        conf = {
            "name": name,
            "executableId": executable["id"],
            "scenarioId": executable["scenarioId"],
        }
        if parameters:
            bindings = [{"key": k, "value": v} for k, v in parameters.items()]
            conf["parameterBindings"] = bindings
        if artifacts:
            bindings = [
                {
                    "key": binding,
                    "artifactId": artifact["id"]
                    if isinstance(artifact, dict)
                    else artifact,
                }
                for binding, artifact in artifacts.items()
            ]
            conf["inputArtifactBindings"] = bindings
        return self.sess.post(
            urljoin(self.url, "/v2/lm/configurations"), json=conf
        ).json()

    def execute_flow(self, template, parameters=None, artifacts=None):
        """Starts a workflow execution.

        :param template: Template name or object returned from :meth:`Client.template` or :meth:`Client.list_templates`.
        :param parameters: (optional) Dict of ``param-name``-> ``value``.
        :param artifacts: (optional) Dict of ``binding-name`` -> ``S3 object URL or artifact or artifact ID``.
        """

        def make_artifact(obj):
            if isinstance(obj, str) and obj.startswith("s3://"):
                return register_artifact(obj, "dataset", template["scenarioId"])
            return obj

        if isinstance(template, str):
            template = self.template(template)

        if artifacts:
            artifacts = {k: make_artifact(v) for k, v in artifacts.items()}

        cfg = self._create_configuration(
            template["id"], template, parameters, artifacts
        )
        return self.sess.post(
            urljoin(self.url, "/v2/lm/executions"), json={"configurationId": cfg["id"]}
        ).json()

    def abort_execution(self, execution):
        """Stops a running flow.

        :param execution: Execution object or ID.
        """
        if isinstance(execution, dict):
            execution = execution["id"]
        return self.sess.patch(
            urljoin(self.url, f"/v2/lm/executions/{execution}"),
            json={"targetStatus": "STOPPED"},
        ).json()

    def execution(self, execution):
        """Shows an execution status.

        :param execution: Execution object or ID.
        """
        if isinstance(execution, dict):
            execution = execution["id"]
        return self.sess.get(urljoin(self.url, f"/v2/lm/executions/{execution}")).json()

    def execution_logs(self, execution):
        """Shows pods' logs of an execution.

        :param execution: Execution object or ID.
        """
        if isinstance(execution, dict):
            execution = execution["id"]
        logs = self.sess.get(
            urljoin(self.url, f"/v2/lm/executions/{execution}/logs")
        ).json()
        return "".join([x["msg"] for x in logs["data"]["result"]])

    def register_artifact(self, url, kind, scenario, description=""):
        """Registers an artifact.

        :param url: Object's URL.
        :param kind: One of the: ``model``, ``dataset``, ``resultset`` or ``other``.
        :param scenario: Scenario object or ID.
        :param description: (optional) Artifact's description.
        """
        if isinstance(scenario, dict):
            scenario = scenario["id"]
        if not url.startswith("ai://"):
            for secret in self.list_s3_secrets():
                meta = secret["metadata"]
                path = "s3://" + urljoin(
                    meta["storage.ai.sap.com/bucket"],
                    meta["storage.ai.sap.com/pathPrefix"],
                )
                if url.startswith(path):
                    url = url.replace(path, "ai://" + secret["name"])
                    break
        a = {
            "url": url,
            "name": scenario,
            "scenarioId": scenario,
            "kind": kind,
            "description": description,
        }
        return self.sess.post(urljoin(self.url, "/v2/lm/artifacts"), json=a).json()

    def list_artifacts(self):
        """Lists all registered artifacts."""
        return self.sess.get(urljoin(self.url, "/v2/lm/artifacts")).json()["resources"]
