# SPDX-License-Identifier: BSD-3-Clause
#
# This file is part of SOL.
#
# Copyright (c) 2020 Great Scott Gadgets <info@greatscottgadgets.com>

''' Contains the gatware module necessary to interpret and generate low-level USB packets. '''


import functools
import operator


from torii             import Array, Cat, Const, Elaboratable, Module, Signal
from torii.hdl.rec     import DIR_FANIN, DIR_FANOUT, Record

from ...interface.utmi import UTMITransmitInterface
from ..stream          import USBInStreamInterface, USBOutStreamInterface
from .                 import USBPacketID, USBSpeed

#
# Interfaces.
#


class HandshakeExchangeInterface(Record):
	''' Record that carries handshakes detected -or- generated between modules.

	Attributes
	----------
	ack: Signal()
		When connected to a generator, pulsing this strobe will trigger generating of an ACK.
		When connected to a detector, this strobe will be pulsed when an ACK is detected from the host.
	nak: Signal()
		When connected to a generator, pulsing this strobe will trigger generating of an NAK.
		When connected to a detector, this strobe will be pulsed when an NAK is detected from the host.
	stall: Signal()
		When connected to a generator, pulsing this strobe will trigger generation of a STALL.
		Unused in a detector, currently.
	nyet: Signal()
		When connected to a generator, pulsing this strobe will trigger generation of a NYET.
		Unused in a detector, currently.

	Parameters
	----------
	is_detector: bool
		If true, this will be considered an interface to a detector that identifies handshakes.
		Otherwise, this will be considered an interface to a generator that accepts handshake requests.
	'''

	def __init__(self, *, is_detector):
		direction = DIR_FANOUT if is_detector else DIR_FANOUT

		super().__init__([
			('ack',   1, direction),
			('nak',   1, direction),
			('stall', 1, direction),
			('nyet',  1, direction),
		])



class DataCRCInterface(Record):
	''' Record providing an interface to a USB CRC-16 generator.

	Attributes
	----------
	start: Signal(), input to CRC generator
		Strobe that indicates that a new CRC computation should be started.
	crc: Signal(), output from CRC generator
		The current CRC-16 value; updated with each sent or received byte.
	'''

	def __init__(self):
		super().__init__([
			('start', 1,  DIR_FANIN),
			('crc',   16, DIR_FANOUT)
		])


class TokenDetectorInterface(Record):
	''' Record providing an interface to a USB token detector.

	Attributes
	----------
	pid: Signal(4), detector output
		The Packet ID of the most recent token.
	address: Signal(7), detector output
		The address associated with the relevant token.
	endpoint: Signal(4), detector output
		The endpoint indicated by the most recent token.

	new_token: Signal(), detector output
		Strobe asserted for a single cycle when a new token packet has been received.
	ready_for_response: Signal(), detector output
		Strobe asserted for a single cycle one inter-packet delay after a token packet is complete.
		Indicates when the token packet can be responded to.

	frame: Signal(11), detector output
		The current USB frame number.
	new_frame: Signal(), detector output
		Strobe asserted for a single cycle when a new SOF has been received.

	is_in: Signal(), detector output
		High iff the current token is an IN.
	is_out: Signal(), detector output
		High iff the current token is an OUT.
	is_setup: Signal(), detector output
		High iff the current token is a SETUP.
	is_ping: Signal(), detector output
		High iff the current token is a PING.
	'''

	def __init__(self):
		super().__init__([
			('pid',                4, DIR_FANOUT),
			('address',            7, DIR_FANOUT),
			('endpoint',           4, DIR_FANOUT),
			('new_token',          1, DIR_FANOUT),
			('ready_for_response', 1, DIR_FANOUT),

			('frame',             11, DIR_FANOUT),
			('new_frame',          1, DIR_FANOUT),

			('is_in',              1, DIR_FANOUT),
			('is_out',             1, DIR_FANOUT),
			('is_setup',           1, DIR_FANOUT),
			('is_ping',            1, DIR_FANOUT),
		])


class InterpacketTimerInterface(Record):
	''' Record providing an interface to our interpacket timer.

	See [USB2.0: 7.1.18] and the USBInterpacketTimer gateware for more information.

	Attributes
	----------
	start: Signal(), input to timer
		Strobe that indicates when the timer should be started. Usually started at the end of an Rx or Tx event.

	tx_allowed: Signal(), output from timer
		Strobe that goes high when it's safe to transmit after an Rx event.
	tx_timeout: Signal(), output from timer
		Strobe that goes high when the transmit-after-receive window has passed.
	rx_timeout: Signal(), output from timer
		Strobe that goes high when the receive-after-transmit window has passed.
	'''

	def __init__(self):
		super().__init__([
			('start',      1, DIR_FANIN),

			('tx_allowed', 1, DIR_FANOUT),
			('tx_timeout', 1, DIR_FANOUT),
			('rx_timeout', 1, DIR_FANOUT),
		])


	def attach(self, *subordinates):
		''' Attaches subordinate interfaces to the given timer interface.

		Parameters
		----------
		subordinates: [InterpacketTimerInterface, Signal]
			Each :class:`InterpacketTimerInterface` is provided will be fully connected to a given
			timer interface. Each ``Signal`` provided will be interpreted as a timer reset, and added
			to the list of all resets.
		'''

		start_conditions = []
		fragments = []

		for subordinate in subordinates:

			# If this is an interface, add its start to our list of start conditions,
			# and propagate our timer outputs to it.
			if isinstance(subordinate, self.__class__):
				start_conditions.append(subordinate.start)
				fragments.extend([
					subordinate.tx_allowed.eq(self.tx_allowed),
					subordinate.tx_timeout.eq(self.tx_timeout),
					subordinate.rx_timeout.eq(self.rx_timeout)
				])

			# If it's a signal, connect it directly as a start signal.
			else:
				start_conditions.append(subordinate)

		# Merge all of our start conditions into a single start condition, and
		# then add that to our fragment list.
		start_condition = functools.reduce(operator.__or__, start_conditions)
		fragments.append(self.start.eq(start_condition))

		return fragments


#
# Gateware.
#


class USBTokenDetector(Elaboratable):
	''' Gateware that parses token packets and generates relevant events.

	Attributes
	----------
	interface: TokenDetectorInterface
		The interface that contains token detection events, and information about detected tokens.
	speed: Signal(2), input
		Carries a ``USBSpeed`` constant identifying the device's current operating speed.
	address: Signal(7), input -or- output
		If :parameter:``filter_by_address`` is true, this is an input that filters our event detector so
		it only reports tokens directed at a given address.
		If ``filter_by_address`` is false, this is an output that contains the address of the most
		recent token.


	Parameters
	----------
		utmi: UTMIInterface
			The UTMI bus to observe.
		filter_by_address: bool
			If true, this detector will only report events for the address supplied in the address[] field.
	'''

	SOF_PID      = 0b0101
	TOKEN_SUFFIX = 0b01

	def __init__(self, *, utmi, filter_by_address = True, domain_clock = 60e6, fs_only = False):
		self.utmi = utmi
		self.filter_by_address = filter_by_address
		self._domain_clock = domain_clock
		self._fs_only = fs_only

		#
		# I/O port
		#
		self.interface = TokenDetectorInterface()
		self.speed     = Signal(2)
		self.address   = Signal(7)


	@staticmethod
	def _generate_crc_for_token(token):
		''' Generates a 5-bit signal equivalent to the CRC check for the provided token packet. '''

		def xor_bits(*indices):
			bits = (token[len(token) - 1 - i] for i in indices)
			return functools.reduce(operator.__xor__, bits)

		# Implements the CRC polynomial from the USB specification.
		return Cat(
			xor_bits(10, 9, 8, 5, 4, 2),
			~xor_bits(10, 9, 8, 7, 4, 3, 1),
			xor_bits(10, 9, 8, 7, 6, 3, 2, 0),
			xor_bits(10, 7, 6, 4, 1),
			xor_bits(10, 9, 6, 5, 3, 0)
		)


	def elaborate(self, platform):
		m = Module()

		token_data       = Signal(11)
		current_pid      = Signal.like(self.interface.pid)

		# Instantiate a dedicated inter-packet delay timer, which
		# we'll use to generate our `ready_for_response` signal.
		#
		# Giving this unit a timer separate from a device's main
		# timer simplifies the architecture significantly; and
		# removes a primary source of timer contention.
		m.submodules.timer = USBInterpacketTimer(domain_clock = self._domain_clock, fs_only = self._fs_only)
		timer              = InterpacketTimerInterface()
		m.d.comb += m.submodules.timer.speed.eq(self.speed)

		# Generate our 'ready_for_response' signal whenever our
		# timer reaches a delay that indicates it's safe to respond to a token.
		m.submodules.timer.add_interface(timer)
		m.d.comb += self.interface.ready_for_response.eq(timer.tx_allowed)

		# Generate our convenience status signals.
		m.d.comb += [
			self.interface.is_in.eq(self.interface.pid == USBPacketID.IN),
			self.interface.is_out.eq(self.interface.pid == USBPacketID.OUT),
			self.interface.is_setup.eq(self.interface.pid == USBPacketID.SETUP),
			self.interface.is_ping.eq(self.interface.pid == USBPacketID.PING)
		]

		# Keep our strobes un-asserted unless otherwise specified.
		m.d.usb += [
			self.interface.new_frame.eq(0),
			self.interface.new_token.eq(0)
		]

		with m.FSM(domain = 'usb'):

			# IDLE -- waiting for a packet to be presented
			with m.State('IDLE'):
				with m.If(self.utmi.rx_active):
					m.next = 'READ_PID'


			# READ_PID -- read the packet's ID, and determine if it's a token.
			with m.State('READ_PID'):

				with m.If(~self.utmi.rx_active):
					m.next = 'IDLE'

				with m.Elif(self.utmi.rx_valid):
					is_normal_token = (self.utmi.rx_data[0:2] == self.TOKEN_SUFFIX)
					is_ping_token   = (self.utmi.rx_data[0:4] == USBPacketID.PING)
					is_valid_pid    = (self.utmi.rx_data[0:4] == ~self.utmi.rx_data[4:8])

					# If we have a valid token, move to capture it.
					# Note that we have two categories of token we'll accept: normal tokens (IN, OUT, SETUP, SOF),
					# and our SPECIAL category tokens (e.g. PING), which have a separate PID suffix.
					with m.If((is_normal_token | is_ping_token) & is_valid_pid):
						m.d.usb += current_pid.eq(self.utmi.rx_data)
						m.next = 'READ_TOKEN_0'

					# Otherwise, ignore this packet as a non-token.
					with m.Else():
						m.next = 'IRRELEVANT'


			with m.State('READ_TOKEN_0'):

				# If our transaction stops, discard the current read state.
				# We'll ignore token fragments, since it's impossible to tell
				# if they were e.g. for us.
				with m.If(~self.utmi.rx_active):
					m.next = 'IDLE'

				# If we have a new byte, grab it, and move on to the next.
				with m.Elif(self.utmi.rx_valid):
					m.d.usb += token_data.eq(self.utmi.rx_data)
					m.next = 'READ_TOKEN_1'


			with m.State('READ_TOKEN_1'):

				with m.If(~self.utmi.rx_active):
					m.next = 'IDLE'

				# Once we've just gotten the second core byte of our token,
				# we can validate our checksum and handle it.
				with m.Elif(self.utmi.rx_valid):
					expected_crc = self._generate_crc_for_token(
						Cat(token_data[0:8], self.utmi.rx_data[0:3]))

					# If the token has a valid CRC, capture it...
					with m.If(self.utmi.rx_data[3:8] == expected_crc):
						m.d.usb += token_data[8:].eq(self.utmi.rx_data)
						m.next = 'TOKEN_COMPLETE'

					# ... otherwise, we'll ignore the whole token, as we can't tell
					# if this token was meant for us.
					with m.Else():
						m.next = 'IRRELEVANT'

			# TOKEN_COMPLETE: we've received a full token; and now need to wait
			# for the packet to be complete.
			with m.State('TOKEN_COMPLETE'):

				# Once our receive is complete, use the completed token,
				# and strobe our 'new token' signal.
				with m.If(~self.utmi.rx_active):
					m.next = 'IDLE'

					# Special case: if this is a SOF PID, we'll extract
					# the frame number from this, rather than our typical
					# token fields.
					with m.If(current_pid == self.SOF_PID):
						m.d.usb += [
							self.interface.frame.eq(token_data),
							self.interface.new_frame.eq(1),
						]

					# Otherwise, extract the address and endpoint from the token,
					# and report the captured pid.
					with m.Else():

						# If we're filtering by address, only count this token if it's releveant to our address.
						# Otherwise, always count tokens -- we'll report the address on the output.
						token_applicable = (token_data[0:7] == self.address) if self.filter_by_address else True
						with m.If(token_applicable):
							m.d.usb += [
								self.interface.pid.eq(current_pid),
								self.interface.new_token.eq(1),

								Cat(self.interface.address, self.interface.endpoint).eq(token_data)
							]

							# Start our interpacket-delay timer.
							m.d.comb += timer.start.eq(1)

						# If we don't count the token, clear the state so we don't act on following packets.
						with m.Else():
							m.d.usb += self.interface.pid.eq(0)


				# Otherwise, if we get more data, we've received a malformed
				# token -- which we'll ignore.
				with m.Elif(self.utmi.rx_valid):
					m.next = 'IRRELEVANT'


			# NON_TOKEN -- we've encountered a non-token packet; wait for it to end
			with m.State('IRRELEVANT'):

				with m.If(~self.utmi.rx_active):
					m.next = 'IDLE'

		return m


class USBHandshakeDetector(Elaboratable):
	''' Gateware that detects handshake packets.

	Attributes
	-----------
	detected: HandshakeExchangeInterface
		Strobes that indicate which handshakes we're detecting.

	Parameters
	----------
	utmi: [UTMIInterface, UTMITranslator]
		The UTMI interface to listen on.
	'''

	ACK_PID   = 0b0010
	NAK_PID   = 0b1010
	STALL_PID = 0b1110
	NYET_PID  = 0b0110

	def __init__(self, *, utmi):
		self.utmi = utmi

		#
		# I/O port
		#
		self.detected = HandshakeExchangeInterface(is_detector = True)


	def elaborate(self, platform):
		m = Module()

		active_pid = Signal(4)

		# Keep our strobes un-asserted unless otherwise specified.
		m.d.usb += [
			self.detected.ack.eq(0),
			self.detected.nak.eq(0),
			self.detected.stall.eq(0),
			self.detected.nyet.eq(0),
		]


		with m.FSM(domain = 'usb'):

			# IDLE -- waiting for a packet to be presented
			with m.State('IDLE'):
				with m.If(self.utmi.rx_active):
					m.next = 'READ_PID'


			# READ_PID -- read the packet's ID.
			with m.State('READ_PID'):

				with m.If(~self.utmi.rx_active):
					m.next = 'IDLE'

				with m.Elif(self.utmi.rx_valid):
					is_valid_pid = (self.utmi.rx_data[0:4] == ~self.utmi.rx_data[4:8])

					# If we have a valid PID, move to capture it.
					with m.If(is_valid_pid):
						m.d.usb += active_pid.eq(self.utmi.rx_data)
						m.next = 'AWAIT_COMPLETION'

					# Otherwise, ignore this packet as a non-token.
					with m.Else():
						m.next = 'IRRELEVANT'


			# TOKEN_COMPLETE: we've received a full token; and now need to wait
			# for the packet to be complete.
			with m.State('AWAIT_COMPLETION'):

				# Once our receive is complete, we can parse the PID
				# and identify the event.
				with m.If(~self.utmi.rx_active):
					m.d.usb += [
						self.detected.ack.eq(active_pid == self.ACK_PID),
						self.detected.nak.eq(active_pid == self.NAK_PID),
						self.detected.stall.eq(active_pid == self.STALL_PID),
						self.detected.nyet.eq(active_pid == self.NYET_PID),
					]
					m.next = 'IDLE'

				# Otherwise, if we get more data, this isn't a valid handshake.
				# Skip this packet as irrelevant.
				with m.Elif(self.utmi.rx_valid):
					m.next = 'IRRELEVANT'


			# IRRELEVANT -- we've encountered a malformed or non-handshake packet
			with m.State('IRRELEVANT'):

				with m.If(~self.utmi.rx_active):
					m.next = 'IDLE'

		return m




class USBDataPacketCRC(Elaboratable):
	''' Gateware that computes a running CRC-16.

	By default, this module has no connections to the modules that use it.

	These are added using :attr:`add_interface`; this module supports an arbitrary
	number of connection interfaces; see :attr:`add_interface()` for restrictions.

	Attributes
	----------
	rx_data: Signal(8), input
		Receive data input; can be carried directly from a UTMI interface.
	rx_valid: Signal(), input
		Receive validity signal; can be carried directly from a UTMI interface.

	tx_data: Signal(8), input
		Transmit data input; can be carried directly from a UTMI interface.
	tx_valid: Signal(), input
		When high, the `tx_data` input is used to update the CRC.

	Parameters
	----------
	initial_value: [int, Const]
			The initial value of the CRC shift register; the USB default is used if not provided.
	'''

	def __init__(self, initial_value = 0xFFFF):

		self._initial_value = initial_value

		# List of interfaces to work with.
		# This list is populated dynamically by calling .add_interface().
		self._interfaces    = []

		#
		# I/O port
		#
		self.clear = Signal()

		self.rx_data  = Signal(8)
		self.rx_valid = Signal()

		self.tx_data  = Signal(8)
		self.tx_valid = Signal()

		self.crc   = Signal(16, reset = initial_value)


	def add_interface(self, interface : DataCRCInterface):
		''' Adds an interface to the CRC generator module.

		Each interface can reset the CRC; and can read the current CRC value.
		No arbitration is performed; it's assumed that no more than one interface
		will be computing a running CRC at at time.

		Parameters
		----------
		interface: DataCRCInterface
			The interface to be added; accepts control signals from other modules, and
			brings CRC output to them. This method can be called multiple times to generate
			multiplpe CRCs.
		'''
		self._interfaces.append(interface)


	def _generate_next_crc(self, current_crc, data_in):
		''' Generates the next round of a bytewise USB CRC16. '''
		def xor_reduce(bits):
			return functools.reduce(operator.__xor__, bits)

		# Extracted from the USB spec's definition of the CRC16 polynomial.
		return Cat(
			xor_reduce(data_in)      ^ xor_reduce(current_crc[ 8:16]),
			xor_reduce(data_in[0:7]) ^ xor_reduce(current_crc[ 9:16]),
			xor_reduce(data_in[6:8]) ^ xor_reduce(current_crc[ 8:10]),
			xor_reduce(data_in[5:7]) ^ xor_reduce(current_crc[ 9:11]),
			xor_reduce(data_in[4:6]) ^ xor_reduce(current_crc[10:12]),
			xor_reduce(data_in[3:5]) ^ xor_reduce(current_crc[11:13]),
			xor_reduce(data_in[2:4]) ^ xor_reduce(current_crc[12:14]),
			xor_reduce(data_in[1:3]) ^ xor_reduce(current_crc[13:15]),

			xor_reduce(data_in[0:2]) ^ xor_reduce(current_crc[14:16]) ^ current_crc[0],
			data_in[0] ^ current_crc[1] ^ current_crc[15],
			current_crc[2],
			current_crc[3],
			current_crc[4],
			current_crc[5],
			current_crc[6],
			xor_reduce(data_in) ^ xor_reduce(current_crc[7:16]),
		)


	def elaborate(self, platform):
		m = Module()

		# Register that contains the running CRCs.
		crc        = Signal(16, reset = self._initial_value)

		# Signal that contains the output version of our active CRC.
		output_crc = Signal.like(crc)

		# We'll clear our CRC whenever any of our interfaces request it.
		start_signals = (interface.start for interface in self._interfaces)
		clear = functools.reduce(operator.__or__, start_signals)

		# If we're clearing our CRC in progress, move our holding register back to
		# our initial value.
		with m.If(clear):
			m.d.usb += crc.eq(self._initial_value)

		# Otherwise, update the CRC whenever we have new data.
		with m.Elif(self.rx_valid):
			m.d.usb += crc.eq(self._generate_next_crc(crc, self.rx_data))
		with m.Elif(self.tx_valid):
			m.d.usb += crc.eq(self._generate_next_crc(crc, self.tx_data))

		# Convert from our intermediary 'running CRC' format into the current CRC-16...
		m.d.comb += output_crc.eq(~crc[::-1])

		# ... and connect it to each of our interfaces.
		for interface in self._interfaces:
			m.d.comb += interface.crc.eq(output_crc)

		return m


class USBDataPacketReceiver(Elaboratable):
	''' Gateware that converts received USB data packets into a data-stream packets.

	It's important to note that packet payloads are mostly directly carried over from UTMI.
	Since USB data is received -prior- to its CRC, one cannot know if a packet is valid until
	after it has been compeltely received. As a result, this interface will generate data of
	unknown validity, followed by a strobe on either :attr:`packet_complete` or :attr:`crc_mismatch`.
	The receiving interface must be prepared to handle :attr:`crc_mismatch` by discarding the received
	data.


	Attributes
	----------
	data_crc: DataCRCInterface
		Connection to the CRC generator.
	timer: InterpacketTimerInterface
		Connection to our interpacket timer.
	stream: USBOutDataStream, output
		Stream that carries captured packet data.

	active_pid: Signal(4), output
		The PID of the data currently being received.
	packet_id: Signal(4), output
		The packet ID of the most recently captured PID. Becomes valid simultaneous to a strobe on
		:attr:`packet_complete` or :attr:`crc_mismatch`.

	packet_complete: Signal(), output
		Strobe that pulses high when a new packet is delivered with a valid CRC.
	crc_mismatch: Signal(), output
		Strobe that pulses high when the given packet has a CRC mismatch; and thus the data
		received this far should be discarded.
	ready_for_response: Signal(), output
		Strobe that indicates that an inter-packet delay has passed since :attr:`packet_complete`,
		and thus we're now ready to respond with a handshake.

	Parameters
	----------
	utmi: UTMIInterface, or equivalent
		The UTMI bus to observe.
	max_packet_size: int
		The maximum packet (payload) size to be deserialized, in bytes.

	standalone: bool
		Debug value. If True, a submodule CRC generator will be created.
	speed: USBSpeed
		USBSpeed signal or constant that specifies our speed in standalone mode.
	'''

	_DATA_SUFFIX = 0b11

	def __init__(self, *, utmi, standalone = False, speed = None):

		self.utmi        = utmi
		self.standalone  = standalone
		self.speed       = speed

		#
		# I/O port
		#
		self.data_crc           = DataCRCInterface()
		self.timer              = InterpacketTimerInterface()
		self.stream             = USBOutStreamInterface()

		self.active_pid         = Signal(4)

		self.packet_complete    = Signal()
		self.ready_for_response = Signal()
		self.crc_mismatch       = Signal()
		self.packet_id          = Signal(4)


	def elaborate(self, platform):
		m = Module()

		# If we're in standalone mode, create our dependencies for us.
		if self.standalone:
			m.submodules.crc = crc = USBDataPacketCRC()
			crc.add_interface(self.data_crc)

			m.submodules.timer = timer = USBInterpacketTimer()
			timer.add_interface(self.timer)

			if not self.speed:
				self.speed = USBSpeed.FULL

			m.d.comb += [

				# Connect our CRC generator...
				crc.rx_data.eq(self.utmi.rx_data),
				crc.rx_valid.eq(self.utmi.rx_valid),
				crc.tx_valid.eq(0),

				# ... and our timer.
				timer.speed.eq(self.speed)
			]


		# CRC-16 tracking signals.
		last_byte_crc = Signal(16)
		last_word_crc = Signal(16)

		# Keeps track of the most recently received word; for CRC comparison/removal.
		data_pipeline     = Signal(8 * 2)

		# Keep our control signals + strobes un-asserted unless otherwise specified.
		m.d.usb  += [
			self.packet_complete.eq(0),
			self.crc_mismatch.eq(0),
		]
		m.d.comb += [
			self.stream.next.eq(0),
			self.data_crc.start.eq(0),
		]


		with m.FSM(domain = 'usb'):

			# IDLE -- waiting for a packet to be presented
			with m.State('IDLE'):

				with m.If(self.utmi.rx_active):
					m.next = 'READ_PID'

			# READ_PID -- read the packet's ID.
			with m.State('READ_PID'):

				# Clear our CRC; as we're potentially about to start a new packet.
				m.d.comb += self.data_crc.start.eq(1)

				with m.If(~self.utmi.rx_active):
					m.next = 'IDLE'

				with m.Elif(self.utmi.rx_valid):
					is_data      = (self.utmi.rx_data[0:2] == self._DATA_SUFFIX)
					is_valid_pid = (self.utmi.rx_data[0:4] == ~self.utmi.rx_data[4:8])

					# If this is a data packet, capture its PID.
					with m.If(is_valid_pid & is_data):
						m.d.usb += self.active_pid.eq(self.utmi.rx_data),
						m.next = 'RECEIVE_FIRST_BYTE'

					# Otherwise, ignore this packet.
					with m.Else():
						m.next = 'IRRELEVANT'


			# RECEIVE_FIRST_BYTE -- capture the first byte into our pipeline.
			# We'll always pipeline two bytes before we start emitting; as we won't want to
			# pass through the last two bytes (the CRC).
			with m.State('RECEIVE_FIRST_BYTE'):

				with m.If(self.utmi.rx_valid):
					m.d.usb += [
						data_pipeline[8:].eq(self.utmi.rx_data),
						last_byte_crc.eq(self.data_crc.crc)
					]
					m.next = 'RECEIVE_SECOND_BYTE'

				# If our packet stops before we see the first to bytes, we'll return to idle.
				# There's nothing to clean up, as we've never touched the stream.
				with m.If(~self.utmi.rx_active):
					m.next = 'IDLE'


			# RECEIVE_SECOND_BYTE-- capture the second byte into our pipeline.
			with m.State('RECEIVE_SECOND_BYTE'):

				with m.If(self.utmi.rx_valid):
					m.d.usb += [
						data_pipeline[8:].eq(self.utmi.rx_data),
						data_pipeline[0:8].eq(data_pipeline[8:]),

						last_byte_crc.eq(self.data_crc.crc),
						last_word_crc.eq(last_byte_crc),
					]
					m.next = 'RECEIVE_AND_EMIT'

				# If our packet stops before we see the first to bytes, we'll return to idle.
				# There's nothing to clean up, as we've never touched the stream.
				with m.Elif(~self.utmi.rx_active):
					m.next = 'IDLE'


			# RECEIVE_AND_EMIT -- receive bytes into our pipeline, and emit them.
			# Now that we have more than two bytes captured, we can start emitting bytes.
			# We'll always be emitting bytes that are two old -- so we can stop before our CRC.:
			with m.State('RECEIVE_AND_EMIT'):
				m.d.comb += self.stream.valid.eq(1)

				with m.If(self.utmi.rx_valid):

					m.d.comb += [
						# Emit the current packet...
						self.stream.payload.eq(data_pipeline[0:8]),
						self.stream.next.eq(1),
					]

					m.d.usb += [

						# ... capture the incoming one...
						data_pipeline[8:].eq(self.utmi.rx_data),
						data_pipeline[0:8].eq(data_pipeline[8:]),

						# ... and update our cached CRCs.
						last_byte_crc.eq(self.data_crc.crc),
						last_word_crc.eq(last_byte_crc),
					]


				# Once we stop receiving data, check our CRC and finish.
				with m.If(~self.utmi.rx_active):

					# If our CRC matches, this is a valid packet!
					with m.If(last_word_crc == data_pipeline):

						# Indicate so...
						m.d.usb += [
							self.packet_id.eq(self.active_pid),
							self.packet_complete.eq(1)
						]

						# ... start counting our interpacket delay...
						m.d.comb += [
							self.timer.start.eq(1)
						]

						# ... and wait for it to complete.
						m.next = 'INTERPACKET_DELAY'


					# Otherwise, flag this as a CRC mismatch.
					with m.Else():
						m.d.usb += [
							self.crc_mismatch.eq(1)
						]

						# ... and return to IDLE.
						m.next = 'IDLE'


			# INTERPACKET_DELAY -- we've received a valid packet; wait for an
			# interpacket delay before moving back to IDLE.
			with m.State('INTERPACKET_DELAY'):

				with m.If(self.timer.tx_allowed):
					m.d.comb += self.ready_for_response.eq(1)
					m.next = 'IDLE'


			# IRRELEVANT -- we've encountered a malformed or non-DATA packet.
			with m.State('IRRELEVANT'):

				with m.If(~self.utmi.rx_active):
					m.next = 'IDLE'

		return m


class USBDataPacketDeserializer(Elaboratable):
	''' Gateware that captures USB data packet contents and parallelizes them.

	Attributes
	----------
	data_crc: DataCRCInterface
		Connection to the CRC generator.

	new_packet: Signal(), output
		Strobe that pulses high for a single cycle when a new packet is delivered.
	packet_id: Signal(4), output
		The packet ID of the captured PID.

	packet: Signal(max_packet_size), output
		Packet data for a the most recently received packet.
	length: Signal(range(0, max_packet_length +1)), output
		The length of the packet data presented on the packet[] output.

	Parameters
	----------
	utmi: UTMIInterface, or equivalent
		The UTMI bus to observe.
	max_packet_size: int
		The maximum packet (payload) size to be deserialized, in bytes.
	create_crc_generator: bool
		If True, a submodule CRC generator will be created. Excellent for testing.
	'''

	_DATA_SUFFIX = 0b11

	def __init__(self, *, utmi, max_packet_size = 64, create_crc_generator = False):

		self.utmi                 = utmi
		self._max_packet_size     = max_packet_size
		self.create_crc_generator = create_crc_generator

		#
		# I/O port
		#
		self.data_crc    = DataCRCInterface()

		self.new_packet  = Signal()

		self.packet_id   = Signal(4)
		self.packet      = Array(Signal(8, name = f'packet_{i}') for i in range(max_packet_size))
		self.length      = Signal(range(0, max_packet_size + 1))


	def elaborate(self, platform):
		m = Module()

		max_size_with_crc = self._max_packet_size + 2

		# If we're creating an internal CRC generator, create a submodule
		# and hook it up.
		if self.create_crc_generator:
			m.submodules.crc = crc = USBDataPacketCRC()
			crc.add_interface(self.data_crc)

			m.d.comb += [
				crc.rx_data.eq(self.utmi.rx_data),
				crc.rx_valid.eq(self.utmi.rx_valid),
				crc.tx_valid.eq(0)
			]

		# CRC-16 tracking signals.
		last_byte_crc = Signal(16)
		last_word_crc = Signal(16)

		# Currently captured PID.
		active_pid         = Signal(4)

		# Active packet transfer.
		active_packet      = Array(Signal(8) for _ in range(max_size_with_crc))
		position_in_packet = Signal(range(0, max_size_with_crc))

		# Keeps track of the most recently received word; for CRC comparison.
		last_word          = Signal(16)

		# Keep our control signals + strobes un-asserted unless otherwise specified.
		m.d.usb += self.new_packet.eq(0)
		m.d.comb += self.data_crc.start.eq(0)

		with m.FSM(domain = 'usb'):

			# IDLE -- waiting for a packet to be presented
			with m.State('IDLE'):

				with m.If(self.utmi.rx_active):
					m.next = 'READ_PID'

			# READ_PID -- read the packet's ID.
			with m.State('READ_PID'):
				# Clear our CRC; as we're potentially about to start a new packet.
				m.d.comb += self.data_crc.start.eq(1)

				with m.If(~self.utmi.rx_active):
					m.next = 'IDLE'

				with m.Elif(self.utmi.rx_valid):
					is_data      = (self.utmi.rx_data[0:2] == self._DATA_SUFFIX)
					is_valid_pid = (self.utmi.rx_data[0:4] == ~self.utmi.rx_data[4:8])

					# If this is a data packet, capture it.
					with m.If(is_valid_pid & is_data):
						m.d.usb += [
							active_pid.eq(self.utmi.rx_data),
							position_in_packet.eq(0)
						]
						m.next = 'CAPTURE_DATA'

					# Otherwise, ignore this packet.
					with m.Else():
						m.next = 'IRRELEVANT'


			with m.State('CAPTURE_DATA'):

				# If we have a new byte of data, capture it.
				with m.If(self.utmi.rx_valid):

					# If this would over-fill our internal buffer, fail out.
					with m.If(position_in_packet >= max_size_with_crc):
						# TODO: potentially signal the babble?
						m.next = 'IRRELEVANT'

					with m.Else():
						m.d.usb += [
							active_packet[position_in_packet].eq(self.utmi.rx_data),
							position_in_packet.eq(position_in_packet + 1),

							last_word.eq(Cat(last_word[8:], self.utmi.rx_data)),

							last_word_crc.eq(last_byte_crc),
							last_byte_crc.eq(self.data_crc.crc),
						]


				# If this is the end of our packet, validate our CRC and finish.
				with m.If(~self.utmi.rx_active):

					with m.If(last_word_crc == last_word):
						m.d.usb += [
							self.packet_id.eq(active_pid),
							self.length.eq(position_in_packet - 2),
							self.new_packet.eq(1)
						]

						for i in range(self._max_packet_size):
							m.d.usb += self.packet[i].eq(active_packet[i]),

						m.next = 'IDLE'

			# IRRELEVANT -- we've encountered a malformed or non-handshake packet
			with m.State('IRRELEVANT'):

				with m.If(~self.utmi.rx_active):
					m.next = 'IDLE'

		return m



class USBDataPacketGenerator(Elaboratable):
	''' Module that converts a FIFO-style stream into a USB data packet.

	Handles steps such as PID generation and CRC-16 injection.

	As a special case, if the stream pulses `last` (with valid = 1) without pulsing
	`first`, we'll send a zero-length packet.

	Attributes
	----------

	data_pid: Signal(2), input
		The data packet number to use. The potential PIDS are: 0 = DATA0, 1 = DATA1,
		2 = DATA2, 3 = MDATA; the interface is designed so that most endpoints can tie the MSb to
		zero and then perform PID toggling by toggling the LSb.

	crc: DataCRCInterface
		Interface to our data CRC generator.
	stream: USBInStreamInterface
		Stream input for the raw data to be transmitted.
	tx: UTMITransmitInterface
		UTMI-subset transmit interface

	Parameters
	----------
	standalone: bool
		If True, this unit will include its internal CRC generator. Perfect for unit testing or debugging.
	'''

	def __init__(self, standalone = False):

		self.standalone = standalone

		#
		# I/O port
		#
		self.data_pid     = Signal(2)

		self.crc          = DataCRCInterface()
		self.stream       = USBInStreamInterface()
		self.tx           = UTMITransmitInterface()


	def elaborate(self, platform):
		m = Module()

		# Create a mux that maps our data_pid value to our actual data PID.
		data_pids = Array([
			Const(0xC3, shape = 8), # DATA0
			Const(0x4B, shape = 8), # DATA1
			Const(0x87, shape = 8), # DATA2
			Const(0x0F, shape = 8)  # DATAM
		])

		# Stores the current data pid; latched in at the start of a transmission.
		current_data_pid = Signal(8)

		# Register that stores the final CRC byte.
		# Capturing this before the end of the packet ensures we can still send
		# the correct final CRC byte; even if the CRC generator updates its computation
		# when the first byte of the CRC is transmitted.
		remaining_crc = Signal(8)

		# Flag that stores whether we're sending a zero-length packet.
		is_zlp = Signal()

		# If we're creating an internal CRC generator, create a submodule
		# and hook it up.
		if self.standalone:
			m.submodules.crc = crc = USBDataPacketCRC()
			crc.add_interface(self.crc)

			m.d.comb += [
				crc.rx_valid.eq(0),

				crc.tx_data.eq(self.stream.payload),
				crc.tx_valid.eq(self.tx.ready)
			]

		with m.FSM(domain = 'usb'):

			# IDLE -- waiting for an active transmission to start.
			with m.State('IDLE'):

				# We won't consume any data while we're in the IDLE state.
				m.d.comb += self.stream.ready.eq(0)

				# Latch in the requested data PID.
				m.d.usb += current_data_pid.eq(data_pids[self.data_pid])

				# Once a packet starts, we'll need to transmit the data PID.
				with m.If(self.stream.first & self.stream.valid):
					m.d.usb += is_zlp.eq(0)
					m.next = 'SEND_PID'

				# Special case: if `last` pulses without first, we'll consider this
				# a zero-length packet ('a packet without a first byte').
				with m.Elif(self.stream.last & self.stream.valid):
					m.d.usb += is_zlp.eq(1)
					m.next = 'SEND_PID'


			# SEND_PID -- prepare for the transaction by sending the data packet ID.
			with m.State('SEND_PID'):

				m.d.comb += [
					# Prepare for a new payload by starting a new CRC calculation.
					self.crc.start.eq(1),

					# Send the USB packet ID for our data packet...
					self.tx.data.eq(current_data_pid),
					self.tx.valid.eq(1),

					# ... and don't consume any data.
					self.stream.ready.eq(0)
				]

				# Advance once the PHY accepts our PID.
				with m.If(self.tx.ready):

					# If this is a ZLP, we don't have a payload to send.
					# Skip directly to sending our CRC.
					with m.If(is_zlp):
						m.next = 'SEND_CRC_FIRST'

					# Otherwise, we have a payload. Send it.
					with m.Else():
						m.next = 'SEND_PAYLOAD'


			# SEND_PAYLOAD -- send the data payload for our stream
			with m.State('SEND_PAYLOAD'):

				# While sending the payload, we'll essentially connect
				# our stream directly through to the ULPI transmitter.
				m.d.comb += self.stream.bridge_to(self.tx)

				# We'll stop sending once the packet ends, and move on to our CRC.
				with m.If(self.tx.ready & (self.stream.last | ~self.stream.valid)):
					m.next = 'SEND_CRC_FIRST'


			# SEND_CRC_FIRST -- send the first byte of the packet's CRC
			with m.State('SEND_CRC_FIRST'):

				# Capture the current CRC for use in the next byte...
				m.d.usb += remaining_crc.eq(self.crc.crc[8:])

				# Send the relevant CRC byte...
				m.d.comb += [
					self.tx.data.eq(self.crc.crc[0:8]),
					self.tx.valid.eq(1),
				]

				# ... and move on to the next one.
				with m.If(self.tx.ready):
					m.next = 'SEND_CRC_SECOND'


			# SEND_CRC_LAST -- send the last byte of the packet's CRC
			with m.State('SEND_CRC_SECOND'):

				# Send the relevant CRC byte...
				m.d.comb += [
					self.tx.data.eq(remaining_crc),
					self.tx.valid.eq(1),
				]

				# ... and return to idle.
				with m.If(self.tx.ready):
					m.next = 'IDLE'

		return m




class USBHandshakeGenerator(Elaboratable):
	''' Module that generates handshake packets, on request.

	Attributes:

	issue_ack: Signal(), input
		Pulsed to generate an ACK handshake packet.
	issue_nak: Signal(), input
		Pulsed to generate a NAK handshake packet.
	issue_stall: Signal(), input
		Pulsed to generate a STALL handshake.

	tx: UTMITransmitInterface
		Interface to the relevant UTMI interface.
	'''

	# Full contents of an ACK, NAK, and STALL packet.
	# These include the four check bits; which consist of the inverted PID.
	_PACKET_ACK   = 0b11010010
	_PACKET_NAK   = 0b01011010
	_PACKET_STALL = 0b00011110

	def __init__(self):

		#
		# I/O port
		#
		self.issue_ack    = Signal()
		self.issue_nak    = Signal()
		self.issue_stall  = Signal()

		self.tx           = UTMITransmitInterface()


	def elaborate(self, platform):
		m = Module()

		with m.FSM(domain = 'usb'):

			# IDLE -- we haven't yet received a request to transmit
			with m.State('IDLE'):
				m.d.comb += self.tx.valid.eq(0)

				# Wait until we have an ACK, NAK, or STALL request;
				# Then set our data value to the appropriate PID,
				# in preparation for the next cycle.

				with m.If(self.issue_ack):
					m.d.usb += self.tx.data.eq(self._PACKET_ACK),
					m.next = 'TRANSMIT'

				with m.If(self.issue_nak):
					m.d.usb += self.tx.data.eq(self._PACKET_NAK),
					m.next = 'TRANSMIT'

				with m.If(self.issue_stall):
					m.d.usb += self.tx.data.eq(self._PACKET_STALL),
					m.next = 'TRANSMIT'


			# TRANSMIT -- send the handshake.
			with m.State('TRANSMIT'):
				m.d.comb += self.tx.valid.eq(1)

				# Once we know the transmission will be accepted, we're done!
				# Move back to IDLE.
				with m.If(self.tx.ready):
					m.next = 'IDLE'

		return m




class USBInterpacketTimer(Elaboratable):
	''' Module that tracks inter-packet timings, enforcing spec-mandated packet gaps.

	Ports other than :attr:`speed` are added dynamically via :method:add_interface`.

	Attributes
	----------
	speed: Signal(2), input
		The device's current operating speed. Should be a USBSpeed enumeration value --
		0 for high, 1 for full, 2 for low.

	'''

	# Per the USB 2.0 and ULPI 1.1 specifications, after receipt:
	#   - A FS/LS device needs to wait 2 bit periods before transmitting; and must
	#     respond before 6.5 bit times pass. [USB2, 7.1.18.1]
	#   - Two FS bit periods is equivalent to 10 ULPI clocks, and two LS periods is
	#     equivalent to 80 ULPI clocks. 6.5 FS bit periods is equivalent to 32 ULPI clocks,
	#     and 6.5 LS bit periods is equivalent to 260 ULPI clocks. [ULPI 1.1, Figure 18].
	#   - A HS device needs to wait 8 HS bit periods before transmitting [USB2, 7.1.18.2].
	#     Each ULPI cycle is 8 HS bit periods, so we'll only need to wait one cycle.
	_HS_RX_TO_TX_DELAY     = {60e6: (  1,  24)}
	_FS_RX_TO_TX_DELAY     = {60e6: ( 10,  32), 12e6: (2, 7)}
	_LS_RX_TO_TX_DELAY     = {60e6: ( 80, 260)}

	# Per the USB 2.0 and ULPI 1.1 specifications, after transission:
	#   - A FS/LS can assume it won't receive a response after 16 bit times [USB2, 7.1.18.1].
	#     This is equivalent to 80 ULPI clocks (FS), or 640 ULPI clocks (LS).
	#   - A HS device can assume it won't receive a response after 736 bit times.
	#     This is equivalent to 92 ULPI clocks.
	_HS_TX_TO_RX_TIMEOUT = {60e6:  92}
	_FS_TX_TO_RX_TIMEOUT = {60e6:  80, 12e6: 16}
	_LS_TX_TO_RX_TIMEOUT = {60e6: 640}


	def __init__(self, domain_clock = 60e6, fs_only = False):
		self._fs_only = fs_only

		# Start off with empty delays -- this doesn't change anything, but makes
		# linters happy. :)
		self._hs_rx_to_tx_delay   = None
		self._ls_rx_to_tx_delay   = None
		self._hs_rx_to_tx_timeout = None
		self._ls_rx_to_tx_timeout = None

		# Validate that we have a usable FS Rx/Tx delay.
		if domain_clock not in self._FS_RX_TO_TX_DELAY:
			raise ValueError(f'Domain clock must be in {self._FS_TX_TO_RX_TIMEOUT.keys()}, not {domain_clock}.')

		# Capture our FS delay for the current clock speed.
		self._fs_rx_to_tx_delay   = self._FS_RX_TO_TX_DELAY[domain_clock]
		self._fs_tx_to_rx_timeout = self._FS_TX_TO_RX_TIMEOUT[domain_clock]
		self._counter_max = self._FS_TX_TO_RX_TIMEOUT[domain_clock]

		# If we're not in a FS-only configuration, capture our other delays.
		if not self._fs_only:
			if domain_clock not in self._HS_RX_TO_TX_DELAY:
				raise ValueError(f'Domain clock must be in {self._FS_TX_TO_RX_TIMEOUT.keys()}, not {domain_clock}.')

			# Capute our HS and LS delays for the given clock speed.
			self._hs_rx_to_tx_delay   = self._HS_RX_TO_TX_DELAY[domain_clock]
			self._ls_rx_to_tx_delay   = self._LS_RX_TO_TX_DELAY[domain_clock]
			self._hs_tx_to_rx_timeout = self._HS_TX_TO_RX_TIMEOUT[domain_clock]
			self._ls_tx_to_rx_timeout = self._LS_TX_TO_RX_TIMEOUT[domain_clock]
			self._counter_max         = self._LS_TX_TO_RX_TIMEOUT[domain_clock]



		# List of interfaces to users of this module.
		self._interfaces               = []

		#
		# I/O port
		#
		self.speed = Signal(2)


	def add_interface(self, interface: InterpacketTimerInterface):
		''' Adds a connection to a user of this module.

		This module performs no multiplexing; it's assumed only one interface will be active at a time.

		Parameters
		---------
		interface: InterpacketTimerInterface
			The InterPacketTimer interface to add to our module.
		'''
		self._interfaces.append(interface)


	def elaborate(self, platform):
		m = Module()

		# Internal signals representing each of our timeouts.
		rx_to_tx_at_min  = Signal()
		rx_to_tx_at_max  = Signal()
		tx_to_rx_timeout = Signal()

		# Create a counter that will track our interpacket delays.
		# This should be able to count up to our longest delay. We'll allow our
		# counter to be able to increment one past its maximum, and let it saturate
		# there, after the count.
		counter = Signal(range(0, self._counter_max + 2))

		# Reset our timer whenever any of our interfaces request a timer start.
		reset_signals = (interface.start for interface in self._interfaces)
		any_reset = functools.reduce(operator.__or__, reset_signals)

		# When a reset is requested, start the counter from 0.
		with m.If(any_reset):
			m.d.usb += counter.eq(0)
		with m.Elif(counter < self._counter_max + 1):
			m.d.usb += counter.eq(counter + 1)

		#
		# Create our counter-progress strobes.
		# This could be made less repetitive, but spreading it out here
		# makes the documentation above clearer.
		#
		with m.If(self.speed == USBSpeed.HIGH):
			if not self._fs_only:
				m.d.comb += [
					rx_to_tx_at_min.eq(counter == self._hs_rx_to_tx_delay[0]),
					rx_to_tx_at_max.eq(counter == self._hs_rx_to_tx_delay[1]),
					tx_to_rx_timeout.eq(counter == self._hs_tx_to_rx_timeout)
				]
		with m.Elif(self.speed == USBSpeed.FULL):
			m.d.comb += [
				rx_to_tx_at_min.eq(counter == self._fs_rx_to_tx_delay[0]),
				rx_to_tx_at_max.eq(counter == self._fs_rx_to_tx_delay[1]),
				tx_to_rx_timeout.eq(counter == self._fs_tx_to_rx_timeout)
			]
		with m.Else():
			if not self._fs_only:
				m.d.comb += [
					rx_to_tx_at_min.eq(counter == self._hs_rx_to_tx_delay[0]),
					rx_to_tx_at_max.eq(counter == self._hs_rx_to_tx_delay[1]),
					tx_to_rx_timeout.eq(counter == self._hs_tx_to_rx_timeout)
				]

		# Tie our strobes to each of our consumers.
		for interface in self._interfaces:
			m.d.comb += [
				interface.tx_allowed.eq(rx_to_tx_at_min),
				interface.tx_timeout.eq(rx_to_tx_at_max),
				interface.rx_timeout.eq(tx_to_rx_timeout)
			]


		return m
