#! /usr/bin/env python
# Author: Martin C. Frith 2018

from __future__ import print_function

import argparse
import collections
import gzip
import itertools
import signal
import sys

def myOpen(fileName):  # faster than fileinput
    if fileName == "-":
        return sys.stdin
    if fileName.endswith(".gz"):
        return gzip.open(fileName, "rt")  # xxx dubious for Python2
    return open(fileName)

def alnBegFromSeqBeg(gappedSequence, seqBeg):
    for i, x in enumerate(gappedSequence):
        if x != "-":
            if seqBeg == 0:
                return i
            seqBeg -= 1

def alnEndFromSeqEnd(gappedSequence, seqEnd):
    for i, x in enumerate(gappedSequence):
        if x != "-":
            seqEnd -= 1
            if seqEnd == 0:
                return i + 1

def alignmentRange(cutBeg, cutEnd, sLineFields):
    if sLineFields[4] == "-":
        seqLen = int(sLineFields[5])
        cutBeg, cutEnd = seqLen - cutEnd, seqLen - cutBeg
    beg = int(sLineFields[2])
    if beg >= cutEnd:
        return 0, 0
    span = int(sLineFields[3])
    end = beg + span
    if end <= cutBeg:
        return 0, 0
    sequenceWithGaps = sLineFields[6]
    monomerLen = (len(sequenceWithGaps) - sequenceWithGaps.count("-")) // span
    seqBeg = max(cutBeg - beg, 0) * monomerLen
    alnBeg = alnBegFromSeqBeg(sequenceWithGaps, seqBeg)
    seqEnd = min(cutEnd - beg, span) * monomerLen
    alnEnd = alnEndFromSeqEnd(sequenceWithGaps, seqEnd)
    return alnBeg, alnEnd

def findTheSpecifiedSequence(seqName, mafLines):
    for line in mafLines:
        if line[0] == "s":
            fields = line.split()
            if seqName is None or fields[1] == seqName:
                return fields
    return None

def sLineCutData(fields, alnBeg, alnEnd):
    span = int(fields[3])
    seq = fields[6]
    n = len(seq)
    begSpan = alnBeg - seq.count("-", 0, alnBeg)
    endSpan = n - alnEnd - seq.count("-", alnEnd, n)
    mul = (n - seq.count("-")) // span
    b = begSpan // mul
    e = endSpan // mul
    return int(fields[2]) + b, span - b - e, begSpan % mul, endSpan % mul

def addMafData(dataPerSeq, mafLines, alnBeg, alnEnd):
    for line in mafLines:
        if line[0] == "s":
            fields = line.split()
            seqName = fields[1]  # xxx omit theSpecifiedSequence ?
            strand = fields[4]
            seqLen = int(fields[5])
            beg, span, begRem, endRem = sLineCutData(fields, alnBeg, alnEnd)
            alnBeg -= begRem
            alnEnd += endRem
            end = beg + span
            if strand == "-":
                beg, end = seqLen - end, seqLen - beg
            seqData = seqLen, beg, end
            dataPerSeq[seqName].append(seqData)

def cutMafRecords(mafLines, alnBeg, alnEnd):
    for line in mafLines:
        fields = line.split()
        if line[0] == "s":
            beg, span, begRem, endRem = sLineCutData(fields, alnBeg, alnEnd)
            alnBeg -= begRem
            alnEnd += endRem
            newSeq = fields[6][alnBeg:alnEnd]
            yield fields[:2] + [str(beg), str(span)] + fields[4:6] + [newSeq]
        elif line[0] == "q":
            yield fields[:2] + [fields[2][alnBeg:alnEnd]]
        elif line[0] == "p":
            yield fields[:1] + [fields[1][alnBeg:alnEnd]]
        else:
            yield fields

def mafFieldWidths(mafRecords):
    sRecords = (i for i in mafRecords if i[0] == "s")
    sColumns = zip(*sRecords)
    for i in sColumns:
        yield max(map(len, i))

def printMafLine(fieldWidths, fields):
    formatParams = itertools.chain.from_iterable(zip(fieldWidths, fields))
    print("%*s %-*s %*s %*s %*s %*s %*s" % tuple(formatParams))

def printOneMaf(mafRecords):
    fieldWidths = list(mafFieldWidths(mafRecords))
    for fields in mafRecords:
        if fields[0] == "s":
            printMafLine(fieldWidths, fields)
        elif fields[0] == "q":
            printMafLine(fieldWidths, fields[:2] + [""] * 4 + fields[2:])
        elif fields[0] == "p":
            printMafLine(fieldWidths, fields[:1] + [""] * 5 + fields[1:])
        else:
            print(" ".join(fields))
    print()

def nameFromTitleLine(line):
    return line[1:].split()[0]

def doOneFastaSequence(dataPerSeq, topLine, seqLines):
    alnData = dataPerSeq.get(nameFromTitleLine(topLine))
    if alnData:
        seq = "".join(i.rstrip() for i in seqLines)
        seqLen = len(seq)
        alnData = [i for i in alnData if i[0] == seqLen]
        if alnData:
            beg = min(i[1] for i in alnData)
            end = max(i[2] for i in alnData)
            print(topLine, end="")
            print(seq[beg:end])

def doOneFastqSequence(dataPerSeq, lines):
    alnData = dataPerSeq.get(nameFromTitleLine(lines[0]))
    if alnData:
        seqLen = len(lines[1].rstrip())
        alnData = [i for i in alnData if i[0] == seqLen]
        if alnData:
            beg = min(i[1] for i in alnData)
            end = max(i[2] for i in alnData)
            print(lines[0], end="")
            print(lines[1][beg:end])
            print(lines[2], end="")
            print(lines[3][beg:end])

def printFastaParts(dataPerSeq, topLine, lines):
    seqLines = []
    for line in lines:
        if line[0] == ">":
            doOneFastaSequence(dataPerSeq, topLine, seqLines)
            topLine = line
            seqLines = []
        else:
            seqLines.append(line)
    doOneFastaSequence(dataPerSeq, topLine, seqLines)

def printFastqParts(dataPerSeq, topLine, lines):
    b = [topLine]
    for i, x in enumerate(lines):
        b.append(x)
        if i % 4 == 2:
            doOneFastqSequence(dataPerSeq, b)
            b = []

def printSequenceParts(dataPerSeq, lines):
    for line in lines:
        if line[0] == ">":
            printFastaParts(dataPerSeq, line, lines)
        elif line[0] == "@":
            printFastqParts(dataPerSeq, line, lines)
        else:
            raise RuntimeError("can't read the sequence data")
        break

def cutOneMaf(seqName, cutBeg, cutEnd, dataPerSeq, mafLines):
    sLineFields = findTheSpecifiedSequence(seqName, mafLines)
    if not sLineFields:
        return
    alnBeg, alnEnd = alignmentRange(cutBeg, cutEnd, sLineFields)
    if alnBeg >= alnEnd:
        return
    if dataPerSeq is None:
        mafRecords = list(cutMafRecords(mafLines, alnBeg, alnEnd))
        printOneMaf(mafRecords)
    else:
        addMafData(dataPerSeq, mafLines, alnBeg, alnEnd)

def wantedRange(cutSpecification):
    seqName = None
    s = cutSpecification.split()
    if len(s) == 3:
        seqName, beg, end = s
    elif len(s) == 2:
        beg, end = s
    else:
        if ":" in cutSpecification:
            seqName, cutSpecification = cutSpecification.rsplit(":", 1)
        beg, end = cutSpecification.rsplit("-", 1)
    return seqName, int(beg), int(end)

def mafCut(args):
    seqName, cutBeg, cutEnd = wantedRange(args.segment)
    dataPerSeq = collections.defaultdict(list) if args.sequences else None

    mafLines = []
    for line in myOpen(args.alignments):
        if line.isspace():
            cutOneMaf(seqName, cutBeg, cutEnd, dataPerSeq, mafLines)
            mafLines = []
        elif line[0] != "#":
            mafLines.append(line)
    cutOneMaf(seqName, cutBeg, cutEnd, dataPerSeq, mafLines)

    if dataPerSeq:
        for fileName in args.sequences:
            printSequenceParts(dataPerSeq, myOpen(fileName))

if __name__ == "__main__":
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)  # avoid silly error message
    description = "Get parts of MAF-format alignments."
    ap = argparse.ArgumentParser(description=description)
    ap.add_argument("segment",
                    help="zero-based sequence range, e.g. chr7:123400-123500")
    ap.add_argument("alignments", nargs="?", default="-",
                    help="file of sequence alignments in MAF format")
    ap.add_argument("sequences", nargs="*", default=None,  # xxx need default?
                    help="file of sequences in fasta or fastq format")
    args = ap.parse_args()
    mafCut(args)
