#! /usr/bin/env python
# Copyright 2015 Martin C. Frith

import math, optparse, os, random, signal, subprocess, sys, tempfile

def writeWords(outFile, words):
    outFile.write(" ".join(words) + "\n")

def seqInput(fileNames):
    for name in fileNames:
        with open(name) as file:
            seqType = 0
            for line in file:
                if seqType == 0:
                    if line[0] == ">":
                        seqType = 1
                        seq = []
                    elif line[0] == "@":
                        seqType = 2
                        lineType = 1
                elif seqType == 1:  # fasta
                    if line[0] == ">":
                        yield "".join(seq), ""
                        seq = []
                    else:
                        seq.append(line.rstrip())
                elif seqType == 2:  # fastq
                    if lineType == 1:
                        seq = line.rstrip()
                    elif lineType == 3:
                        yield seq, line.rstrip()
                    lineType = (lineType + 1) % 4
            if seqType == 1: yield "".join(seq), ""

def randomNumbersAndLengths(lengths, sampleSize):
    numbersAndLengths = list(enumerate(lengths))
    random.shuffle(numbersAndLengths)
    for i, x in numbersAndLengths:
        if x < sampleSize:
            yield i, x
            sampleSize -= x
        else:
            yield i, sampleSize
            break

def writeSeqSample(numbersAndLengths, seqInput, outfile):
    j = 0
    for i, x in enumerate(seqInput):
        if j < len(numbersAndLengths) and numbersAndLengths[j][0] == i:
            seqlen = numbersAndLengths[j][1]
            seq, qual = x
            if qual:
                outfile.write("@" + str(i) + "\n")
                outfile.write(seq[:seqlen])
                outfile.write("\n+\n")
                outfile.write(qual[:seqlen])
            else:
                outfile.write(">" + str(i) + "\n")
                outfile.write(seq[:seqlen])
            outfile.write("\n")
            j += 1

def getSeqSample(opts, queryFiles, outfile):
    lengths = (len(i[0]) for i in seqInput(queryFiles))
    sampleSize = int(round(opts.sample))
    numbersAndLengths = sorted(randomNumbersAndLengths(lengths, sampleSize))
    writeSeqSample(numbersAndLengths, seqInput(queryFiles), outfile)
    print "# sampled sequences:", len(numbersAndLengths)

def scaleFromHeader(lines):
    for line in lines:
        for i in line.split():
            if i.startswith("t="):
                return float(i[2:])
    raise Exception("couldn't read the scale")

def scoreMatrixFromHeader(lines):
    matrix = []
    for line in lines:
        w = line.split()
        if len(w) > 2 and len(w[1]) == 1:
            matrix.append(w[1:])
        elif matrix:
            break
    return matrix

def scaledMatrix(matrix, scaleIncrease):
    return matrix[0:1] + [i[0:1] + [int(j) * scaleIncrease for j in i[1:]]
                          for i in matrix[1:]]

def countsFromLastOutput(lines, opts):
    matrix = []
    # use +1 pseudocounts as a kludge to mitigate numerical problems:
    matches = 1.0
    deletes = 2.0  # 1 open + 1 extension
    inserts = 2.0  # 1 open + 1 extension
    delOpens = 1.0
    insOpens = 1.0
    alignments = 0  # no pseudocount here
    for line in lines:
        if line[0] == "s":
            strand = line.split()[4]  # slow?
        if line[0] == "c":
            c = map(float, line.split()[1:])
            if not matrix:
                matrixSize = int(math.sqrt(len(c) - 10))
                matrix = [[1.0] * matrixSize for i in range(matrixSize)]
            identities = sum(c[i * matrixSize + i] for i in range(matrixSize))
            alignmentLength = c[-10] + c[-9] + c[-8]
            if 100 * identities > opts.pid * alignmentLength: continue
            for i in range(matrixSize):
                for j in range(matrixSize):
                    if strand == "+" or opts.S == "0":
                        matrix[i][j]       += c[i * matrixSize + j]
                    else:
                        matrix[-1-i][-1-j] += c[i * matrixSize + j]
            matches += c[-10]
            deletes += c[-9]
            inserts += c[-8]
            delOpens += c[-7]
            insOpens += c[-6]
            alignments += 1
    gapCounts = matches, deletes, inserts, delOpens, insOpens, alignments
    return matrix, gapCounts

def scoreFromProb(scale, prob):
    if prob > 0: logProb = math.log(prob)
    else:        logProb = -800  # exp(-800) is exactly zero, on my computer
    return int(round(scale * logProb))

def costFromProb(scale, prob):
    return -scoreFromProb(scale, prob)

def guessAlphabet(matrixSize):
    if matrixSize ==  4: return "ACGT"
    if matrixSize == 20: return "ACDEFGHIKLMNPQRSTVWY"
    raise Exception("can't handle unusual alphabets")

def matrixWithLetters(matrix):
    alphabet = guessAlphabet(len(matrix))
    return [alphabet] + [[a] + i for a, i in zip(alphabet, matrix)]

def writeMatrixHead(outFile, prefix, alphabet, formatString):
    writeWords(outFile, [prefix + " "] + [formatString % k for k in alphabet])

def writeMatrixBody(outFile, prefix, alphabet, matrix, formatString):
    for i, j in zip(alphabet, matrix):
        writeWords(outFile, [prefix + i] + [formatString % k for k in j])

def writeCountMatrix(outFile, matrix, prefix):
    alphabet = guessAlphabet(len(matrix))
    writeMatrixHead(outFile, prefix, alphabet, "%-14s")
    writeMatrixBody(outFile, prefix, alphabet, matrix, "%-14s")

def writeProbMatrix(outFile, matrix, prefix):
    alphabet = guessAlphabet(len(matrix))
    writeMatrixHead(outFile, prefix, alphabet, "%-14s")
    writeMatrixBody(outFile, prefix, alphabet, matrix, "%-14g")

def writeScoreMatrix(outFile, matrix, prefix):
    alphabet = guessAlphabet(len(matrix))
    writeMatrixHead(outFile, prefix, alphabet, "%6s")
    writeMatrixBody(outFile, prefix, alphabet, matrix, "%6s")

def writeMatrixWithLetters(outFile, matrix, prefix):
    head = matrix[0]
    tail = matrix[1:]
    left = [i[0] for i in tail]
    body = [i[1:] for i in tail]
    writeMatrixHead(outFile, prefix, head, "%6s")
    writeMatrixBody(outFile, prefix, left, body, "%6s")

def matProbsFromCounts(counts, opts):
    r = range(len(counts))
    if opts.revsym:  # add complement (reverse strand) substitutions
        counts = [[counts[i][j] + counts[-1-i][-1-j] for j in r] for i in r]
    if opts.matsym:  # symmetrize the substitution matrix
        counts = [[counts[i][j] + counts[j][i] for j in r] for i in r]
    identities = sum(counts[i][i] for i in r)
    total = sum(map(sum, counts))
    probs = [[j / total for j in i] for i in counts]

    print "# substitution percent identity: %g" % (100 * identities / total)
    print
    print "# count matrix (query letters = columns, reference letters = rows):"
    writeCountMatrix(sys.stdout, counts, "# ")
    print
    print "# probability matrix (query letters = columns, reference letters = rows):"
    writeProbMatrix(sys.stdout, probs, "# ")
    print

    return probs

def gapProbsFromCounts(counts, opts):
    matches, deletes, inserts, delOpens, insOpens, alignments = counts
    if not alignments: raise Exception("no alignments")
    gaps = deletes + inserts
    gapOpens = delOpens + insOpens
    denominator = matches + gapOpens + (alignments + 1)  # +1 pseudocount
    if opts.gapsym:
        delOpenProb = gapOpens / denominator / 2
        insOpenProb = gapOpens / denominator / 2
        delExtendProb = (gaps - gapOpens) / gaps
        insExtendProb = (gaps - gapOpens) / gaps
    else:
        delOpenProb = delOpens / denominator
        insOpenProb = insOpens / denominator
        delExtendProb = (deletes - delOpens) / deletes
        insExtendProb = (inserts - insOpens) / inserts

    print "# aligned letter pairs:", matches
    print "# deletes:", deletes
    print "# inserts:", inserts
    print "# delOpens:", delOpens
    print "# insOpens:", insOpens
    print "# alignments:", alignments
    print "# mean delete size: %g" % (deletes / delOpens)
    print "# mean insert size: %g" % (inserts / insOpens)
    print "# delOpenProb: %g" % delOpenProb
    print "# insOpenProb: %g" % insOpenProb
    print "# delExtendProb: %g" % delExtendProb
    print "# insExtendProb: %g" % insExtendProb
    print

    delCloseProb = 1 - delExtendProb
    insCloseProb = 1 - insExtendProb
    firstDelProb = delOpenProb * delCloseProb
    firstInsProb = insOpenProb * insCloseProb

    # If we define "an alignment" to mean "a set of indistinguishable
    # paths", then:
    #delExtendProb += firstDelProb
    #insExtendProb += firstInsProb
    # Else, this ensures gap existence cost >= 0:
    delExtendProb = max(delExtendProb, firstDelProb)
    insExtendProb = max(insExtendProb, firstInsProb)

    delExistProb = firstDelProb / delExtendProb
    insExistProb = firstInsProb / insExtendProb

    return delExistProb, insExistProb, delExtendProb, insExtendProb

def scoreFromLetterProbs(scale, pairProb, prob1, prob2):
    probRatio = pairProb / (prob1 * prob2)
    return scoreFromProb(scale, probRatio)

def matScoresFromProbs(scale, probs):
    rowProbs = map(sum, probs)
    colProbs = map(sum, zip(*probs))
    return [[scoreFromLetterProbs(scale, j, x, y) for j, y in zip(i, colProbs)]
            for i, x in zip(probs, rowProbs)]

def gapCostsFromProbs(scale, probs):
    delExistProb, insExistProb, delExtendProb, insExtendProb = probs
    delExistCost = costFromProb(scale, delExistProb)
    insExistCost = costFromProb(scale, insExistProb)
    delExtendCost = costFromProb(scale, delExtendProb)
    insExtendCost = costFromProb(scale, insExtendProb)
    if delExtendCost == 0: delExtendCost = 1
    if insExtendCost == 0: insExtendCost = 1
    return delExistCost, insExistCost, delExtendCost, insExtendCost

def writeLine(out, *things):
    out.write(" ".join(map(str, things)) + "\n")

def writeGapCosts(gapCosts, out):
    delExistCost, insExistCost, delExtendCost, insExtendCost = gapCosts
    writeLine(out, "#last -a", delExistCost)
    writeLine(out, "#last -A", insExistCost)
    writeLine(out, "#last -b", delExtendCost)
    writeLine(out, "#last -B", insExtendCost)

def printGapCosts(gapCosts):
    delExistCost, insExistCost, delExtendCost, insExtendCost = gapCosts
    print "# delExistCost:", delExistCost
    print "# insExistCost:", insExistCost
    print "# delExtendCost:", delExtendCost
    print "# insExtendCost:", insExtendCost
    print

def tryToMakeChildProgramsFindable():
    myDir = os.path.dirname(__file__)
    srcDir = os.path.join(myDir, os.pardir, "src")
    # put srcDir first, to avoid getting older versions of LAST:
    os.environ["PATH"] = srcDir + os.pathsep + os.environ["PATH"]

def fixedLastalArgs(opts):
    x = ["lastal", "-j7"]
    if opts.D: x.append("-D" + opts.D)
    if opts.E: x.append("-E" + opts.E)
    if opts.s: x.append("-s" + opts.s)
    if opts.S: x.append("-S" + opts.S)
    if opts.T: x.append("-T" + opts.T)
    if opts.m: x.append("-m" + opts.m)
    if opts.P: x.append("-P" + opts.P)
    if opts.Q: x.append("-Q" + opts.Q)
    return x

def doTraining(opts, args):
    tryToMakeChildProgramsFindable()
    scaleIncrease = 20  # while training, up-scale the scores by this amount
    x = fixedLastalArgs(opts)
    if opts.r: x.append("-r" + opts.r)
    if opts.q: x.append("-q" + opts.q)
    if opts.p: x.append("-p" + opts.p)
    if opts.a: x.append("-a" + opts.a)
    if opts.b: x.append("-b" + opts.b)
    if opts.A: x.append("-A" + opts.A)
    if opts.B: x.append("-B" + opts.B)
    x += args
    y = ["last-split", "-n"]
    p = subprocess.Popen(x, stdout=subprocess.PIPE)
    q = subprocess.Popen(y, stdin=p.stdout, stdout=subprocess.PIPE)
    externalScale = scaleFromHeader(q.stdout)
    internalScale = externalScale * scaleIncrease
    if opts.Q:
        externalMatrix = scoreMatrixFromHeader(q.stdout)
        internalMatrix = scaledMatrix(externalMatrix, scaleIncrease)
    oldParameters = []

    print "# maximum percent identity:", opts.pid
    print "# scale of score parameters:", externalScale
    print "# scale used while training:", internalScale
    print

    while True:
        print "#", " ".join(x)
        print
        sys.stdout.flush()
        matCounts, gapCounts = countsFromLastOutput(q.stdout, opts)
        gapProbs = gapProbsFromCounts(gapCounts, opts)
        gapCosts = gapCostsFromProbs(internalScale, gapProbs)
        printGapCosts(gapCosts)
        if opts.Q:
            if gapCosts in oldParameters: break
            oldParameters.append(gapCosts)
        else:
            matProbs = matProbsFromCounts(matCounts, opts)
            matScores = matScoresFromProbs(internalScale, matProbs)
            print "# score matrix (query letters = columns, reference letters = rows):"
            writeScoreMatrix(sys.stdout, matScores, "# ")
            print
            parameters = gapCosts, matScores
            if parameters in oldParameters: break
            oldParameters.append(parameters)
            internalMatrix = matrixWithLetters(matScores)
        x = fixedLastalArgs(opts)
        x.append("-p-")
        x += args
        p = subprocess.Popen(x, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
        writeGapCosts(gapCosts, p.stdin)
        writeMatrixWithLetters(p.stdin, internalMatrix, "")
        p.stdin.close()
        # in python2.6, the next line must come after p.stdin.close()
        q = subprocess.Popen(y, stdin=p.stdout, stdout=subprocess.PIPE)

    gapCosts = gapCostsFromProbs(externalScale, gapProbs)
    writeGapCosts(gapCosts, sys.stdout)
    if opts.s: writeLine(sys.stdout, "#last -s", opts.s)
    if opts.S: writeLine(sys.stdout, "#last -S", opts.S)
    if not opts.Q:
        matScores = matScoresFromProbs(externalScale, matProbs)
        externalMatrix = matrixWithLetters(matScores)
    print "# score matrix (query letters = columns, reference letters = rows):"
    writeMatrixWithLetters(sys.stdout, externalMatrix, "")

def lastTrain(opts, args):
    if opts.sample:
        random.seed(math.pi)
        refName = args[0]
        queryFiles = args[1:]
        try:
            with tempfile.NamedTemporaryFile(delete=False) as f:
                getSeqSample(opts, queryFiles, f)
            doTraining(opts, [refName, f.name])
        finally:
            os.remove(f.name)
    else:
        doTraining(opts, args)

if __name__ == "__main__":
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)  # avoid silly error message
    usage = "%prog [options] lastdb-name sequence-file(s)"
    description = "Try to find suitable score parameters for aligning the given sequences."
    op = optparse.OptionParser(usage=usage, description=description)
    og = optparse.OptionGroup(op, "Training options")
    og.add_option("--revsym", action="store_true",
                  help="force reverse-complement symmetry")
    og.add_option("--matsym", action="store_true",
                  help="force symmetric substitution matrix")
    og.add_option("--gapsym", action="store_true",
                  help="force insertion/deletion symmetry")
    og.add_option("--pid", type="float", default=100, help=
                  "skip alignments with > PID% identity (default: %default)")
    og.add_option("--sample", type="float", default=1000000, metavar="N",
                  help="use a random sample of N letters (default: %default)")
    op.add_option_group(og)
    og = optparse.OptionGroup(op, "Initial parameter options")
    og.add_option("-r", metavar="SCORE",
                  help="match score (default: 6 if Q>0, else 5)")
    og.add_option("-q", metavar="COST",
                  help="mismatch cost (default: 18 if Q>0, else 5)")
    og.add_option("-p", metavar="NAME", help="match/mismatch score matrix")
    og.add_option("-a", metavar="COST",
                  help="gap existence cost (default: 21 if Q>0, else 15)")
    og.add_option("-b", metavar="COST",
                  help="gap extension cost (default: 9 if Q>0, else 3)")
    og.add_option("-A", metavar="COST", help="insertion existence cost")
    og.add_option("-B", metavar="COST", help="insertion extension cost")
    op.add_option_group(og)
    og = optparse.OptionGroup(op, "Alignment options")
    og.add_option("-D", metavar="LENGTH",
                  help="query letters per random alignment (default: 1e6)")
    og.add_option("-E", metavar="EG2",
                  help="maximum expected alignments per square giga")
    og.add_option("-s", metavar="STRAND", help=
                  "0=reverse, 1=forward, 2=both (default: 2 if DNA, else 1)")
    og.add_option("-S", metavar="NUMBER", default="1", help=
                  "score matrix applies to forward strand of: " +
                  "0=reference, 1=query (default: %default)")
    og.add_option("-T", metavar="NUMBER",
                  help="type of alignment: 0=local, 1=overlap (default: 0)")
    og.add_option("-m", metavar="COUNT", help=
                  "maximum initial matches per query position (default: 10)")
    og.add_option("-P", metavar="THREADS",
                  help="number of parallel threads")
    og.add_option("-Q", metavar="NUMBER",
                  help="input format: 0=fasta, 1=fastq-sanger")
    op.add_option_group(og)
    (opts, args) = op.parse_args()
    if len(args) < 2: op.error("I need a lastdb index and query sequences")
    if not opts.p and (not opts.Q or opts.Q == "0"):
        if not opts.r: opts.r = "5"
        if not opts.q: opts.q = "5"
        if not opts.a: opts.a = "15"
        if not opts.b: opts.b = "3"

    try: lastTrain(opts, args)
    except KeyboardInterrupt: pass  # avoid silly error message
    except Exception, e:
        prog = os.path.basename(sys.argv[0])
        sys.exit(prog + ": error: " + str(e))
