#!/usr/bin/python3

# Copyright 2014..2022, Martin <debacle@debian.org>
# License: AGPL-3+

# Python standard modules
import argparse
import collections
import configparser
import email.mime.text
import email.utils
import hashlib
import html
import os
import smtplib
import socket
import subprocess
import sys
import textwrap

# additional modules
import apt
import prettytable
import slixmpp

longname = "Pain in the APT"
shortname = "painintheapt"
version = "0.20220226"

columns = ["Name", "Installed", "Candidate"]
Package = collections.namedtuple("Package", " ".join(columns).lower())


def getargs():
    ap = argparse.ArgumentParser(
        description="Pester people about available package updates" + " by email or jabber.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    ap.add_argument(
        "-c",
        "--configfile",
        default="/etc/%s.conf" % shortname,
        help="configuration file",
    )
    ap.add_argument(
        "-d",
        "--debug",
        default=False,
        action="store_true",
        help="print debug output to stderr",
    )
    ap.add_argument(
        "-f",
        "--force",
        default=False,
        action="store_true",
        help="send message, even if updates did not change",
    )
    ap.add_argument(
        "-s",
        "--stampfile",
        help="stamp file",
        default="/var/lib/%s/stamp" % shortname,
    )
    ap.add_argument(
        "-t",
        "--testmessage",
        default=False,
        action="store_true",
        help="send a test message only",
    )
    ap.add_argument("-v", "--version", action="version", version="%(prog)s " + version)
    return ap.parse_args()


def update():
    """Create the APT cache and update it.

    Return the cache and a list of updates.
    """
    updates = []
    cache = apt.Cache()
    cache.update()
    cache.open()
    cache.upgrade(dist_upgrade=True)
    changes = cache.get_changes()
    for c in changes:
        name = c._pkg.name
        pkg = cache[name]
        installed = pkg.installed.version if pkg.installed else "-"
        candidate = pkg.candidate.version if pkg.candidate else "-"
        updates.append(Package(name, installed, candidate))
    return cache, updates


def wrap(text, maxwid):
    """Fill paragraph."""
    return "\n".join(textwrap.wrap(text, maxwid))


_changes = None


def get_changelogs(cache, send_changes):
    """Download changelogs. Beware: This is very slow.

    Identical changelogs for different binary packages are combined.
    """
    global _changes
    if cache is None or send_changes is not True:
        return ""
    if _changes:
        return _changes
    changelogs = collections.defaultdict(list)
    changes = cache.get_changes()
    for c in changes:
        name = c._pkg.name
        changelog = cache[name].get_changelog().strip()
        changelogs[changelog].append(name)
    # now do some very fancy formatting
    maxwid = 79
    _changes = ("\n" + "-" * maxwid + "\n").join(
        sorted(
            [
                wrap(", ".join(sorted(names)), maxwid) + ":\n\n" + changelog
                for changelog, names in changelogs.items()
            ]
        )
    )
    return _changes


def maketable(lst):
    """Create a pretty table of package updates."""
    table = prettytable.PrettyTable(columns)
    table.sortby = columns[0]
    table.align = "l"
    maxwid = 23
    for element in lst:
        table.add_row(
            [
                wrap(element.name, maxwid),
                wrap(element.installed, maxwid),
                wrap(element.candidate, maxwid),
            ]
        )
    return table.get_string()


class JabberBot(slixmpp.ClientXMPP):
    def __init__(
        self,
        jid,
        password,
        to,
        room,
        pubsub_service,
        pubsub_node,
        nick,
        subject,
        table,
        changes,
    ):
        slixmpp.ClientXMPP.__init__(self, jid, password)
        self.to = to
        self.room = room
        self.pubsub_service = pubsub_service
        self.pubsub_node = pubsub_node
        self.nick = nick
        self.add_event_handler("session_start", self.start)
        self.subject = subject
        self.table = table
        self.changes = changes

    def start(self, event):
        self.get_roster()
        self.send_presence()
        pre = "```"
        for to in self.to:
            self.send_message(
                mto=to,
                msubject=self.subject,
                # subject is not shown by all clients, better add it to body
                mbody="\n".join([self.subject, pre, self.table, pre, "\n", self.changes]),
                mtype="chat",
            )
        if self.room:
            self.plugin["xep_0045"].join_muc(self.room, self.nick)
            self.send_message(
                mto=self.room,
                # no per message subject in groupchats, add it to message body
                mbody="\n".join([self.subject, pre, self.table, pre, "\n", self.changes]),
                mtype="groupchat",
            )
        if self.pubsub_service and self.pubsub_node:
            payload = (
                '<entry xmlns="http://www.w3.org/2005/Atom"><title>'
                + html.escape(self.subject)
                + '</title><content type="xhtml"><div>'
                + '<pre xmlns="http://www.w3.org/1999/xhtml">'
                + html.escape(self.table)
                + "</pre><p>"
                + html.escape(self.changes).replace("\n", "</p>\n<p>").replace(" ", "&#160;")
                + "</p></div></content></entry>"
            )
            self["xep_0060"].publish(
                self.pubsub_service,
                self.pubsub_node,
                payload=slixmpp.xmlstream.ET.fromstring(payload),
            )
        self.disconnect(wait=True)


def read_password(config, config_dir):
    password_file = config.get("password_file", "").strip()
    if len(password_file):
        filename = os.path.join(config_dir, password_file)
        with open(filename) as f:
            return f.read().strip()

    print("password deprecated, use password_file instead", file=sys.stderr)
    return config.get("password", "")


def sendxmpp(config, config_dir, table, count, host, debug, changes):
    """Send message to a jabber conference room."""
    jid = config.get("jid", "")
    password = read_password(config, config_dir)
    to = config.get("to", "").split(",")
    room = config.get("room")
    pubsub_service = config.get("pubsub_service", "").strip()
    pubsub_node = config.get("pubsub_node", "").strip()
    subject = "%d package update(s) for %s" % (count, host)
    xmpp = JabberBot(
        jid,
        password,
        to,
        room,
        pubsub_service,
        pubsub_node,
        longname,
        subject,
        table,
        changes,
    )
    xmpp.register_plugin("xep_0030")  # service discovery
    if room:
        xmpp.register_plugin("xep_0045")  # multi-user chat
    if pubsub_service and pubsub_node:
        xmpp.register_plugin("xep_0060")  # pubsub
    xmpp.register_plugin("xep_0199")  # XMPP ping

    xmpp.connect()
    xmpp.process(forever=False)


def sendsmtp(config, config_dir, table, count, host, debug, changes):
    """Send email by SMTP to whomsoever it may concern."""
    server = config.get("server", "localhost")
    port = config.getint("port", 25)
    username = config.get("username", "")
    password = read_password(config, config_dir)
    from_ = config.get("from", username)
    to = config.get("to", username)
    cc = config.get("cc", "")

    msg = email.mime.text.MIMEText("\n\n".join([table, changes]).strip(), "plain", "utf-8")
    msg["From"] = from_
    msg["To"] = to
    msg["Subject"] = "%d package update(s) for %s" % (count, host)
    msg["X-Mailer"] = longname

    if cc:
        msg["Cc"] = cc

    s = smtplib.SMTP(host=server, port=port)
    if debug:
        s.set_debuglevel(True)
    s.starttls()
    s.ehlo_or_helo_if_needed()
    if username or password:
        s.login(username, password)
    recipients = [r[1] for r in email.utils.getaddresses([to + "," + cc])]
    s.sendmail(from_, list(set(recipients)), msg.as_string())
    s.quit()


def sendmailx(config, config_dir, table, count, host, debug, changes):
    """Send email by mailx to whomsoever it may concern."""
    cmd = [
        "/usr/bin/mailx",
        "-r",
        config.get("from", "root"),
        "-s",
        "%d package update(s) for %s" % (count, host),
        "-a",
        "X-Mailer: " + longname,
    ]
    cc = config.get("cc", "")
    if cc:
        cmd += ["-c", cc]
    # this is taken from apticron
    if os.path.realpath("/usr/bin/mailx") == "/usr/bin/heirloom-mailx":
        cmd += ["-S", "ttycharset=utf-8"]
    else:
        cmd += [
            "-a",
            "MIME-Version: 1.0",
            "-a",
            "Content-type: text/plain; charset=UTF-8",
            "-a",
            "Content-transfer-encoding: 8bit",
        ]
    to = config.get("to", "root")
    mailx = subprocess.Popen(cmd + [to], stdin=subprocess.PIPE)
    mailx.stdin.write("\n\n".join([table, changes]).strip())
    mailx.stdin.close()
    mailx.wait()


def has_changed(configfile, table, stampfile):
    change = False
    hashsum = hashlib.sha1()
    for line in open(configfile):
        hashsum.update(line.encode("utf-8"))
    hashsum.update(table.encode("utf-8"))
    newhash = hashsum.hexdigest()
    try:
        with open(stampfile) as f:
            oldhash = f.readline().strip()
    except Exception as err:
        oldhash = "invalid"
    if oldhash != newhash:
        change = True
    return change, newhash


class AcquireProgress(apt.progress.text.AcquireProgress):
    def __init__(self, debug):
        super(AcquireProgress, self).__init__(
            outfile=sys.stderr if debug else open("/dev/null", "w")
        )


if __name__ == "__main__":
    args = getargs()
    config = configparser.ConfigParser()
    config.read(args.configfile)
    config_dir = os.path.dirname(args.configfile)

    fqdn = socket.getfqdn()
    # workaround for dodgy /etc/hosts
    if fqdn in ["localhost", "localhost.localdomain"]:
        fqdn = socket.gethostname() or fqdn

    if args.testmessage:
        cache = None
        count = 0
        table = "this is a test message from painintheapt"
        change = True
    else:
        cache, updates = update()
        count = len(updates)
        table = maketable(updates) if count else ""
        change, newhash = has_changed(args.configfile, table, args.stampfile)

    ret = 0
    for section, function in [
        ("XMPP", sendxmpp),
        ("SMTP", sendsmtp),
        ("MAILX", sendmailx),
    ]:
        try:
            if section in config.sections() and (change or args.force):
                send_changes = config[section].getboolean("send_changes", True)
                function(
                    config[section],
                    config_dir,
                    table,
                    count,
                    fqdn,
                    args.debug,
                    get_changelogs(cache, send_changes),
                )
        except Exception as err:
            print(str(err), file=sys.stderr)
            ret = 1

    if args.testmessage:
        sys.exit(ret)

    if change or args.force:
        with open(args.stampfile, "wb") as f:
            f.write(newhash.encode("utf-8"))

    cache.fetch_archives(progress=AcquireProgress(args.debug))

    sys.exit(ret)
