#!/usr/bin/python

# split-repo - split a repository into parts
# 
# Copyright 2003 Progeny Linux Systems.
# 
# This script splits a repository into N parts, each approximately the
# same size, such that adding each part to an apt sources.list is the
# same as adding the original whole.  Options exist for adding an apt
# CD label, for specifying the architecture, and for providing a
# package order list, among other things.  "picax --help" or the man
# page should give a good overview.
#
# This script relies on hashfile.py, a custom module that should be
# found in CVS with this script.  Make sure it is available to be
# imported into Python, either by copying into the same directory as
# this script or by copying it into the Python import path.
#
# Due to the use of several new features, this script requires Python
# 2.2 or later.

import sys
import os
import string
import re
import traceback
import md5
import sha
import hashfile
import gzip

import apt_pkg

import picax.config
import picax.apt
import picax.installer
import picax.media

import pdb

# Class definitions

class Package:
    """This class encapsulates a Package as represented in an apt
    Packages file.  Don't create one of your own; instead, rely on the
    PackageFactory class below to create these for you."""

    def __init__(self, base_path, fn, start_pos, section, distro, component):
        self.base_path = base_path
        self.fn = fn
        self.start_pos = start_pos
        self.lines = []
        self.meta = {}

        self.fields = {}
        for key in section.keys():
            self.fields[key] = section[key]

        self.meta["distribution"] = distro
        self.meta["component"] = component

    def get_lines(self):
        if len(self.lines) == 0:
            tagfile = open(self.fn)
            tagfile.seek(self.start_pos)
            line = tagfile.readline()
            while len(string.strip(line)) > 0:
                self.lines.append(line)
                line = tagfile.readline()
            tagfile.close()

        return self.lines[:]

    def link(self, dest_path):
        raise RuntimeError, "invoked link() on base class"

    def _get_package_size(self):
        raise RuntimeError, "size calculation not available in base class"

    def has_key(self, key):
        if self.fields.has_key(key) or \
           self._calc_meta.has_key(key) or \
           self.meta.has_key(key):
            return True
        else:
            return False

    def __getitem__(self, key):
        if self.fields.has_key(key):
            return self.fields[key]

        if not self.meta.has_key(key) and self._calc_meta.has_key(key):
            func = getattr(self, self._calc_meta[key])
            func(key)

        return self.meta[key]

    def __setitem__(self, key, value):
        if self.fields.has_key(key):
            raise KeyError, "meta field already defined in Packages file"
        self.meta[key] = value

    _calc_meta = { "Package-Size": "_get_package_size" }

class BinaryPackage(Package):
    def link(self, dest_root_path):
        pkg_path = self["Filename"]
        src_path = self.base_path + "/" + pkg_path
        dest_path = dest_root_path + "/" + pkg_path

        if os.path.exists(dest_path):
            sys.stderr.write("W: binary package %s already copied once, skipping\n"
                             % (pkg["Package"],))
            return

        pkg_dir = os.path.dirname(dest_path)
        if not os.path.exists(pkg_dir):
            os.makedirs(pkg_dir)

        os.link(src_path, dest_path)

    def _get_package_size(self, key):
        self.meta[key] = os.stat(self.base_path + "/"
                                 + self["Filename"]).st_size

    def get_source_info(self):
        if self.fields.has_key("Source"):
            source = string.split(string.strip(self["Source"]))
            if len(source) == 1:
                return (source[0], self["Version"])
            else:
                return (source[0], source[1][1:-1])
        else:
            return (self["Package"], self["Version"])

class SourcePackage(Package):
    def _get_file_list(self):
        if not hasattr(self, "file_list"):
            self.file_list = []
            file_dir = self["Directory"]
            for file_line in string.split(self["Files"], "\n"):
                if len(string.strip(file_line)) == 0:
                    continue
                (md5sum, size, fn) = string.split(string.strip(file_line))
                self.file_list.append(file_dir + "/" + fn)

        return self.file_list

    def link(self, dest_path):
        for path in self._get_file_list():
            dest_file_path = dest_path + "/" + path
            if os.path.exists(dest_file_path):
                sys.stderr.write("W: source package %s already copied once, skipping\n"
                                 % (self["Package"],))
                return

            pkg_dir = os.path.dirname(dest_file_path)
            if not os.path.exists(pkg_dir):
                os.makedirs(pkg_dir)

            os.link(self.base_path + "/" + path, dest_file_path)

    def _get_package_size(self, key):
        total_size = 0
        for path in self._get_file_list():
            total_size = total_size + os.stat(self.base_path + "/"
                                              + path).st_size
        self.meta[key] = total_size

class PackageFactory:
    """This class creates Package objects from the Packages file it is
    given.  Besides the explicit function calls, PackageFactory
    objects can be treated as iterators."""

    def __init__(self, package_file_stream, base_path, distro, component):
        self.base_path = base_path
        self.distro = distro
        self.component = component
        self.package_file = package_file_stream
        self.package_parser = apt_pkg.ParseTagFile(package_file_stream)
        self.eof = False

        self.current_pos = self.package_parser.Offset()

    def _next_package(self):
        if self.eof:
            return

        self.last_pos = self.current_pos
        self.eof = not self.package_parser.Step()
        self.current_pos = self.package_parser.Offset()

    def get_next_package(self):
        if self.eof:
            return None

        self._next_package()
        if self.eof:
            return None

        if self.package_parser.Section.has_key("Binary"):
            return SourcePackage(self.base_path, self.package_file.name,
                                 self.last_pos, self.package_parser.Section,
                                 self.distro, self.component)
        else:
            return BinaryPackage(self.base_path, self.package_file.name,
                                 self.last_pos, self.package_parser.Section,
                                 self.distro, self.component)

    def get_packages(self):
        pkg_list = []
        package = self.get_next_package()
        while package:
            pkg_list.append(package)
            package = self.get_next_package()

        return pkg_list

    def __iter__(self):
        return self

    def next(self):
        package = self.get_next_package()
        if not package:
            raise StopIteration
        return package

# Functions

def _get_path_size_helper(arg, path, names):
    global size_total

    size = reduce(lambda x, y: x + y,
                  map(lambda x: os.path.getsize(path + "/" + x), names))
    size_total = size_total + size

def _get_path_size(path):
    global size_total

    size_total = 0
    os.path.walk(path, _get_path_size_helper, None)
    return size_total

# Main function

def main():
    # Get the command-line options provided, and init the rest
    # of the system.

    try:
        picax.config.handle_args(sys.argv[1:])
    except picax.config.ConfigError, e:
        sys.stderr.write(str(e) + "\n")
        picax.config.usage(sys.stderr)
        sys.exit(1)

    debug = picax.config.get_config()["debug"]
    if debug:
        import pdb
        pdb.set_trace()

    picax.apt.init()

    # Build the media.

    try:
        # Read often-used configuration values from the global
        # configuration and sanity-check them.

        global_conf = picax.config.get_config()

        base_path = global_conf["base_path"]
        repo_list = global_conf["repository_list"]
        arch = global_conf["arch"]
        num_parts = global_conf["num_parts"]
        part_size = global_conf["part_size"]
        source_type = global_conf["source"]
        bootstrap_dist = None

        # Read the packages file for the source distro, calculating
        # total size of all packages.

        print "Reading package list..."

        package_list = []
        read_packages = {}
        total_binary_size = 0

        for (dist, section) in repo_list:
            if bootstrap_dist is None:
                bootstrap_dist = dist

            packages_path = "%s/dists/%s/%s/binary-%s/Packages" \
                            % (base_path, dist, section, arch)
            packages_file = open(packages_path)

            factory = PackageFactory(packages_file, base_path, dist, section)
            for pkg in factory:
                if read_packages.has_key(pkg["Package"]):
                    if pkg["Version"] in read_packages[pkg["Package"]]:
                        sys.stderr.write("W: binary package %s is a duplicate, skipping\n"
                                         % (pkg["Package"],))
                        continue
                    else:
                        sys.stderr.write("W: multiple versions of binary package %s exist\n"
                                         % (pkg["Package"],))

                total_binary_size = total_binary_size + pkg["Package-Size"]

                package_list.append(pkg)
                if not read_packages.has_key(pkg["Package"]):
                    read_packages[pkg["Package"]] = []
                read_packages[pkg["Package"]].append(pkg["Version"])

            packages_file.close()

        # If we're doing source, read the source index as well.

        total_source_size = 0
        if source_type != "none":
            print "Reading source index..."

            read_source = {}

            for (dist, section) in repo_list:
                source_path = "%s/dists/%s/%s/source/Sources" \
                              % (base_path, dist, section)
                source_file = open(source_path)

                factory = PackageFactory(source_file, base_path, dist, section)
                for pkg in factory:
                    if read_source.has_key(pkg["Package"]):
                        sys.stderr.write("W: multiple versions of source package %s exist\n"
                                         % (pkg["Package"],))

                    try:
                        pkg_size = pkg["Package-Size"]
                    except:
                        sys.stderr.write("W: problem with source package %s, skipping\n"
                                         % (pkg["Package"],))
                        continue

                    total_source_size = total_source_size + pkg_size

                    if not read_source.has_key(pkg["Package"]):
                        read_source[pkg["Package"]] = []
                    read_source[pkg["Package"]].append(pkg)

        # If we're using a media module, load it and grab its part_size.

        media_handler = None
        if global_conf.has_key("media_component"):
            if part_size != 0 or num_parts != 0:
                raise RuntimeError, "cannot set part size or number with a media module"
            picax.media.set_media(global_conf["media_component"])
            part_size = picax.media.get_part_size()
            media_handler = picax.media.MediaBuilder()

        # Do the math for the total size, number of parts, and
        # size of each part.  Use whatever's given (total size
        # and either num_parts or part_size) to calculate the
        # missing field.  If we're doing source type "separate",
        # things are a little strange, so handle that.

        if num_parts == 0 and part_size == 0:
            raise RuntimeError, "must specify media type, part size, or number of parts"

        total_size = total_binary_size + total_source_size
        if source_type == "separate":
            if part_size == 0:
                raise RuntimeError, "cannot use separate source and num_parts together"

            if num_parts == 0:
                num_bin_parts = int(total_binary_size / part_size) + 1
                num_src_parts = int(total_source_size / part_size) + 1
                num_parts = num_bin_parts + num_src_parts

        else:
            if part_size == 0:
                part_size = total_size / num_parts

            if num_parts == 0:
                num_parts = int(total_size / part_size) + 1

        # If an installer was specified, put it on the first part,
        # and count the space used by the installer against the
        # first part.

        if global_conf.has_key("installer_component"):
            print "Writing installer to first media..."

            picax.installer.set_installer(global_conf["installer_component"])

            first_part_loc = "%s/bin1" % (base_path,)
            os.makedirs(first_part_loc)
            picax.installer.install(first_part_loc)

            first_part_space = _get_path_size(first_part_loc)

            if media_handler is not None:
                inst_media_handler = picax.installer.get_media_builder()
                if inst_media_handler is not None:
                    media_handler = inst_media_handler
        else:
            first_part_space = 0

        # Create an order list.  Put explicitly requested packages first.

        print "Calculating order..."

        raw_order_list = []
        if global_conf.has_key("order_pkgs"):
            raw_order_list.extend(global_conf["order_pkgs"])

        # Include debootstrap packages in the order.

        if not global_conf["no_debootstrap"]:
            pkg_names = []
            try:
                debootstrap_pipe = os.popen("/usr/sbin/debootstrap --print-debs %s"
                                            % (bootstrap_dist,))
                for line in debootstrap_pipe:
                    pkg_names.extend(string.split(string.strip(line)))
                debootstrap_pipe.close()
            except:
                pass

            if len(pkg_names) == 0:
                sys.stderr.write("W: debootstrap could not report packages\n")
            else:
                raw_order_list.extend(pkg_names)

        # Include installer packages in the order.

        if global_conf.has_key("installer_component"):
            raw_order_list.extend(picax.installer.get_package_requests())

##         for pkg in package_list:
##             if pkg["Package"] not in raw_order_list:
##                 raw_order_list.append(pkg["Package"])

        order_list = picax.apt.resolve_package_list(raw_order_list)

        # Figure out which packages belong in which parts, putting
        # them into the package_group[part] arrays.

        print "Separating packages into media groups..."

        package_group = []
        current_group = []
        binary_list = []
        source_added = []
        current_size = 0

        for pkg_name in order_list:
            found_pkg = None
            for pkg in package_list:
                if pkg["Package"] == pkg_name:
                    found_pkg = pkg
                    break

            if found_pkg:
                pkg_size = found_pkg["Package-Size"]
                pkg_list = [found_pkg]

                if source_type == "mixed":
                    (src_name, src_version) = found_pkg.get_source_info()
                    if (src_name, src_version) not in source_added:
                        if read_source.has_key(src_name):
                            found_src = None
                            for srcpkg in read_source[src_name]:
                                if srcpkg["Version"] == src_version:
                                    found_src = srcpkg
                                    break
                            if found_src is not None:
                                pkg_list.append(found_src)
                                pkg_size = pkg_size + found_src["Package-Size"]
                            else:
                                sys.stderr.write("W: proper source package for binary package %s not found\n"
                                                 % (found_pkg["Package"],))
                        else:
                            sys.stderr.write("W: no source for binary package %s\n"
                                             % (found_pkg["Package"],))

                if (current_size + pkg_size) >= (part_size - first_part_space):
                    package_group.append(current_group)
                    current_group = []
                    current_size = 0
                    first_part_space = 0

                package_list.remove(found_pkg)
                for pkg in pkg_list:
                    current_group.append(pkg)
                current_size = current_size + pkg_size
                if source_type == "mixed":
                    source_added.append((src_name, src_version))
                elif source_type in ("immediate", "separate"):
                    binary_list.append(found_pkg)

        # Packages not mentioned in one of the order lists are tacked
        # on to the end unless specifically requested otherwise.

        if not global_conf["short_package_list"]:
            while (len(package_list)):
                found_pkg = package_list[0]
                pkg_size = found_pkg["Package-Size"]
                pkg_list = [found_pkg]

                if source_type == "mixed":
                    (src_name, src_version) = found_pkg.get_source_info()
                    if (src_name, src_version) in source_added:
                        if read_source.has_key(src_name):
                            found_src = None
                            for srcpkg in read_source[src_name]:
                                if srcpkg["Version"] == src_version:
                                    found_src = srcpkg
                                    break
                            if found_src is not None:
                                pkg_list.append(found_src)
                                pkg_size = pkg_size + found_src["Package-Size"]
                            else:
                                sys.stderr.write("W: proper source package for binary package %s not found\n"
                                                 % (found_pkg["Package"],))
                        else:
                            sys.stderr.write("W: no source for binary package %s\n"
                                             % (found_pkg["Package"],))

                if (current_size + pkg_size) >= (part_size - first_part_space):
                    package_group.append(current_group)
                    current_group = []
                    current_size = 0
                    first_part_space = 0

                del package_list[0]
                current_size = current_size + pkg_size
                for pkg in pkg_list:
                    current_group.append(pkg)
                if source_type == "mixed":
                    source_added.append((src_name, src_version))
                elif source_type in ("immediate", "separate"):
                    binary_list.append(found_pkg)

        # For separate source, start with a new package group.

        if source_type == "separate" and len(current_group):
            package_group.append(current_group)
            current_group = []
            current_size = 0
            first_part_space = 0

        # If the source type is one that requires tacking the source
        # for the current binaries to the end, do that.

        if source_type in ("immediate", "separate"):
            for pkg in binary_list:
                (src_name, src_version) = pkg.get_source_info()
                if (src_name, src_version) in source_added:
                    continue

                if not read_source.has_key(src_name):
                    sys.stderr.write("W: no source for binary package %s\n"
                                     % (pkg["Package"],))
                    continue

                found_pkg = None
                for srcpkg in read_source[src_name]:
                    if srcpkg["Version"] == src_version:
                        found_pkg = srcpkg
                        break

                if found_pkg is None:
                    sys.stderr.write("W: proper source version for binary package %s not found\n"
                                     % (pkg["Package"],))
                    continue

                pkg_size = found_pkg["Package-Size"]

                if (current_size + pkg_size) >= (part_size - first_part_space):
                    package_group.append(current_group)
                    current_group = []
                    current_size = 0
                    first_part_space = 0

                source_added.append((src_name, src_version))
                current_group.append(found_pkg)
                current_size = current_size + pkg_size

        # Pick up the last straggler package group if there is one.

        if len(current_group):
            package_group.append(current_group)

        # Read the binary-level Release file.

        release_info = {}

        for repo in repo_list:
            release_file = hashfile.open("%s/dists/%s/%s/binary-%s/Release"
                                         % (base_path, repo[0], repo[1], arch))
            release_md5 = md5.new()
            release_sha = sha.new()
            release_file.add_hash(release_md5)
            release_file.add_hash(release_sha)
            release_data = release_file.read()
            release_file.close()

            release_info[repo] = [(release_md5, release_sha, release_data)]

        # Read the toplevel Release file.

        if os.path.exists("%s/dists/%s/Release" % (base_path, bootstrap_dist)):
            toprelease_file = open("%s/dists/%s/Release"
                                   % (base_path, bootstrap_dist))
            toprelease_lines = toprelease_file.readlines()
            toprelease_file.close()
        else:
            sys.stderr.write("W: no toplevel Release file\n")
            toprelease_lines = None

        # Get the suite name.
        # XXX: Commented out while we figure out multi-repo.

##         suite_name = None
##         for line in toprelease_lines:
##             if line[:6] == "Suite:":
##                 suite_name = string.strip(line[6:])

        # Write the packages information into the parts.

        print "Writing packages into media groups:"

        hash_parts = []
        for part in range(0, len(package_group)):
            print "  group %d..." % (part + 1,)

            current_group = package_group[part]
            top_path = "%s/bin%d" % (base_path, part + 1)

            repo_files = {}
            repo_file_info = []
            for repo in repo_list:
                (dist, section) = repo
                dist_path = "%s/dists/%s/%s/binary-%s" \
                            % (top_path, dist, section, arch)
                src_dist_path = "%s/dists/%s/%s/source" \
                                % (top_path, dist, section)
                os.makedirs(dist_path)
                os.makedirs(src_dist_path)

                # XXX: commented out while solving multi-repo
##                 if suite_name and suite_name != dist:
##                     os.symlink(dist, "%s/dists/%s" % (top_path, suite_name))

                part_pkgs_file = hashfile.open(dist_path + "/Packages", "w")
                part_pkgs_md5 = md5.new()
                part_pkgs_sha = sha.new()
                part_pkgs_file.add_hash(part_pkgs_md5)
                part_pkgs_file.add_hash(part_pkgs_sha)

                part_srcs_file = hashfile.open(src_dist_path + "/Sources", "w")
                part_srcs_md5 = md5.new()
                part_srcs_sha = sha.new()
                part_srcs_file.add_hash(part_srcs_md5)
                part_srcs_file.add_hash(part_srcs_sha)

                repo_files[repo] = (part_pkgs_file, part_srcs_file)
                repo_file_info.extend(((dist, section, "Packages", dist_path,
                                        part_pkgs_md5, part_pkgs_sha),
                                       (dist, section, "Sources",
                                        src_dist_path, part_srcs_md5,
                                        part_srcs_sha)))

            for pkg in current_group:
                pkg.link(top_path)

                repo_file_list = repo_files[(pkg["distribution"],
                                             pkg["component"])]
                if pkg.has_key("Binary"):
                    part_file = repo_file_list[1]
                else:
                    part_file = repo_file_list[0]

                for line in pkg.get_lines():
                    part_file.write(line)
                part_file.write("\n")

            for repo in repo_files.keys():
                for repo_file in repo_files[repo]:
                    repo_file.close()

            hash_info = []

            for (dist, section, fn, fn_path, fn_md5, fn_sha) in repo_file_info:
                fn_size = os.stat(fn_path + "/" + fn).st_size
                if fn_size > 0:
                    hash_info.append((dist, section, fn, fn_size,
                                      fn_md5.hexdigest(), fn_sha.hexdigest()))

                    gzip_hash = hashfile.open(fn_path + "/" + fn + ".gz", "w")
                    gzip_md5 = md5.new()
                    gzip_sha = sha.new()
                    gzip_hash.add_hash(gzip_md5)
                    gzip_hash.add_hash(gzip_sha)
                    gzip_file = gzip.GzipFile(fn, "w", 9,
                                              gzip_hash)

                    part_file = open(fn_path + "/" + fn)
                    gzip_file.write(part_file.read())
                    part_file.close()
                    gzip_file.close()

                    gzip_size = os.stat(fn_path + "/" + fn + ".gz").st_size
                    hash_info.append((dist, section, fn + ".gz", gzip_size,
                                      gzip_md5.hexdigest(),
                                      gzip_sha.hexdigest()))

                    release_out = open(fn_path + "/Release", "w")
                    release_out.write(release_data)
                    release_out.close()

                    release_size = os.stat(fn_path + "/Release").st_size
                    hash_info.append((dist, section, "Release", release_size,
                                      release_md5.hexdigest(),
                                      release_sha.hexdigest()))
                else:
                    for emptyfn in os.listdir(fn_path):
                        os.unlink(fn_path + "/" + emptyfn)
                    os.rmdir(fn_path)

            # Write the toplevel Release files.

            if toprelease_lines:
                toprelease_out = open("%s/dists/%s/Release"
                                      % (top_path, bootstrap_dist),
                                      "w")
                for line in toprelease_lines:
                    if line[0] in string.whitespace:
                        continue

                    toprelease_out.write(line)
                    if line[:6] == "MD5Sum":
                        for (dist, section, fn, size,
                             ck_md5, ck_sha) in hash_info:
                            toprelease_out.write(" %s %18d %s\n"
                                                 % (ck_md5, size,
                                                    "%s/binary-%s/%s"
                                                    % (section, arch, fn)))
                    elif line[:4] == "SHA1":
                        for (dist, section, fn, size,
                             ck_md5, ck_sha) in hash_info:
                            toprelease_out.write(" %s %18d %s\n"
                                                 % (ck_sha, size,
                                                    "%s/binary-%s/%s"
                                                    % (section, arch, fn)))

                toprelease_out.close()

            # Write the CD label, if there is one.

            if global_conf.has_key("cd_label"):
                os.mkdir("%s/.disk" % (top_path,))
                info_file = open("%s/.disk/info" % (top_path,), "w")
                info_file.write("%s (%d/%d)\n" % (global_conf["cd_label"],
                                                  part + 1, len(package_group)))
                info_file.close()

            # Save the hashes.

            hash_parts.append(hash_info)

        # After installing the parts, run the post-installation
        # setup for the installer, if any.

        if global_conf.has_key("installer_component"):
            print "Running installer post-install setup..."
            picax.installer.post_install(first_part_loc)

        # If a media module was set, build media now.

        if media_handler:
            media_handler.create_media()

    except Exception, e:

        # Generic error handling code for any problems.

        if debug:
            traceback.print_exc(None, sys.stderr)
            pdb.post_mortem(sys.exc_info()[2])
        else:
            sys.stderr.write("error: " + str(e) + "\n")
            sys.stderr.write("\n")
            picax.config.usage(sys.stderr)

        sys.exit(1)

    # For now, report on the files generated.  Technically, this isn't
    # necessary, but it helps to debug when things go wrong.

    for index in range(0, len(hash_parts)):
        for item in hash_parts[index]:
            print "%d %s/%s/%s: size %d\n  md5 %s\n  sha %s" \
                  % ((index + 1,) + item)

if __name__ == "__main__":
    main()
