#!/usr/bin/env python3
############################################################################
# Copyright (C) SchedMD LLC.
############################################################################
import os, re, sys

# SchedMD
sys.path.append(sys.path[0] + "/src/lib")
from db.test_db import (
    get_new_run_list,
    reset_suite_run_ids,
    setup_db_if_new,
    update_records_from_dirs,
)
from cli import (
    get_args,
    test_parser,
)
from utils.cmds import (
    perform,
)
from utils.conf import (
    get_vars_from_conf,
)
from utils.fs import (
    create_dir,
    file_exists,
    exit_if_not_exists,
)
from utils.log import (
    log,
    log_header,
    log_new_line,
    print_pretty_dict,
)
from utils.ps import (
    is_tool,
)
from utils.test.test_list import (
    get_unit_tl_from_build_dir,
    get_tl_from_dir,
)
from test_runners.unit_runner import (
    get_unit_run_stats,
    run_unit_tests,
)
from test_runners.regressions_runner import (
    get_regressions_run_stats,
    run_regressions_tests,
)

# Set needed python version
REQUIRED_PYTHON_VERSION = "3.8"


def main():
    req_ver = [int(x) for x in REQUIRED_PYTHON_VERSION.split(".")]
    if sys.version_info[0] < req_ver[0] or sys.version_info[1] < req_ver[1]:
        sys.exit(f"""
*** PYTHON VERSION ERROR ***

runtests requires python3 version {REQUIRED_PYTHON_VERSION} or greater, got {sys.version}.

- Exiting -""")
    else:
        perform(
            "Executing ./run-tests",
            run_app,
        )


def run_app():
    # Set defaults
    APP_DIR = os.getcwd()
    SRC_DIR = f"{APP_DIR}/src"
    SEED_FILE = f"{SRC_DIR}/seed_data"
    CLUSTER_DIR_PREFIX = "results"
    CLUSTER_DEFAULT_NAME = "default-run"
    TS_CONF_DEFAULT = f"{APP_DIR}/testsuite.conf"

    # Handle envars
    test_env = os.environ.copy()
    test_env.setdefault("SLURM_TESTSUITE_CONF", TS_CONF_DEFAULT)
    if test_env["SLURM_TESTSUITE_CONF"] == TS_CONF_DEFAULT:
        if not file_exists(TS_CONF_DEFAULT):
            msg = f"""
*** CONFIG ERROR ***

SLURM_TESTSUITE_CONF environment variable not set and the default testsuite.conf
file doesn't exist.

Please set SLURM_TESTSUITE_CONF to the appropriate testsuite.conf or create a
testsuite.conf file in this directory as outlined in the README.

- Exiting - """
            sys.exit(msg)

    # Retrieve slurm directory data from ts_conf
    ts_conf = test_env["SLURM_TESTSUITE_CONF"]
    ts_conf_vars = get_vars_from_conf(ts_conf)
    slurm_src_dir = ts_conf_vars["SlurmSourceDir"]
    slurm_build_dir = ts_conf_vars["SlurmBuildDir"]
    slurm_install_dir = ts_conf_vars["SlurmInstallDir"]
    slurm_config_dir = ts_conf_vars["SlurmConfigDir"]

    # Set cluster to default if not exists
    test_env.setdefault("SLURM_CLUSTER", CLUSTER_DEFAULT_NAME)
    my_cluster = test_env["SLURM_CLUSTER"]

    # Set dynamic dirs
    CLUSTER_DIR = f"{APP_DIR}/{CLUSTER_DIR_PREFIX}/{my_cluster}"
    TS_DIR = f"{slurm_src_dir}/testsuite"
    UNIT_DIR = f"{slurm_build_dir}/testsuite/slurm_unit"
    GLOBALS_LOCAL_DEFAULT = f"{TS_DIR}/expect/globals.local"
    args = get_args()

    # Create directories for chosen my_cluster in ts_conf
    create_dir(f"{CLUSTER_DIR}/src")
    output_dir = args.output_dir or f"{CLUSTER_DIR}/log"
    create_dir(output_dir)
    db_name = f"{CLUSTER_DIR}/src/test_database.db"

    # Handle globals.local for expect suite
    test_env.setdefault("SLURM_LOCAL_GLOBALS_FILE", GLOBALS_LOCAL_DEFAULT)

    # Process chosen tests from args
    # TODO see if adding fails_first option is wanted
    # fails_first = args.fails_first or False
    fails_first = False
    all_unit_tests = get_unit_tl_from_build_dir(
        f"{slurm_build_dir}/testsuite/slurm_unit"
    )

    incl_dict = test_parser(args.include, all_unit_tests)
    excl_dict = test_parser(args.exclude, all_unit_tests)

    # Exit early if programs for chosen tests aren't installed
    if len(incl_dict["expect"]) > 0:
        if not is_tool("expect"):
            sys.exit(
                "The expect tests require expect (TCL) to be installed in order to run."
            )

    if len(incl_dict["python"]) > 0:
        if not is_tool("pytest-3"):
            sys.exit(
                "The python tests require pytest-3 to be installed in order to run."
            )

    # Collect tests lists from the dirs (to be current with new or deleted tests)
    all_expect_tests = get_tl_from_dir(f"{TS_DIR}/expect", "(test\d+\.\d+)$", "expect/")
    all_python_tests = get_tl_from_dir(
        f"{TS_DIR}/python/tests", "(test_\d+_\d+\.py)$", "python/tests/"
    )

    unit_dir_data = create_new_test_data(all_unit_tests, "slurm_unit")
    expect_dir_data = create_new_test_data(all_expect_tests, "expect")
    python_dir_data = create_new_test_data(all_python_tests, "python")

    # Setup or update the local db with all test data
    start_fresh = args.reset
    log_header("UPDATING RECORDS")
    log(f"DB: {db_name}")
    perform(
        f"Setting up local test database",
        setup_db_if_new,
        db_name,
        SEED_FILE,
        start_fresh,
    )
    perform(
        "Updating records based on test dirs",
        update_records_from_dirs,
        db_name,
        [unit_dir_data, expect_dir_data, python_dir_data],
    )

    # Log configuration
    log_header("SETTINGS: ENV VARS")
    log(f"SLURM_TESTSUITE_CONF={test_env['SLURM_TESTSUITE_CONF']}")
    log(f"SLURM_LOCAL_GLOBALS_FILE={test_env['SLURM_LOCAL_GLOBALS_FILE']}")
    log(f"SLURM_CLUSTER={test_env['SLURM_CLUSTER']}")

    log_header("SLURM_TESTSUITE_CONF VARS")
    log(f"SlurmSourceDir = {slurm_src_dir}")
    log(f"SlurmBuildDir = {slurm_build_dir}")
    log(f"SlurmInstallDir = {slurm_install_dir}")
    log(f"SlurmConfigDir = {slurm_config_dir}")

    # Log user include / exclude options
    #    log_new_line(f"Included tests: = ")
    #    print_pretty_dict(incl_dict)

    #    log(f"Excluded tests: = ")
    #    print_pretty_dict(excl_dict)

    # Format user choices to valid lists
    incl_dict = filter_tests(
        incl_dict, all_unit_tests, all_expect_tests, all_python_tests
    )
    excl_dict = filter_tests(
        excl_dict, all_unit_tests, all_expect_tests, all_python_tests
    )

    # Begin Testing
    log_header("**** BEGIN TESTING ****")
    if not args.resume:
        perform(
            "Reseting run_ids",
            reset_suite_run_ids,
            db_name,
            ["slurm_unit", "expect", "python"],
            verbose=False
        )


    # Run unit tests
    aborted = False
    run_unit_list = list(set(incl_dict["slurm_unit"]) - set(excl_dict["slurm_unit"]))
    skip_unit = len(run_unit_list) < 1
    sorted_unit_data = (
        []
        if skip_unit
        else (get_new_run_list(db_name, ["slurm_unit"], run_unit_list, fails_first))
    )

    if not skip_unit:
        perform(
            "Running unit tests",
            run_unit_tests,
            db_name,
            sorted_unit_data,
            UNIT_DIR,
            output_dir,
            args.resume,
            new_line=True,
        )

        unit_stats = get_unit_run_stats()
        unit_status = unit_stats["status"]
        if unit_status == "ABORTED":
            aborted = True

    # Run regressions tests
    run_regressions_list = list(
        set(incl_dict["expect"]) - set(excl_dict["expect"])
    ) + list(set(incl_dict["python"]) - set(excl_dict["python"]))

    # If in resume mode start from resume idx using the last ran data
    reg_start_idx = 0
    skip_reg = len(run_regressions_list) < 1
    sorted_reg_data = (
        []
        if skip_reg
        else get_new_run_list(
            db_name, ["expect", "python"], run_regressions_list, fails_first
        )
    )

    if not aborted:
        if not skip_reg:
            perform(
                "Running regressions tests",
                run_regressions_tests,
                db_name,
                sorted_reg_data,
                TS_DIR,
                output_dir,
                test_env,
                args.resume,
                new_line=True,
            )

            regressions_stats = get_regressions_run_stats()
            regressions_status = regressions_stats["status"]
            if regressions_status == "ABORTED":
                aborted = True


    # Print summary data
    print_run_summary(
        [
            ("Unit Test Summary", get_unit_run_stats()),
            ("Regressions Summary", get_regressions_run_stats()),
        ]
    )

    print(f"Logs written to:\n\n{output_dir}\n")
    if aborted:
        msg = """Run ended early, to resume this run next time use the '-r' option
* Note: running without '-r' next time will reset all 'resume' history
        """
        print(msg)
    os.chdir(APP_DIR)


def print_run_summary(run_tup_list):
    log_header("RESULTS")
    frmt = "{:<20s}{:>10s}"

    for run_tup in run_tup_list:
        msg, stats_dict = run_tup
        if len(stats_dict.items()) > 0:
            completions = stats_dict["completions"]
            passes = stats_dict["passes"]
            skips = stats_dict["skips"]
            fails = stats_dict["fails"]
            num_fails = len(fails)
            num_tests = stats_dict["total"]
            status = stats_dict["status"]

            if completions + skips + num_fails > 0:
                print_frmt(frmt, f"{msg}:", "")
                print_frmt(frmt, "Completions:", f"{completions}/{num_tests}")
                print_frmt(frmt, "Passed:", passes)
                print_frmt(frmt, "Failed:", num_fails)
                print_frmt(frmt, "Skipped:", skips)
                print_frmt(frmt, "Failed tests:", "")
                print(", ".join(fails))
                print(f"{('-' * 20)}\n")


def print_frmt(text_format, msg1, msg2):
    print(text_format.format(f"{msg1}", f"{msg2}"))


def create_new_test_data(test_name_list, suite, duration=10000, status="", run_id=0):
    result = []

    for test_name in test_name_list:
        result.append((test_name, suite, duration, status, run_id))

    return result


def filter_tests(t_dict, all_unit_tests, all_expect_tests, all_python_tests):
    for k, v in t_dict.items():
        if k == "slurm_unit":
            if len(v) > 0:
                if v[0] == "all":
                    t_dict[k] = all_unit_tests
                else:
                    t_dict[k] = intersect(v, all_unit_tests)

        elif k == "expect":
            if len(v) > 0:
                if v[0] == "all":
                    t_dict[k] = all_expect_tests
                else:
                    t_dict[k] = intersect(v, all_expect_tests)

        else:
            if len(v) > 0:
                if v[0] == "all":
                    t_dict[k] = all_python_tests
                else:
                    t_dict[k] = intersect(v, all_python_tests)

    return t_dict


def intersect(user_list, all_list):
    # Handle *. or .* and *_ or _* options
    result = []
    for test_name in user_list:
        test_name = test_name.replace(".", "\.").replace("*", "\d+")
        name_re = re.compile(f"{test_name}$")

        for valid_name in all_list:
            if name_re.search(valid_name):
                result.append(valid_name)

    return list(set(result))


if __name__ == "__main__":
    sys.exit(main())
