#! /usr/bin/env python

# Copyright 2014 Martin C. Frith

# Read MAF-format alignments, and write those that have a segment with
# score >= threshold, with gentle masking of lowercase letters.  There
# must be a lastal header with score parameters.

# Gentle masking is described in: MC Frith, PLoS One 2011;6(12):e28819
# "Gentle masking of low-complexity sequences improves homology search"

# Limitations: doesn't (yet) handle sequence quality data,
# frameshifts, or generalized affine gaps.

from __future__ import print_function

import gzip
import optparse, os, signal, sys

try:
    from future_builtins import zip
except ImportError:
    pass

def myOpen(fileName):
    if fileName == "-":
        return sys.stdin
    if fileName.endswith(".gz"):
        return gzip.open(fileName, "rt")  # xxx dubious for Python2
    return open(fileName)

def complement(base):
    x = "ACGTRYKMBDHV"
    y = "TGCAYRMKVHDB"
    i = x.find(base)
    return y[i] if i >= 0 else base

def fastScoreMatrix(rowHeads, colHeads, matrix, deleteCost, insertCost):
    matrixLen = 128
    defaultScore = min(map(min, matrix))
    fastMatrix = [[defaultScore for i in range(matrixLen)]
                  for j in range(matrixLen)]
    for i, x in enumerate(rowHeads):
        for j, y in enumerate(colHeads):
            xu = ord(x.upper())
            xl = ord(x.lower())
            yu = ord(y.upper())
            yl = ord(y.lower())
            score = matrix[i][j]
            maskScore = min(score, 0)
            fastMatrix[xu][yu] = score
            fastMatrix[xu][yl] = maskScore
            fastMatrix[xl][yu] = maskScore
            fastMatrix[xl][yl] = maskScore
    for i in range(matrixLen):
        fastMatrix[i][ord("-")] = -deleteCost
        fastMatrix[ord("-")][i] = -insertCost
    return fastMatrix

def matrixPerStrand(rowHeads, colHeads, matrix, deleteCost, insertCost):
    rowComps = [complement(i) for i in rowHeads]
    colComps = [complement(i) for i in colHeads]
    fwd = fastScoreMatrix(rowHeads, colHeads, matrix, deleteCost, insertCost)
    rev = fastScoreMatrix(rowComps, colComps, matrix, deleteCost, insertCost)
    return fwd, rev

def isGoodAlignment(columns, scoreMatrix, delOpenCost, insOpenCost, minScore):
    """Does the alignment have a segment with score >= minScore?"""
    score = 0
    xOld = yOld = " "
    for x, y in columns:
        score += scoreMatrix[ord(x)][ord(y)]
        if score >= minScore:
            return True
        if x == "-" and xOld != "-":
            score -= insOpenCost
        if y == "-" and yOld != "-":
            score -= delOpenCost
        if score < 0:
            score = 0
        xOld = x
        yOld = y
    return False

def printIfGood(maf, seqs, scoreMatrix, delOpenCost, insOpenCost, minScore):
    cols = zip(*seqs)
    if isGoodAlignment(cols, scoreMatrix, delOpenCost, insOpenCost, minScore):
        for line in maf:
            print(line, end="")
        print()

def doOneFile(lines):
    aDel = bDel = aIns = bIns = minScore = matrices = None
    strandParam = 0
    scoreMatrix = []
    rowHeads = []
    colHeads = []
    maf = []
    seqs = []

    for line in lines:
        if line[0] == "#":
            print(line, end="")
            fields = line.split()
            nf = len(fields)
            if not colHeads:
                for i in fields:
                    if i.startswith("a="): aDel = int(i[2:])
                    if i.startswith("b="): bDel = int(i[2:])
                    if i.startswith("A="): aIns = int(i[2:])
                    if i.startswith("B="): bIns = int(i[2:])
                    if i.startswith("e="): minScore = int(i[2:])
                    if i.startswith("S="): strandParam = int(i[2:])
                if nf > 1 and max(map(len, fields)) == 1:
                    colHeads = fields[1:]
            elif nf == len(colHeads) + 2 and len(fields[1]) == 1:
                rowHeads.append(fields[1])
                scoreMatrix.append([int(i) for i in fields[2:]])
        elif line.isspace():
            if seqs: printIfGood(maf, seqs, matrix, aDel, aIns, minScore)
            maf = []
            seqs = []
        else:
            maf.append(line)
            if line[0] == "s":
                if not matrices:
                    if None in (aDel, bDel, aIns, bIns, minScore):
                        raise Exception("can't read alignment header")
                    matrices = matrixPerStrand(rowHeads, colHeads,
                                               scoreMatrix, bDel, bIns)
                fields = line.split()
                seqs.append(fields[6])
                if len(seqs) == 2:
                    isReverse = (fields[4] == "-")
                    matrix = matrices[isReverse * strandParam]
    if seqs: printIfGood(maf, seqs, matrix, aDel, aIns, minScore)

def lastPostmask(args):
    if not args:
        args = ["-"]
    for i in args:
        f = myOpen(i)
        doOneFile(f)
        f.close()

if __name__ == "__main__":
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)  # avoid silly error message

    usage = "%prog in.maf > out.maf"
    description = "Get alignments that have a segment with score >= threshold, with gentle masking of lowercase letters."
    op = optparse.OptionParser(usage=usage, description=description)
    (opts, args) = op.parse_args()

    try: lastPostmask(args)
    except KeyboardInterrupt: pass  # avoid silly error message
    except Exception as e:
        prog = os.path.basename(sys.argv[0])
        sys.exit(prog + ": error: " + str(e))
