#!/usr/bin/env python

"""
This script automatically generates address select logic for IPbus-based firmware designs.

The script takes a uHAL-compliant XML address file and writes a VHDL module containing the
corresponding address decoder function to a file named 'ipbus_decode_<addr_table_name>.vhd'.

Note that full address decoding is not performed (would be very inefficient with 32b address
space), so slaves will appear at many locations.
"""

from __future__ import print_function

import argparse
import logging
import itertools
import math
import os.path
import sys
import time

if sys.version_info[0] < 3:
    range = xrange


       
class BitArray:
    
    def __init__(self,value=0):
        self.length = 32
        if (value> (1<<self.length)-1):
            raise Exception("Value '%d' too big. It does not fit in %d bits\n" % (value,self.length))
        
        self.bits = [value>>i & 1 for i in range(self.length)]

    def __getitem__(self,index):
        return self.bits[index]
        
    def __and__(self,other):
        if not isinstance(other,BitArray):
            raise Exception("BitArray.__and__ method requires BitArray")

        result = BitArray()
        result.bits = [x & y for (x,y) in zip(self.bits,other.bits)]
        return result

    def __or__(self,other):
        if not isinstance(other,BitArray):
            raise Exception("BitArray.__and__ method requires BitArray")

        result = BitArray()
        result.bits = [x | y for (x,y) in zip(self.bits,other.bits)]
        return result
    
    def __invert__(self):
        result = BitArray()
        result.bits = [(x+1) % 2 for x in self.bits]
        return result
    
    def __str__(self):
        return "".join(map(lambda x: x and "1" or "0",self.bits[::-1]))

    def uint(self):
        result = 0
        for i,x in enumerate(self.bits):
            result += x << i
            
        return result

    def hex(self):
        return "0x%08x" % self.uint()


class node(object):

    """
    Class representing one address tree node
    
    """
    
    def __init__(self, name, addr = 0, flag = False, size = 0, width = None):
        self.name = name
        self.flag = flag
        self.addr = addr
        self.size = size
        self.width = width
        self.desc = list()

    def __str__(self):
        w = hex(self.width) if self.width is not None else 'None' 
        return self.name + " " + hex(self.addr) + " " + w + " " + str(self.flag)
        

    def sort(self):
        """
        Sort my descendants into address order
        """
        self.desc = sorted(self.desc, key = lambda x: x.addr)
        
class nodetree(object):

    def __init__(self):
        self.root = node("TOP")

# Add a node to the tree

    def add(self, n):
        np = self.root
        m = n.name.split('.')
        for i in range(len(m)):
            name = '.'.join(m[: i + 1])
            l = [x.name for x in np.desc]
            if name not in l:
                np.desc.append(n)
                break
            else:
                np = np.desc[l.index(name)]
    
    def __str__(self):
        return self.dump(self.root, "")

# Dump the tree in human readable format
    
    def dump(self, n, prefix):
        s = ""
        for i in n.desc:
            s = s + prefix + str(i) + "\n"
            if i.desc: s = s + self.dump(i, prefix + "  ")
        return s

# Assign widths to endpoint nodes

    def assign_widths(self):
        errors = []
        self._assign_width(self.root, errors)
        return len(errors) == 0

# Assign a width to an endpoint node based on addresses of descendants

    def _assign_width(self, n, errors_detected):
        n.sort()
        l = self.get_flagged_nodes(n)
        if not l: # This is a bottom-layer endpoint node
            if n.width is None: # Error, we must have an explicit width
                log.warning("Node <<" + n.name + ">> has no endpoint descendants and no width parameter - assuming width 0")
                n.width = 0
            if n.size > 2 ** n.width: # Check our own size
                log.error("Endpoint node <<" + n.name + ">> has declared size " + hex(n.size) + " that exceeds its address width " + hex(n.width))
                errors_detected.append(n.name)
            for i in self.get_desc(n): # Check that our node content does not overflow our address range
                if i.addr + i.size - 1 >= n.addr + 2 ** n.width:
                    log.error("Endpoint node <<" + n.name + ">> (base address " + hex(n.addr) + ", width " + str(n.width) + ") has a sub-node <<" + i.name + ">> (base address " + hex(i.addr) + ", size " + hex(i.size) + ") that overflows its address space")
                    errors_detected.append(i.name)
        else: # This is an intermediate-layer endpoint node that may or may not have a width assigned explicitly
            if self.get_leaves(n):
                log.error("Endpoint node <<" + n.name + ">> has a mixture of endpoint and non-endpoint descendants")
                errors_detected.append(n.name)
            max_addr = 0
            for i in l: # Find the maximum span of our endpoint descendants
                max_addr = i.addr + pow(2, self._assign_width(i, errors_detected))
            if n.width is None:
                n.width = int(math.ceil(math.log(max_addr - n.addr,2)))
            elif max_addr > n.addr + 2 ** n.width:
                log.error("Endpoint node <<" + n.name + ">> (base address " + hex(n.addr) + ", width " + str(n.width) + ") has insufficient width to contain its descendants (maximum address " + hex(max_addr) + ")")
                errors_detected.append(n.name)
        if n.addr % pow(2, n.width) != 0: # Check our alignment
            log.error("Node <<" + n.name + ">> at " + hex(n.addr) + ": Width, " + hex(n.width) + ", produces address mask " + hex(pow(2,n.width)) + " that is not aligned with node's base address. The address must be a multiple of 2^width.")
            errors_detected.append(n.name)
        return n.width

# Get all descendants of a node

    def get_desc(self, n):
        l = list()
        for i in n.desc:
            l.append(i)
            l.extend(self.get_desc(i))
        return l

# Get all bottom-layer leaves below a node

    def get_leaves(self, n):
        l = list()
        for i in n.desc:
            if not i.flag:
                if not i.desc:
                    l.append(i)
                else:
                    l.extend(self.get_leaves(i))
        return l

# Get immediate endpoint descendants of a node

    def get_flagged_nodes(self, n):
        l = list()
        for i in n.desc:
            if i.flag:
                l.append(i)
            else:
                l.extend(self.get_flagged_nodes(i))
        return l

#===========================================================================================

EXIT_CODE_ARG_PARSING_ERROR   = 1
EXIT_CODE_NODE_ADDRESS_ERRORS = 2
EXIT_CODE_IMPORT_ERROR        = 3

def main():    
    logging.basicConfig(level=logging.WARNING, format='%(levelname)s\t: %(message)s')

    # configure logger
    global log
    log = logging.getLogger("main")

    try:
        import uhal
    except ImportError as e:
        print('ERROR: ' + str(e))
        sys.exit(EXIT_CODE_IMPORT_ERROR)

    uhal.setLogLevelTo(uhal.LogLevel.WARNING)

    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawTextHelpFormatter
        )
    parser.add_argument('-v', '--verbose', help="increase output verbosity (default: %(default)s)", action="store_true")
    parser.add_argument('-d', '--debug', help="enable debug messages (default: %(default)s)", action="store_true")
    parser.add_argument('-n', '--dry-run', help="Dry run (default: %(default)s)", action="store_true", default=False)
    parser.add_argument('-t', '--template', help="Decoder template path (default: %(default)s)", default='/opt/cactus/etc/uhal/tools/ipbus_addr_decode.vhd')
    parser.add_argument('--no-timestamp', help="Do not include timestamp in comments (default: %(default)s)", action="store_true", default=False)
    parser.add_argument('addrtab', help="Address table file")
    args = parser.parse_args()

    if args.verbose:
        log.setLevel(logging.INFO)
        uhal.setLogLevelTo(uhal.LogLevel.INFO)
    if args.debug:
        log.setLevel(logging.DEBUG)
        uhal.setLogLevelTo(uhal.LogLevel.DEBUG)  

    # Ask the API to read and parse the address tree
    try:
        device = uhal.getDevice("dummy","ipbusudp-1.3://localhost:12345","file://" + args.addrtab)
    except Exception as e:
        log.error("Exception thrown when parsing address table '{}'".format(args.addrtab))
        log.error(str(e))
        sys.exit(EXIT_CODE_ARG_PARSING_ERROR)
        
# Build the node tree

    t = nodetree()
    for i in device.getNodes():
        d = device.getNode(i)
        if d.getMask() == 0xffffffff:
            f = d.getFirmwareInfo()
            p = "type" in f and f["type"] == "endpoint"
            s = d.getSize() if d.getMode() == uhal.BlockReadWriteMode.INCREMENTAL else 1
            w = int(f["width"]) if "width" in f else None
            t.add(node(i, d.getAddress(), p, s, w))


    for l in str(t).splitlines():
        log.debug(l)

# Assign widths

    if not t.assign_widths():
        log.error("Node errors detected, exiting early before writing output")
        sys.exit(EXIT_CODE_NODE_ADDRESS_ERRORS)

# Check for overlaps in address ranges of different endpoints

    error_count = 0
    endpoint_nodes = t.get_flagged_nodes(t.root)
    for n1, n2 in itertools.permutations(endpoint_nodes, 2):
        if n1.addr <= n2.addr:
            if (n1.addr + 2 ** n1.width - 1) >= n2.addr:
                log.error("Node <<{0[0]}>> (address range 0x{0[1]:x} to 0x{0[2]:x}) overlaps with node <<{1[0]}>> (0x{1[1]:x} to 0x{1[2]:x})".format(*[(x.name, x.addr, x.addr + 2 ** x.width - 1) for x in [n1,n2]]))
                error_count += 1

# Figure out address masks

    or_prod = BitArray(0x0)
    and_prod = BitArray(0xffffffff)
    slaves = []

    for n, i in enumerate(endpoint_nodes):
        addr_bits = BitArray(i.addr)
        mask_bits = BitArray(2 ** i.width - 1)
        log.info("uHAL slave %s,%s,0b%s,0b%s" % (n, i.name, addr_bits, mask_bits))
        if (addr_bits & mask_bits).uint() != 0:
            log.error("Node <<" + i.name + ">> has a non-aligned address " + str(addr_bits) + " with respect to address bit mask " + str(mask_bits))
            error_count += 1
        or_prod = or_prod | (addr_bits & ~mask_bits)
        and_prod = and_prod & (addr_bits | mask_bits)
        slaves.append((n, i.name, addr_bits, mask_bits))

    addr_mask = or_prod & ~and_prod
    log.info("Significant address bits 0b%s" % str(addr_mask))

    if error_count > 0:
        log.error("Node errors detected, exiting early before writing output")
        sys.exit(EXIT_CODE_NODE_ADDRESS_ERRORS)


# FIXME: this algorithm leads to more than the required number of address bits being
# compared in the VHDL in some circumstances. Could add a loop here to calculate the
# number of address bits needed for a unique match for a given slave.

# Generate VHDL code

    moduleName = os.path.splitext(os.path.basename(args.addrtab))[0]
    if not args.dry_run:

        timestamp_suffix = "" if args.no_timestamp else (" (" + time.asctime() + ")" )

        # generate vhdl snippet with symbolic constants
        numSlaves = len(slaves)
        snippet1 = "-- START automatically generated VHDL" + timestamp_suffix + "\n"
        if numSlaves > 0:
            slaveIds={}
            for n,id,addr_bits,mask_bits in slaves:
                slaveIds[n]="N_SLV_"+id.upper().replace(".","_")
                snippet1 += "  constant "+slaveIds[n]+": integer := %i;\n"%n
        snippet1 += "  constant N_SLAVES: integer := %i;\n"%(numSlaves)
        snippet1 += "-- END automatically generated VHDL"
        
        # generate vhdl snippet with selection code
        snippet2 = "-- START automatically generated VHDL" + timestamp_suffix + "\n"
        if numSlaves == 0:
            snippet2 += "    sel := ipbus_sel_t(to_unsigned(0, IPBUS_SEL_WIDTH));\n"
        elif numSlaves == 1:
            snippet2 += "    sel := ipbus_sel_t(to_unsigned(" + slaveIds[0] + ", IPBUS_SEL_WIDTH));\n"
        else:
            for n,id,addr_bits,mask_bits in slaves:
                mask = addr_mask & ~mask_bits
                if n == 0:
                    snippet2 += "    if    "
                else:
                    snippet2 += "    elsif "

                snippet2 += "std_match(addr, \""
                for i in range(31, -1, -1):
                    if not mask[i]:
                        snippet2 += "-"
                    elif addr_bits[i]:
                        snippet2 += "1"
                    else:
                        snippet2 += "0"
                snippet2 += "\") then\n"
            
                snippet2 += "      sel := ipbus_sel_t(to_unsigned(" + slaveIds[n] + ", IPBUS_SEL_WIDTH)); -- " + str(id) + " / base " + addr_bits.hex() + " / mask " + mask.hex() + "\n"
            snippet2 += "    else\n"
            snippet2 += "        sel := ipbus_sel_t(to_unsigned(N_SLAVES, IPBUS_SEL_WIDTH));\n"
            snippet2 += "    end if;\n"
        snippet2 += "-- END automatically generated VHDL\n"

        # calculate required selector bit width
        ipbus_sel_width = 1
        if len(slaves):
            ipbus_sel_width = int(math.floor(math.log(numSlaves,2))+1)
        snippet3 = "-- START automatically generated VHDL" + timestamp_suffix + "\n"
        snippet3 += "  constant IPBUS_SEL_WIDTH: positive := " + str(ipbus_sel_width) + ";\n"
        snippet3 += "-- END automatically generated VHDL\n"

        # print vhdl code
        vhdlfilename = "ipbus_decode_"+moduleName+".vhd"
        vhdlfile=open(vhdlfilename,"w")
        vhdlfile.write(open(args.template).read()
                       .replace("-- INSERT_SYMBOLIC_CONSTANTS_HERE",snippet1)\
                       .replace("-- INSERT_SELECTION_HERE",snippet2)\
                       .replace("-- INSERT_SEL_WIDTH_HERE", snippet3)\
                       .replace("PACKAGENAME",moduleName))
        vhdlfile.close()
        print("VHDL decode file saved: ", vhdlfilename)
    
if __name__ == '__main__':
    main()
  
