#! /usr/bin/python3

import sys
import os.path
import re

macros = {}

# Map from OpenFlow version number to version ID used in ofp_header.
version_map = {"1.0": 0x01,
               "1.1": 0x02,
               "1.2": 0x03,
               "1.3": 0x04,
               "1.4": 0x05,
               "1.5": 0x06}
version_reverse_map = dict((v, k) for (k, v) in version_map.items())

token = None
line = ""
idRe = "[a-zA-Z_][a-zA-Z_0-9]*"
tokenRe = "#?" + idRe + "|[0-9]+|."
inComment = False
inDirective = False

n_errors = 0


def open_file(fn):
    global fileName
    global inputFile
    global lineNumber
    fileName = fn
    inputFile = open(fileName)
    lineNumber = 0


def tryGetLine():
    global inputFile
    global line
    global lineNumber
    line = inputFile.readline()
    lineNumber += 1
    return line != ""


def getLine():
    if not tryGetLine():
        fatal("unexpected end of input")


def getToken():
    global token
    global line
    global inComment
    global inDirective
    while True:
        line = line.lstrip()
        if line != "":
            if line.startswith("/*"):
                inComment = True
                line = line[2:]
            elif inComment:
                commentEnd = line.find("*/")
                if commentEnd < 0:
                    line = ""
                else:
                    inComment = False
                    line = line[commentEnd + 2:]
            else:
                match = re.match(tokenRe, line)
                token = match.group(0)
                line = line[len(token):]
                if token.startswith('#'):
                    inDirective = True
                elif token in macros and not inDirective:
                    line = macros[token] + line
                    continue
                return True
        elif inDirective:
            token = "$"
            inDirective = False
            return True
        else:
            global lineNumber
            line = inputFile.readline()
            lineNumber += 1
            while line.endswith("\\\n"):
                line = line[:-2] + inputFile.readline()
                lineNumber += 1
            if line == "":
                if token is None:
                    fatal("unexpected end of input")
                token = None
                return False


def error(msg):
    global n_errors
    sys.stderr.write("%s:%d: %s\n" % (fileName, lineNumber, msg))
    n_errors += 1


def fatal(msg):
    error(msg)
    sys.exit(1)


def skipDirective():
    getToken()
    while token != '$':
        getToken()


def isId(s):
    return re.match(idRe + "$", s) is not None


def forceId():
    if not isId(token):
        fatal("identifier expected")


def forceInteger():
    if not re.match(r'[0-9]+$', token):
        fatal("integer expected")


def match(t):
    if token == t:
        getToken()
        return True
    else:
        return False


def forceMatch(t):
    if not match(t):
        fatal("%s expected" % t)


def parseTaggedName():
    assert token in ('struct', 'union')
    name = token
    getToken()
    forceId()
    name = "%s %s" % (name, token)
    getToken()
    return name


def print_enum(tag, constants, storage_class):
    print("""
%(storage_class)sconst char *
%(tag)s_to_string(uint16_t value)
{
    switch (value) {\
""" % {"tag": tag,
       "storage_class": storage_class})
    for constant in constants:
        print("    case %s: return \"%s\";" % (constant, constant))
    print("""\
    }
    return NULL;
}""")


def usage():
    argv0 = os.path.basename(sys.argv[0])
    print('''\
%(argv0)s, for extracting OpenFlow error codes from header files
usage: %(argv0)s ERROR_HEADER VENDOR_HEADER

This program reads VENDOR_HEADER to obtain OpenFlow vendor (aka
experimenter IDs), then ERROR_HEADER to obtain OpenFlow error number.
It outputs a C source file for translating OpenFlow error codes into
strings.

ERROR_HEADER should point to include/openvswitch/ofp-errors.h.
VENDOR_HEADER should point to include/openflow/openflow-common.h.
The output is suitable for use as lib/ofp-errors.inc.\
''' % {"argv0": argv0})
    sys.exit(0)


def extract_vendor_ids(fn):
    global vendor_map
    vendor_map = {}
    vendor_loc = {}

    open_file(fn)
    while tryGetLine():
        m = re.match(
            r'#define\s+([A-Z0-9_]+)_VENDOR_ID\s+(0x[0-9a-fA-F]+|[0-9]+)',
            line
        )
        if not m:
            continue

        name = m.group(1)
        id_ = int(m.group(2), 0)

        if name in vendor_map:
            error("%s: duplicate definition of vendor" % name)
            sys.stderr.write("%s: Here is the location of the previous "
                             "definition.\n" % vendor_loc[name])
            sys.exit(1)

        vendor_map[name] = id_
        vendor_loc[name] = "%s:%d" % (fileName, lineNumber)

    if not vendor_map:
        fatal("%s: no vendor definitions found" % fn)

    inputFile.close()

    vendor_reverse_map = {}
    for name, id_ in vendor_map.items():
        if id_ in vendor_reverse_map:
            fatal("%s: duplicate vendor id for vendors %s and %s"
                  % (id_, vendor_reverse_map[id_], name))
        vendor_reverse_map[id_] = name


def extract_ofp_errors(fn):
    comments = []
    names = []
    domain = {}
    reverse = {}
    for domain_name in version_map.values():
        domain[domain_name] = {}
        reverse[domain_name] = {}

    n_errors = 0
    expected_errors = {}

    open_file(fn)

    while True:
        getLine()
        if re.match(r'enum ofperr', line):
            break

    while True:
        getLine()
        if line.startswith('/*') or not line or line.isspace():
            continue
        elif re.match(r'}', line):
            break

        if not line.lstrip().startswith('/*'):
            fatal("unexpected syntax between errors")

        comment = line.lstrip()[2:].strip()
        while not comment.endswith('*/'):
            getLine()
            if line.startswith('/*') or not line or line.isspace():
                fatal("unexpected syntax within error")
            comment += ' %s' % line.lstrip('* \t').rstrip(' \t\r\n')
        comment = comment[:-2].rstrip()

        m = re.match(r'Expected: (.*)\.$', comment)
        if m:
            expected_errors[m.group(1)] = (fileName, lineNumber)
            continue

        m = re.match(r'((?:.(?!\.  ))+.)\.\s+(.*)$', comment)
        if not m:
            fatal("unexpected syntax between errors")

        dsts, comment = m.groups()

        getLine()
        m = re.match(r'\s+(?:OFPERR_([A-Z0-9_]+))(\s*=\s*OFPERR_OFS)?,',
                     line)
        if not m:
            fatal("syntax error expecting OFPERR_ enum value")

        enum = m.group(1)
        if enum in names:
            fatal("%s specified twice" % enum)

        comments.append(re.sub(r'\[[^]]*\]', '', comment))
        names.append(enum)

        for dst in dsts.split(', '):
            m = re.match(
                r'([A-Z]+)([0-9.]+)(\+|-[0-9.]+)?\((\d+)(?:,(\d+))?\)$',
                dst
            )
            if not m:
                fatal("%r: syntax error in destination" % dst)
            vendor_name = m.group(1)
            version1_name = m.group(2)
            version2_name = m.group(3)
            type_ = int(m.group(4))
            if m.group(5):
                code = int(m.group(5))
            else:
                code = None

            if vendor_name not in vendor_map:
                fatal("%s: unknown vendor" % vendor_name)
            vendor = vendor_map[vendor_name]

            if version1_name not in version_map:
                fatal("%s: unknown OpenFlow version" % version1_name)
            v1 = version_map[version1_name]

            if version2_name is None:
                v2 = v1
            elif version2_name == "+":
                v2 = max(version_map.values())
            elif version2_name[1:] not in version_map:
                fatal("%s: unknown OpenFlow version" % version2_name[1:])
            else:
                v2 = version_map[version2_name[1:]]

            if v2 < v1:
                fatal("%s%s: %s precedes %s"
                      % (version1_name, version2_name,
                         version2_name, version1_name))

            if vendor == vendor_map['OF']:
                # All standard OpenFlow errors have a type and a code.
                if code is None:
                    fatal("%s: %s domain requires code" % (dst, vendor_name))
            elif vendor == vendor_map['NX']:
                # Before OpenFlow 1.2, OVS used a Nicira extension to
                # define errors that included a type and a code.
                #
                # In OpenFlow 1.2 and later, Nicira extension errors
                # are defined using the OpenFlow experimenter error
                # mechanism that includes a type but not a code.
                if v1 < version_map['1.2'] or v2 < version_map['1.2']:
                    if code is None:
                        fatal("%s: NX1.0 and NX1.1 domains require code" % dst)
                if v1 >= version_map['1.2'] or v2 >= version_map['1.2']:
                    if code is not None:
                        fatal("%s: NX1.2+ domains do not have codes" % dst)
            else:
                # Experimenter extension error for OF1.2+ only.
                if v1 < version_map['1.2']:
                    fatal("%s: %s domain not supported before OF1.2"
                          % (dst, vendor_name))
                if code is not None:
                    fatal("%s: %s domains do not have codes"
                          % (dst, vendor_name))
            if code is None:
                code = 0

            for version in range(v1, v2 + 1):
                domain[version].setdefault(vendor, {})
                domain[version][vendor].setdefault(type_, {})
                if code in domain[version][vendor][type_]:
                    msg = "%#x,%d,%d in OF%s means both %s and %s" % (
                        vendor, type_, code, version_reverse_map[version],
                        domain[version][vendor][type_][code][0], enum)
                    if msg in expected_errors:
                        del expected_errors[msg]
                    else:
                        error("%s: %s." % (dst, msg))
                        sys.stderr.write(
                            "%s:%d: %s: Here is the location "
                            "of the previous definition.\n"
                            % (domain[version][vendor][type_][code][1],
                               domain[version][vendor][type_][code][2],
                               dst)
                        )
                else:
                    domain[version][vendor][type_][code] = (enum, fileName,
                                                   lineNumber)

                assert enum not in reverse[version]
                reverse[version][enum] = (vendor, type_, code)

    inputFile.close()

    for fn, ln in expected_errors.values():
        sys.stderr.write("%s:%d: expected duplicate not used.\n" % (fn, ln))
        n_errors += 1

    if n_errors:
        sys.exit(1)

    print("""\
/* Generated automatically; do not modify!     -*- buffer-read-only: t -*- */

#define OFPERR_N_ERRORS %d

struct ofperr_domain {
    const char *name;
    uint8_t version;
    enum ofperr (*decode)(uint32_t vendor, uint16_t type, uint16_t code);
    struct triplet errors[OFPERR_N_ERRORS];
};

static const char *error_names[OFPERR_N_ERRORS] = {
%s
};

static const char *error_comments[OFPERR_N_ERRORS] = {
%s
};\
""" % (len(names),
       '\n'.join('    "%s",' % name for name in names),
       '\n'.join('    "%s",' % re.sub(r'(["\\])', r'\\\1', comment)
                 for comment in comments)))

    def output_domain(map, name, description, version):
        print("""
static enum ofperr
%s_decode(uint32_t vendor, uint16_t type, uint16_t code)
{
    switch (((uint64_t) vendor << 32) | ((uint32_t) type << 16) | code) {"""
               % name)
        found = set()
        for enum in names:
            if enum not in map:
                continue
            vendor, type_, code = map[enum]
            value = (vendor << 32) | (type_ << 16) | code
            if value in found:
                continue
            found.add(value)
            if vendor:
                vendor_s = "(%#xULL << 32) | " % vendor
            else:
                vendor_s = ""
            print("    case %s ((uint32_t) %d << 16) | %d:" % (vendor_s,
                                                               type_, code))
            print("        return OFPERR_%s;" % enum)
        print("""\
    }

    return 0;
}""")

        print("""
static const struct ofperr_domain %s = {
    "%s",
    %d,
    %s_decode,
    {""" % (name, description, version, name))
        for enum in names:
            if enum in map:
                vendor, type_, code = map[enum]
                if code is None:
                    code = -1
                print("        { %#8x, %2d, %3d }, /* %s */" % (vendor, type_,
                                                                code, enum))
            else:
                print("        {       -1, -1,  -1 }, /* %s */" % enum)
        print("""\
    },
};""")

    for version_name, id_ in version_map.items():
        var = 'ofperr_of' + re.sub(r'[^A-Za-z0-9_]', '', version_name)
        description = "OpenFlow %s" % version_name
        output_domain(reverse[id_], var, description, id_)


if __name__ == '__main__':
    if '--help' in sys.argv:
        usage()
    elif len(sys.argv) != 3:
        sys.stderr.write("exactly two non-options arguments required; "
                         "use --help for help\n")
        sys.exit(1)
    else:
        extract_vendor_ids(sys.argv[2])
        extract_ofp_errors(sys.argv[1])
