#!/usr/bin/python3
import argparse
import glob
import os
import re
import subprocess
import sys
import time


def get_tboot_menu_entry(grub_cfg="/boot/grub/grub.cfg"):
    submenu_pattern = re.compile(r"^\s*submenu\s+\"([^\"]+)\".*$")
    with open(grub_cfg) as f:
        for line in f.readlines():
            m = submenu_pattern.match(line)
            if m and "tboot" in m.group(1):
                return m.group(1)
    print("Unable to find tboot menu entry")
    sys.exit(1)


def set_tboot_as_default():
    tboot_menu_entry = get_tboot_menu_entry()
    print(tboot_menu_entry)
    with open("/etc/default/grub") as f:
        grub_default = f.read()
    updated_grub_default = re.sub("\nGRUB_DEFAULT=[^\n]+\n",
                                  f"\nGRUB_DEFAULT='{tboot_menu_entry}'\n",
                                  grub_default)
    with open("/etc/default/grub", "w+") as f:
        f.write(updated_grub_default)
    subprocess.check_call(["update-grub2"], stderr=subprocess.DEVNULL)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--on-unsupported-hardware", action="store_true", help="Indicates that the test is running on unsupported hardware and working measured boot is not expected. However, the system should stil boot normally")
    parser.add_argument("--tpm", type=int, help="TPM version [1|2]")
    parser.add_argument("--boot", help="Boot mode [bios|efi]")
    args = parser.parse_args()

    if not args.on_unsupported_hardware:
        # Check CPU
        with open("/proc/cpuinfo", "r") as f:
            cpuinfo = f.read()
        if "smx" not in cpuinfo:
            print("Your CPU does not support SMX, measured launch cannot work", file=sys.stderr)
            sys.exit(77)

        # Check boot mode
        efi = os.path.exists("/sys/firmware/efi")
        if (args.boot == "efi") ^ efi:
            print(f"Boot mode {args.boot} is not possible on this hardware", file=sys.stderr)
            sys.exit(77)

        # Check TPM
        if not os.path.exists("/sys/class/tpm/tpm0/tpm_version_major"):
            print("You do not seem to have a TPM, measured launch cannot work", file=sys.stderr)
            sys.exit(77)
        with open("/sys/class/tpm/tpm0/tpm_version_major") as f:
            tpm_version = int(f.read())
        if args.tpm != tpm_version:
            print(f"Testing TPM {args.tpm} is not possible on this hardware", file=sys.stderr)
            sys.exit(77)

    reboot_mark = os.environ.get("AUTOPKGTEST_REBOOT_MARK", "mark-1")

    if reboot_mark == "mark-1":
        txt_stat = subprocess.check_output("txt-stat", encoding="utf-8")
        if "unable to find TBOOT log" not in txt_stat:
            print(f"Expected to see 'unable to find TBOOT log' in txt-stat output but got {txt_stat}", file=sys.stderr)
            sys.exit(1)

        txt_parse_err = subprocess.check_output("txt-parse_err", encoding="utf-8")
        if "no error" not in txt_parse_err:
            print(f"Expected to see 'no error' in txt-parse_err output but got {txt_parse_err}", file=sys.stderr)
            sys.exit(1)

        if args.tpm == 1:
            subprocess.check_call(["cp", "/usr/share/doc/trousers/examples/system.data.auth", "/var/lib/tpm/system.data"])
            subprocess.check_call(["systemctl", "restart", "tcsd"])
            time.sleep(4)
        set_tboot_as_default()
        subprocess.check_call(["/tmp/autopkgtest-reboot", "mark-2"])
    elif reboot_mark == "mark-2":
        txt_stat = subprocess.check_output("txt-stat", encoding="utf-8")
        if "unable to find TBOOT log" in txt_stat:
            print(f"Expected to not see 'unable to find TBOOT log' in txt-stat output after reboot, got {txt_stat}", file=sys.stderr)
            sys.exit(1)

        if not args.on_unsupported_hardware:
            if "TBOOT: no SINIT AC module found" in txt_stat:
                print("Your BIOS does not seem to provide an SINIT AC module, measured launch cannot work", file=sys.stderr)
                sys.exit(77)

            if "SINIT matches platform" not in txt_stat:
                print("Somehow the provided SINIT AC does not match your platform?", file=sys.stderr)
                sys.exit(1)

            if args.tpm == 2:
                pcrs = subprocess.check_output(["tpm2_pcrread"], encoding="utf-8")
                if "17: 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" in pcrs:
                    print("Got uninitialized PCR17 value, something went wrong", file=sys.stderr)
                    sys.exit(1)
            if args.tpm == 1:
                with open(glob.glob("/sys/devices/*/*/pcrs")[0]) as f:
                    pcrs = f.read()
                if "PCR-17: FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF" in pcrs:
                    print("Got uninitialized PCR17 value, something went wrong", file=sys.stderr)
                    sys.exit(1)

        sys.exit(0)
    else:
        print(f"Got unexpected reboot mark {reboot_mark}", file=sys.stderr)
        sys.exit(1)


if __name__ == "__main__":
    main()
