import operator
from functools import reduce

from django.db.models import Q
from rest_framework import status
from rest_framework.generics import get_object_or_404
from rest_framework.response import Response
from rest_framework.routers import APIRootView
from rest_framework.viewsets import GenericViewSet
from rest_framework.renderers import JSONRenderer, BrowsableAPIRenderer

from ipam.models import Prefix, Aggregate, IPAddress, Service
from ipam.filtersets import PrefixFilterSet, IPAddressFilterSet, AggregateFilterSet, ServiceFilterSet
from dcim.models import Device
from dcim.filtersets import DeviceFilterSet
from virtualization.models import VirtualMachine
from virtualization.filtersets import VirtualMachineFilterSet
from extras.models import Tag

from drf_yasg import openapi
from drf_yasg.utils import swagger_auto_schema

from .renderers import PlainTextRenderer
from .utils import as_cidr


class ListsRootView(APIRootView):
    def get_view_name(self):
        return "Lists"


class ValuesListViewSet(GenericViewSet):
    renderer_classes = [JSONRenderer, BrowsableAPIRenderer, PlainTextRenderer]

    def list(self, request):
        queryset = self.filter_queryset(self.get_queryset())

        return Response([str(i) for i in queryset])


class PrefixListViewSet(ValuesListViewSet):
    queryset = Prefix.objects.values_list("prefix", flat=True).distinct()
    filterset_class = PrefixFilterSet


class AggregateListViewSet(ValuesListViewSet):
    queryset = Aggregate.objects.values_list("prefix", flat=True).distinct()
    filterset_class = AggregateFilterSet


class IPAddressListViewSet(ValuesListViewSet):
    queryset = IPAddress.objects.values_list("address", flat=True).distinct()
    filterset_class = IPAddressFilterSet

    @swagger_auto_schema(manual_parameters=[
        openapi.Parameter(
            "as_cidr", in_=openapi.IN_QUERY,
            description="Return IPs as /32 or /128", type=openapi.TYPE_BOOLEAN
        )
    ])
    def list(self, request):
        queryset = self.filter_queryset(self.get_queryset())

        if "as_cidr" in request.query_params:
            return Response(list(set([as_cidr(i) for i in queryset])))
        # We use list/set because distinct() won't work
        # for two IPs with the same address but different prefix length.
        return Response(list(set([str(i.ip) for i in queryset])))


class ServiceListviewSet(ValuesListViewSet):
    queryset = Service.objects.filter(ipaddresses__isnull=False).values_list(
        "ipaddresses__address", flat=True).distinct()
    filterset_class = ServiceFilterSet

    @swagger_auto_schema(manual_parameters=[
        openapi.Parameter(
            "as_cidr", in_=openapi.IN_QUERY,
            description="Return IPs as /32 or /128", type=openapi.TYPE_BOOLEAN
        )
    ])
    def list(self, request, *args, **kwargs):
        queryset = self.filter_queryset(self.get_queryset())

        if "as_cidr" in request.query_params:
            return Response(list(set([as_cidr(i) for i in queryset])))

        return Response(list(set([str(i.ip) for i in queryset])))


class TagsListViewSet(GenericViewSet):
    queryset = Tag.objects.all()
    lookup_field = "slug"
    renderer_classes = [JSONRenderer, BrowsableAPIRenderer, PlainTextRenderer]

    @swagger_auto_schema(manual_parameters=[
        openapi.Parameter(
            "prefixes", in_=openapi.IN_QUERY,
            description="Include prefixes", type=openapi.TYPE_BOOLEAN
        ),
        openapi.Parameter(
            "aggregates", in_=openapi.IN_QUERY,
            description="Include aggregates", type=openapi.TYPE_BOOLEAN
        ),
        openapi.Parameter(
            "services", in_=openapi.IN_QUERY,
            description="Include services", type=openapi.TYPE_BOOLEAN
        ),
        openapi.Parameter(
            "devices", in_=openapi.IN_QUERY,
            description="Include devices", type=openapi.TYPE_BOOLEAN
        ),
        openapi.Parameter(
            "vms", in_=openapi.IN_QUERY,
            description="Include vms", type=openapi.TYPE_BOOLEAN
        ),
        openapi.Parameter(
            "ips", in_=openapi.IN_QUERY,
            description="Include IP Addresses", type=openapi.TYPE_BOOLEAN
        ),

    ])
    def retrieve(self, request, slug=None):
        if not slug:
            return Response("No slug", status.HTTP_400_BAD_REQUEST)

        tag = get_object_or_404(Tag, slug=slug)

        if "prefixes" in request.query_params:
            prefixes = [str(i) for i in Prefix.objects.filter(
                tags=tag).values_list("prefix", flat=True).distinct()]
        else:
            prefixes = []

        if "aggregates" in request.query_params:
            aggregates = [str(i) for i in Aggregate.objects.filter(
                tags=tag).values_list("prefix", flat=True).distinct()]
        else:
            aggregates = []

        ip_filters = []
        if "ips" in request.query_params:
            ip_filters.append(Q(tags=tag))
        if "devices" in request.query_params:
            ip_filters.append(Q(interface__device__tags=tag))
        if "vms" in request.query_params:
            ip_filters.append(Q(vminterface__virtual_machine__tags=tag))
        if len(ip_filters) > 0:
            ips = [
                as_cidr(i)
                for i in IPAddress.objects.filter(reduce(operator.or_, ip_filters)).values_list(
                    "address", flat=True).distinct()
            ]
        else:
            ips = []

        if "services" in request.query_params:
            services = [
                as_cidr(i)
                for i in Service.objects.filter(Q(ipaddresses__isnull=False) & Q(tags=tag)).values_list(
                    "ipaddresses__address", flat=True).distinct()
            ]
        else:
            services = []

        return Response(list(set(prefixes + aggregates + ips + services)))


class PrometheusDeviceSD(GenericViewSet):
    queryset = Device.objects.all()
    filterset_class = DeviceFilterSet

    def _sd_device(self, d: Device) -> dict:
        labels = {
            "__meta_netbox_id": d.id,
            "__meta_netbox_name": d.name,
            "__meta_netbox_status": d.status,
            "__meta_netbox_site_name": d.site.name,
            "__meta_netbox_platform_name": d.platform.name if d.platform else "",
            "__meta_netbox_primary_ip": str(d.primary_ip.address.ip) if d.primary_ip else "",
            "__meta_netbox_primary_ip4": str(d.primary_ip4.address.ip) if d.primary_ip4 else "",
            "__meta_netbox_primary_ip6": str(d.primary_ip6.address.ip) if d.primary_ip6 else "",
            "__meta_netbox_serial": d.serial
        }
        for k, v in d.custom_field_data.items():
            labels[f"__meta_netbox_cf_{k}"] = v

        return {
            "targets": [str(d.primary_ip.address.ip) if d.primary_ip else d.name],
            "labels": labels
        }

    def list(self, request):
        queryset = self.filter_queryset(self.get_queryset())
        return Response([self._sd_device(d) for d in queryset])


class PrometheusVirtualMachineSD(GenericViewSet):
    queryset = VirtualMachine.objects.filter()
    filterset_class = VirtualMachineFilterSet

    def _sd_vm(self, vm: VirtualMachine) -> dict:
        labels = {
            "__meta_netbox_id": vm.id,
            "__meta_netbox_name": vm.name,
            "__meta_netbox_status": vm.status,  # TODO
            "__meta_netbox_cluster_name": vm.cluster.name,
            "__meta_netbox_site_name": vm.site.name if vm.site else "",
            "__meta_netbox_role_name": vm.role.name if vm.role else "",
            "__meta_netbox_platform_name": vm.platform.name if vm.platform else "",
            "__meta_netbox_primary_ip": str(vm.primary_ip.address.ip) if vm.primary_ip else "",
            "__meta_netbox_primary_ip4": str(vm.primary_ip4.address.ip) if vm.primary_ip4 else "",
            "__meta_netbox_primary_ip6": str(vm.primary_ip6.address.ip) if vm.primary_ip6 else ""
        }
        for k, v in vm.custom_field_data.items():
            labels[f"__meta_netbox_cf_{k}"] = v
        return {
            "targets": [str(vm.primary_ip.address.ip) if vm.primary_ip else vm.name],
            "labels": labels
        }

    def list(self, request):
        queryset = self.filter_queryset(self.get_queryset())
        return Response([self._sd_vm(vm) for vm in queryset])
