from tabulate import tabulate
from .Processor import Processor
from .Operations import MESIOps as Ops
from .States import MESIState as State


class Protocol:
    def __init__(self, n_processors=2, memory_content=None):
        self.n_processors = n_processors
        self.processors = [Processor(pid=i + 1) for i in range(n_processors)]
        self.memory_content = memory_content
        # Statistical Parameters
        self.n_reads = 0
        self.n_writes = 0
        self.n_state_changes = 0
        self.n_invalidations = 0
        self.n_bus_latency = 0
        self.n_cache_misses = 0
        self.n_cache_to_cache_transfers = 0

        # Additional Private Variables
        self.modified_value = None
        self.history = []

    def __repr__(self):
        return tabulate(self.history, headers=self.get_headers(), stralign="center", numalign="center")

    def __str__(self):
        return tabulate(self.history, headers=self.get_headers(), stralign="center", numalign="center")

    def on_event(self, pid, event, function):
        if event == Ops.PrRd:
            self.n_reads += 1
        elif event == Ops.PrWr:
            self.n_writes += 1
        state = self.get_processor_state(pid)
        if state == State.I:
            states = {p.state for p in self.other_processors_not(pid=pid)}
            if len(states) == 1 and State.I in states:
                self.n_cache_misses += 1
            if State.S in states:
                self.n_cache_misses += 1
            if states.intersection({State.M, State.E}):
                self.n_cache_to_cache_transfers += 1

    def get_headers(self):
        headers = ["Step"]
        [headers.extend(["P{} State".format(i + 1), "P{} Cache".format(i + 1)]) for i, p in enumerate(self.processors)]
        headers.extend(["Memory Content", "Bus Transaction", "Modified"])
        return headers

    def save_history(self, step_name, transaction):
        fields = [step_name]
        [fields.extend([p.state, p.cache_content or '-']) for p in self.processors]
        fields.extend([self.memory_content or '-', transaction or '-', self.modified_value or '-'])
        self.history.append(fields)

    def processor(self, pid):
        return self.processors[pid - 1]

    def other_processors_not(self, pid):
        """Gets the list of other processors which are not equal to pid"""
        return [p for p in self.processors if p.pid != pid]

    def get_processor_state(self, pid):
        return self.processor(pid).state

    def set_processor_state(self, pid, state):
        if self.processor(pid).state != state:
            self.n_state_changes += 1
            if state == "I":
                self.n_invalidations += 1
        self.processor(pid).state = state

    def get_processor_cache_content(self, pid):
        return self.processor(pid).cache_content

    def set_processor_cache_content(self, pid, cache_content):
        self.processor(pid).cache_content = cache_content

    def perform_processor_operation(self, pid, function):
        if function and callable(function):
            cache_content = function(self.modified_value or self.get_processor_cache_content(pid) or self.memory_content)
            self.set_processor_cache_content(pid, cache_content)

    def flush(self, value):
        """"
        Request that indicates that a whole cache block is being written back to the memory
        Places the value of value of the cache line on the bus and updates the memory
        """
        self.memory_content = value

    def _perform_instruction(self, pid, event, function):
        # Perform the current Processor Event
        main_transaction = self.on_event(pid, event, function)
        if main_transaction: self.n_bus_latency += 1
        # Update the other processors
        for processor in self.processors:
            if processor.pid != pid:
                cache_content = self.get_processor_cache_content(processor.pid)
                transaction = self.on_event(processor.pid, main_transaction, function)
                if transaction == Ops.Flush:
                    self.flush(cache_content)
                    main_transaction += "/" + Ops.Flush
                if transaction: self.n_bus_latency += 1

        # Save the current state into history for printing
        step_name = self.format_instruction_name(pid, event)
        self.save_history(step_name, main_transaction)

    def perform_instructions(self, instructions, display=True):
        for i in instructions:
            pid, operation, function = i[0], i[1], i[2] if len(i) == 3 else None
            self._perform_instruction(pid, operation, function)
        if display: print(self, '\n')

    ###########################
    # Statistical Methods
    ###########################
    def statistics(self):
        stats = {
            "Instructions": self.step_count(),
            "Reads": self.n_reads,
            "Writes": self.n_writes,
            "Bus Transactions": self.n_bus_latency,
            "States": self.n_state_changes,
            "Invalidations": self.n_invalidations,
            "Cache-to-Cache Transfers": self.n_cache_to_cache_transfers,
            "Cache Misses": self.n_cache_misses,
        }
        return stats

    ###########################
    # Additional Helper Methods
    ###########################
    def step_count(self):
        return len(self.history)

    def format_instruction_name(self, processor_id, event):
        return "{}. P{} {}".format(len(self.history) + 1, processor_id, event)

    @staticmethod
    def format_instruction_set(func_name):
        if not func_name: return ""
        import inspect
        code = inspect.getsourcelines(func_name)[0][0].strip()
        code = code.replace("(", "").replace(")", "").split(",")[2].strip()
        code = code.split("lambda ")[1]
        return code

    def print(self, _filter=None):
        _filter = self.filter_dict().get(_filter, lambda x: x)
        headers = _filter(self.get_headers())
        history = [_filter(x) for x in self.history]
        print(tabulate(history, headers=headers, stralign="center", numalign="center"))

    def api(self):
        """Protocol API Interface"""
        return {"title": self.get_headers(), "results": [item[:-1] for item in self.history], "table": str(self), "stats": self.statistics()}

    def filter_dict(self):
        return {"cache": self._field_cache_filter, "transaction": self._field_transaction_filter}

    @staticmethod
    def _field_cache_filter(array):
        return array[:1] + array[1:][:-3:2]

    @staticmethod
    def _field_transaction_filter(array):
        return array[:1] + array[1:][:-3:2] + [array[-2]]
