#! /usr/bin/env python
# Copyright 2015 Martin C. Frith
# SPDX-License-Identifier: GPL-3.0-or-later

# References:
# [Fri19] How sequence alignment scores correspond to probability models,
#         MC Frith, Bioinformatics, 2019.

from __future__ import division, print_function

import gzip
import math
import optparse
import os
import random
import signal
import subprocess
import sys
import tempfile

proteinAlphabet20 = "ACDEFGHIKLMNPQRSTVWY"
proteinAlphabet21 = proteinAlphabet20 + "*"

def myOpen(fileName):  # faster than fileinput
    if fileName == "-":
        return sys.stdin
    if fileName.endswith(".gz"):
        return gzip.open(fileName, "rt")  # xxx dubious for Python2
    return open(fileName)

def rootOfIncreasingFunction(func, lowerBound, upperBound, args):
    # Find x such that func(x, *args) == 0
    gap = upperBound - lowerBound
    while True:
        gap *= 0.5
        mid = lowerBound + gap
        if mid <= lowerBound:
            return mid
        if func(mid, *args) < 0:
            lowerBound = mid

def rootOfDecreasingFunction(func, lowerBound, upperBound, args):
    # Find x such that func(x, *args) == 0
    gap = upperBound - lowerBound
    while True:
        gap *= 0.5
        mid = lowerBound + gap
        if mid <= lowerBound:
            return mid
        if func(mid, *args) > 0:
            lowerBound = mid

def homogeneousLetterFreqs(scale, matScores):
    # Solve the simultaneous equations in Section 2.1 of [Fri19]
    expMat = [[math.exp(j / scale) for j in i] for i in matScores]
    m = [row[:] + [1.0] for row in expMat]  # augmented matrix
    n = len(expMat)
    for k in range(n):
        iMax = k
        for i in range(k, n):
            if abs(m[i][k]) > abs(m[iMax][k]):
                iMax = i
        if iMax > k:
            m[k], m[iMax] = m[iMax], m[k]
        if abs(m[k][k]) <= 0:
            raise ArithmeticError("singular matrix")
        for i in range(n):
            if i != k:
                mul = m[i][k] / m[k][k]
                for j in range(k + 1, n + 1):
                    m[i][j] -= m[k][j] * mul
    return [m[k][n] / m[k][k] for k in range(n)]

def randomSample(things, sampleSize):
    """Randomly get sampleSize things (or all if fewer)."""
    reservoir = []  # "reservoir sampling" algorithm
    for i, x in enumerate(things):
        if i < sampleSize:
            reservoir.append(x)
        else:
            r = random.randrange(i + 1)
            if r < sampleSize:
                reservoir[r] = x
    return reservoir

def writeWords(outFile, words):
    print(*words, file=outFile)

def seqInput(fileNames):
    if not fileNames:
        fileNames = ["-"]
    for name in fileNames:
        f = myOpen(name)
        seqType = 0
        for line in f:
            if seqType == 0:
                if line[0] == ">":
                    seqType = 1
                    seq = []
                elif line[0] == "@":
                    seqType = 2
                    lineType = 1
            elif seqType == 1:  # fasta
                if line[0] == ">":
                    yield "".join(seq), ""
                    seq = []
                else:
                    seq.append(line.rstrip())
            elif seqType == 2:  # fastq
                if lineType == 1:
                    seq = line.rstrip()
                elif lineType == 3:
                    yield seq, line.rstrip()
                lineType = (lineType + 1) % 4
        if seqType == 1: yield "".join(seq), ""
        f.close()

def isGoodChunk(chunk):
    for i in chunk:
        for j in i[3]:
            if j not in "Nn":
                return True
    return False

def chunkInput(opts, sequences):
    chunkCount = 0
    chunk = []
    wantedLength = opts.sample_length
    for i, x in enumerate(sequences):
        seq, qual = x
        if all(i in "Nn" for i in seq): continue
        seqLength = len(seq)
        beg = 0
        while beg < seqLength:
            length = min(wantedLength, seqLength - beg)
            end = beg + length
            segment = i, beg, end, seq[beg:end], qual[beg:end]
            chunk.append(segment)
            wantedLength -= length
            if not wantedLength:
                if isGoodChunk(chunk):
                    yield chunk
                    chunkCount += 1
                chunk = []
                wantedLength = opts.sample_length
            beg = end
    if chunk and chunkCount < opts.sample_number:
        yield chunk

def writeSegment(outfile, segment):
    if not segment: return
    i, beg, end, seq, qual = segment
    name = str(i) + ":" + str(beg)
    if qual:
        outfile.write("@" + name + "\n")
        outfile.write(seq)
        outfile.write("\n+\n")
        outfile.write(qual)
    else:
        outfile.write(">" + name + "\n")
        outfile.write(seq)
    outfile.write("\n")

def getSeqSample(opts, queryFiles, outfile):
    sequences = seqInput(queryFiles)
    chunks = chunkInput(opts, sequences)
    sample = randomSample(chunks, opts.sample_number)
    sample.sort()
    x = None
    for chunk in sample:
        for y in chunk:
            if x and y[0] == x[0] and y[1] == x[2]:
                x = x[0], x[1], y[2], x[3] + y[3], x[4] + y[4]
            else:
                writeSegment(outfile, x)
                x = y
    writeSegment(outfile, x)

def scaleFromHeader(lines):
    for line in lines:
        for i in line.split():
            if i.startswith("t="):
                return float(i[2:])
    raise Exception("couldn't read the scale")

def countsFromLastOutput(opts, maxPercentIdentity, codonMatches, lines):
    nTransitions = 9 if opts.codon else 5
    tranCounts = [1.0] * nTransitions  # +1 pseudocounts
    tranCounts[1] = 2.0  # deletes: opens + extensions, so 2 pseudocounts
    tranCounts[2] = 2.0  # inserts: opens + extensions, so 2 pseudocounts
    countMatrix = None
    alignments = 0  # no pseudocount here
    for line in lines:
        if line[0] == "s":
            strand = line.split()[4]  # slow?
        if line[0] == "c":
            counts = [float(i) for i in line.split()[1:]]
            if not countMatrix:
                matrixSize = len(counts) - nTransitions
                nCols = 64 if opts.codon else int(math.sqrt(matrixSize))
                nRows = matrixSize // nCols
                pseudocount = 0.0 if opts.codon else 1.0
                countMatrix = [[pseudocount] * nCols for i in range(nRows)]
            if opts.codon:
                identities = sum(counts[codonMatches[i] * nCols + i]
                                 for i in range(nCols))
            else:
                identities = sum(counts[i * nCols + i] for i in range(nRows))
            alignmentLength = sum(counts[:matrixSize])
            if 100 * identities > maxPercentIdentity * alignmentLength:
                continue
            for i in range(nRows):
                for j in range(nCols):
                    if strand == "+" or opts.S != "1":
                        countMatrix[i][j]       += counts[i * nCols + j]
                    else:
                        countMatrix[-1-i][-1-j] += counts[i * nCols + j]
            for i in range(nTransitions):
                tranCounts[i] += counts[matrixSize + i]
            alignments += 1
    if not alignments:
        raise Exception("no alignments")
    if opts.codon:
        pseudocounts = nRows * 32  # xxx ???
        rowSums = [sum(i) + 1 for i in countMatrix]
        colSums = [sum(i) + 1 for i in zip(*countMatrix)]
        mul = pseudocounts / (sum(rowSums) * sum(colSums))
        countMatrix = [[x + mul * i * j for j, x in zip(colSums, row)]
                       for i, row in zip(rowSums, countMatrix)]
    return countMatrix, tranCounts + [alignments]

def scoreFromProb(scale, prob):
    if prob > 0: logProb = math.log(prob)
    else:        logProb = -800  # exp(-800) is exactly zero, on my computer
    return int(round(scale * logProb))

def costFromProb(scale, prob):
    return -scoreFromProb(scale, prob)

def guessAlphabet(matrixSize):
    if matrixSize ==  4: return "ACGT"
    if matrixSize == 20: return proteinAlphabet20
    raise Exception("can't handle unusual alphabets")

def writeMatrixHead(outFile, prefix, alphabet, formatString):
    writeWords(outFile, [prefix + " "] + [formatString % k for k in alphabet])

def writeMatrixBody(outFile, prefix, alphabet, matrix, formatString):
    for i, j in zip(alphabet, matrix):
        writeWords(outFile, [prefix + i] + [formatString % k for k in j])

def writeCountMatrix(outFile, matrix, prefix):
    alphabet = guessAlphabet(len(matrix))
    writeMatrixHead(outFile, prefix, alphabet, "%-14s")
    writeMatrixBody(outFile, prefix, alphabet, matrix, "%-14.12g")

def writeProbMatrix(outFile, matrix, prefix):
    alphabet = guessAlphabet(len(matrix))
    writeMatrixHead(outFile, prefix, alphabet, "%-14s")
    writeMatrixBody(outFile, prefix, alphabet, matrix, "%-14g")

def writeScoreMatrix(outFile, matrix, prefix):
    alphabet = guessAlphabet(len(matrix))
    writeMatrixHead(outFile, prefix, alphabet, "%6s")
    writeMatrixBody(outFile, prefix, alphabet, matrix, "%6s")

def matProbsFromCounts(counts, opts):
    r = range(len(counts))
    if opts.revsym:  # add complement (reverse strand) substitutions
        counts = [[counts[i][j] + counts[-1-i][-1-j] for j in r] for i in r]
    if opts.matsym:  # symmetrize the substitution matrix
        counts = [[counts[i][j] + counts[j][i] for j in r] for i in r]
    identities = sum(counts[i][i] for i in r)
    total = sum(map(sum, counts))
    probs = [[j / total for j in i] for i in counts]
    print("# substitution percent identity: %g" % (100 * identities / total))
    print()
    print("# count matrix "
          "(query letters = columns, reference letters = rows):")
    writeCountMatrix(sys.stdout, counts, "# ")
    print()
    print("# probability matrix "
          "(query letters = columns, reference letters = rows):")
    writeProbMatrix(sys.stdout, probs, "# ")
    print()
    return probs

def probImbalance(endProb, matchProb, firstDelProb, delExtendProb,
                  firstInsProb, insExtendProb):
    # (RHS - LHS) of Equation (12) in [Fri19]
    d = firstDelProb / (endProb - delExtendProb)
    i = firstInsProb / (endProb - insExtendProb)
    return 1 - matchProb / (endProb * endProb) - d - i

def balancedEndProb(*args):
    matchProb, firstDelProb, delExtendProb, firstInsProb, insExtendProb = args
    lowerBound = max(delExtendProb, insExtendProb)
    upperBound = 1.0
    return rootOfIncreasingFunction(probImbalance,
                                    lowerBound, upperBound, args)

def gapProbsFromCounts(counts, opts, maxGapGrowProb):
    matches, deletes, inserts, delOpens, insOpens, alignments = counts
    gaps = deletes + inserts
    gapOpens = delOpens + insOpens
    denominator = matches + gapOpens + (alignments + 1)  # +1 pseudocount
    matchProb = matches / denominator
    if opts.gapsym:
        delOpenProb = gapOpens / denominator / 2
        insOpenProb = gapOpens / denominator / 2
        delGrowProb = (gaps - gapOpens) / gaps
        insGrowProb = (gaps - gapOpens) / gaps
    else:
        delOpenProb = delOpens / denominator
        insOpenProb = insOpens / denominator
        delGrowProb = (deletes - delOpens) / deletes
        insGrowProb = (inserts - insOpens) / inserts
    print("# aligned letter pairs: %.12g" % matches)
    print("# deletes: %.12g" % deletes)
    print("# inserts: %.12g" % inserts)
    print("# delOpens: %.12g" % delOpens)
    print("# insOpens: %.12g" % insOpens)
    print("# alignments:", alignments)
    print("# mean delete size: %g" % (deletes / delOpens))
    print("# mean insert size: %g" % (inserts / insOpens))
    print("# matchProb: %g" % matchProb)
    print("# delOpenProb: %g" % delOpenProb)
    print("# insOpenProb: %g" % insOpenProb)
    print("# delExtendProb: %g" % delGrowProb)
    print("# insExtendProb: %g" % insGrowProb)
    print()
    delGrowProb = min(delGrowProb, maxGapGrowProb)
    insGrowProb = min(insGrowProb, maxGapGrowProb)
    return matchProb, (delOpenProb, delGrowProb), (insOpenProb, insGrowProb)

def gapRatiosFromProbs(matchProb, delProbs, insProbs):
    delOpenProb, delGrowProb = delProbs
    insOpenProb, insGrowProb = insProbs

    delCloseProb = 1 - delGrowProb
    firstDelProb = delOpenProb * delCloseProb

    insCloseProb = 1 - insGrowProb
    firstInsProb = insOpenProb * insCloseProb

    endProb = balancedEndProb(matchProb, firstDelProb, delGrowProb,
                              firstInsProb, insGrowProb)
    # probably, endProb is negligibly less than 1

    matchRatio = matchProb / (endProb * endProb)

    firstDelRatio = firstDelProb / endProb
    delGrowRatio = delGrowProb / endProb
    delRatios = firstDelRatio, delGrowRatio

    firstInsRatio = firstInsProb / endProb
    insGrowRatio = insGrowProb / endProb
    insRatios = firstInsRatio, insGrowRatio

    return matchRatio, delRatios, insRatios

def scoreFromLetterProbs(scale, matchRatio, pairProb, rowProb, colProb):
    # Equation (4) in [Fri19]
    probRatio = pairProb / (rowProb * colProb)
    return scoreFromProb(scale, matchRatio * probRatio)

def matScoresFromProbs(scale, matchRatio, matProbs, rowProbs, colProbs):
    return [[scoreFromLetterProbs(scale, matchRatio, matProbs[i][j], x, y)
             for j, y in enumerate(colProbs)] for i, x in enumerate(rowProbs)]

def gapCostsFromProbRatios(scale, firstGapRatio, gapExtendRatio):
    # The next addition gets the alignment parameter from the path
    # parameters, as in Supplementary section 3.1 of [Fri19]:
    gapExtendRatio += firstGapRatio
    firstGapCost = max(costFromProb(scale, firstGapRatio), 1)
    gapExtendCost = max(costFromProb(scale, gapExtendRatio), 1)
    return firstGapCost, gapExtendCost

def imbalanceFromGap(scale, firstGapCost, gapExtendCost):
    firstGapRatio = math.exp(-firstGapCost / scale)
    gapExtendRatio = math.exp(-gapExtendCost / scale)
    # The next subtraction gets the path parameter from the alignment
    # parameters, as in Supplementary section 3.1 of [Fri19]:
    gapExtendRatio -= firstGapRatio
    return firstGapRatio / (1 - gapExtendRatio)

def scoreImbalance(scale, matScores, delCosts, insCosts):
    # C' - 1, where C' is defined in Equation (13) of [Fri19]
    d = imbalanceFromGap(scale, *delCosts)
    i = imbalanceFromGap(scale, *insCosts)
    return 1 / sum(homogeneousLetterFreqs(scale, matScores)) + d + i - 1

def balancedScale(imbalanceFunc, nearScale, args):
    # Find a scale, near nearScale, with balanced length probability
    bump = 1.000001
    rootFinders = rootOfDecreasingFunction, rootOfIncreasingFunction
    value = imbalanceFunc(nearScale, *args)
    if abs(value) <= 0:
        return nearScale
    oldLower = oldUpper = nearScale
    while oldUpper < 2 * nearScale:  # xxx ???
        newLower = oldLower / bump
        lowerValue = imbalanceFunc(newLower, *args)
        if (lowerValue < 0) != (value < 0):
            finder = rootFinders[value > 0]
            return finder(imbalanceFunc, newLower, oldLower, args)
        oldLower = newLower
        newUpper = oldUpper * bump
        upperValue = imbalanceFunc(newUpper, *args)
        if (upperValue < 0) != (value < 0):
            finder = rootFinders[value < 0]
            return finder(imbalanceFunc, oldUpper, newUpper, args)
        oldUpper = newUpper
    return 0.0

def scoresAndScale(originalScale, matParams, delRatios, insRatios):
    while True:
        matScores = matScoresFromProbs(originalScale, *matParams)
        delCosts = gapCostsFromProbRatios(originalScale, *delRatios)
        insCosts = gapCostsFromProbRatios(originalScale, *insRatios)
        args = matScores, delCosts, insCosts
        scale = balancedScale(scoreImbalance, originalScale, args)
        if scale > 0:
            rowFreqs = homogeneousLetterFreqs(scale, zip(*matScores))
            colFreqs = homogeneousLetterFreqs(scale, matScores)
            if all(i >= 0 for i in rowFreqs + colFreqs):
                return matScores, delCosts, insCosts, scale, rowFreqs, colFreqs
        print("# the integer-rounded scores are too inaccurate: "
              "increasing the scale")
        originalScale *= 1.1

### Routines for codons & frameshifts:

def initialCodonSubstitutionProbs(matchProb):
    aa = "FFLLSSSSYY**CC*WLLLLPPPPHHQQRRRRIIIMTTTTNNKKSSRRVVVVAAAADDEEGGGG"
    b1 = "TTTTTTTTTTTTTTTTCCCCCCCCCCCCCCCCAAAAAAAAAAAAAAAAGGGGGGGGGGGGGGGG"
    b2 = "TTTTCCCCAAAAGGGGTTTTCCCCAAAAGGGGTTTTCCCCAAAAGGGGTTTTCCCCAAAAGGGG"
    b3 = "TCAGTCAGTCAGTCAGTCAGTCAGTCAGTCAGTCAGTCAGTCAGTCAGTCAGTCAGTCAGTCAG"

    p = matchProb / 61
    q = (1/64 - p) / 19
    if q <= 0:
        raise Exception("initial match probability must be < 61/64")
    matrix = [[q for j in aa] for i in proteinAlphabet20]
    for a, x, y, z in zip(aa, b1, b2, b3):
        codon = "ACGT".index(x) * 16 + "ACGT".index(y) * 4 + "ACGT".index(z)
        if a == "*":
            for row in matrix:
                row[codon] = 1 / (64 * 20)
        else:
            matrix[proteinAlphabet20.index(a)][codon] = p
    return matrix

def initialCodonProbs(opts):
    matProbs = initialCodonSubstitutionProbs(float(opts.r))
    delOpenProb = float(opts.a)
    delGrowProb = float(opts.b)
    insOpenProb = float(opts.A)
    insGrowProb = float(opts.B)
    matchProb = 0.99 - delOpenProb - insOpenProb
    if opts.F:
        delProb1, delProb2, insProb1, insProb2 = map(float, opts.F.split(","))
    else:
        delProb1 = delProb2 = 1 - delGrowProb
        insProb1 = insProb2 = 1 - insGrowProb
    delProbs = delOpenProb, delGrowProb, delProb1, delProb2
    insProbs = insOpenProb, insGrowProb, insProb1, insProb2
    return matProbs, (matchProb, delProbs, insProbs)

def formattedCodons(spec):
    a = "acgt"
    return (format(i + j + k, spec) for i in a for j in a for k in a)

def printCodonCountMatrix(matrix):
    print("#", " ", *formattedCodons("5"))
    for x, row in zip(proteinAlphabet21, matrix):
        print("#", x, *(format(i, "<5.4g") for i in row))

def writeCodonScoreMatrix(outFile, matrixAndProbs, prefix):
    matrix, rowProbs, colProbs = matrixAndProbs
    maxLen = max(len(str(x)) for row in matrix for x in row)
    spec = ">" + str(max(maxLen, 3))
    print(prefix + " ", *formattedCodons(spec), file=outFile)
    for x, row, p in zip(proteinAlphabet21, matrix, rowProbs):
        r = " ".join(format(i, spec) for i in row)
        print(prefix + x, r, p, file=outFile)
    print(prefix + " ", *(format(i, spec) for i in colProbs), file=outFile)

def codonMatProbsFromCounts(counts, opts):
    total = sum(map(sum, counts))
    probs = [[j / total for j in i] for i in counts]
    print("# count matrix "
          "(query letters = columns, reference letters = rows):")
    printCodonCountMatrix(counts)
    print()
    return probs

def bestAminoPerCodon(matProbs, rowProbs):
    for column in zip(*matProbs):
        z = zip(column, rowProbs)
        m = min((-x[0] / x[1], i) for i, x in enumerate(z))
        yield m[1]

def freqText(probability):
    p = 100 * probability
    t = format(p, ".2")
    if len(t) > 3:
        t = format(p, "<3.2g")
    if len(t) > 3:
        t = t.lstrip("0")
    return t

def frameshiftProbImbalance(endProb, matchProb, delProbs, insProbs):
    insOpenProb, insGrowProb, ins1, ins2 = insProbs
    delOpenProb, delGrowProb, del1, del2 = delProbs
    iNum = insOpenProb * (ins1 * endProb ** 2 + ins2 * (1 - ins1) * endProb
                          + (1 - ins1) * (1 - ins2) * (1 - insGrowProb))
    iDen = endProb ** 3 - (1 - ins1) * (1 - ins2) * insGrowProb
    dNum = delOpenProb * (del1 / endProb ** 2 + del2 * (1 - del1) / endProb
                          + (1 - del1) * (1 - del2) * (1 - delGrowProb))
    dDen = endProb ** 3 - (1 - del1) * (1 - del2) * delGrowProb
    return 1 - matchProb / endProb ** 6 - iNum / iDen - dNum / dDen

def balancedFrameshiftEndProb(*args):
    matchProb, delProbs, insProbs = args
    insOpenProb, insGrowProb, ins1, ins2 = insProbs
    delOpenProb, delGrowProb, del1, del2 = delProbs
    lowerBound = max((1 - ins1) * (1 - ins2) * insGrowProb,
                     (1 - del1) * (1 - del2) * delGrowProb) ** (1/3)
    upperBound = 1.0
    return rootOfIncreasingFunction(frameshiftProbImbalance,
                                    lowerBound, upperBound, args)

def frameshiftProbsFromCounts(counts, opts):
    (matches, deletes, inserts, delOpens0, insOpens0,
     delOpens1, delOpens2, insOpens1, insOpens2, alignments) = counts
    delOpens = delOpens0 + delOpens1 + delOpens2
    insOpens = insOpens0 + insOpens1 + insOpens2
    denominator = matches + insOpens + delOpens + (alignments + 1)
    matchProb = matches / denominator

    insOpenProb = insOpens / denominator
    insGrowProb = (inserts - insOpens0) / inserts
    insProb2 = insOpens2 / (inserts + insOpens2)
    insProb1 = insOpens1 / (inserts + insOpens2 + insOpens1)

    delOpenProb = delOpens / denominator
    delGrowProb = (deletes - delOpens0) / deletes
    delProb2 = delOpens2 / (deletes + delOpens2)
    delProb1 = delOpens1 / (deletes + delOpens2 + delOpens1)

    print("# aligned residue/codon pairs: %.12g" % matches)
    print("# whole codon deletes: %.12g" % deletes)
    print("# whole codon inserts: %.12g" % inserts)
    print("# delOpens: %.12g" % delOpens)
    print("# insOpens: %.12g" % insOpens)
    print("# frameshifts del-1,del-2,ins+1,ins+2: %.12g,%.12g,%.12g,%.12g"
          % (delOpens1, delOpens2, insOpens1, insOpens2))
    print("# alignments:", alignments)
    print("# matchProb: %g" % matchProb)
    print("# delOpenProb: %g" % delOpenProb)
    print("# insOpenProb: %g" % insOpenProb)
    print("# delExtendProb: %g" % delGrowProb)
    print("# insExtendProb: %g" % insGrowProb)
    print("# frameshiftProbs del-1,del-2,ins+1,ins+2: %g,%g,%g,%g"
          % (delProb1, delProb2, insProb1, insProb2))
    print()
    delProbs = delOpenProb, delGrowProb, delProb1, delProb2
    insProbs = insOpenProb, insGrowProb, insProb1, insProb2
    return matchProb, delProbs, insProbs

def frameshiftRatiosFromProbs(matchProb, delProbs, insProbs):
    delOpenProb, delGrowProb, del1, del2 = delProbs
    insOpenProb, insGrowProb, ins1, ins2 = insProbs

    endProb = balancedFrameshiftEndProb(matchProb, delProbs, insProbs)

    matchRatio = matchProb / endProb ** 6

    insAdj = (1 - insGrowProb) / insGrowProb
    insOpenRatio = insOpenProb * insAdj
    insMean = ((1 - ins1) * (1 - ins2) * insGrowProb) ** (1/3)
    insRatio0 = insMean / endProb
    insRatio1 = ins1 / (insAdj * insMean)
    insRatio2 = ins2 * (1 - ins1) / (insAdj * insMean ** 2)
    insRatios = insOpenRatio, insRatio0, insRatio1, insRatio2

    delAdj = (1 - delGrowProb) / delGrowProb
    delOpenRatio = delOpenProb * delAdj
    delMean = ((1 - del1) * (1 - del2) * delGrowProb) ** (1/3)
    delRatio0 = delMean / endProb
    delRatio1 = del1 / (delAdj * delMean * endProb ** 4)
    delRatio2 = del2 * (1 - del1) / (delAdj * (delMean * endProb) ** 2)
    delRatios = delOpenRatio, delRatio0, delRatio1, delRatio2

    return matchRatio, delRatios, insRatios

def frameshiftCostsFromProbRatios(scale, gapRatios):
    gapOpenRatio, gapRatio0, gapRatio1, gapRatio2 = gapRatios
    gapOpenCost = costFromProb(scale, gapOpenRatio)
    gapCost0 = max(costFromProb(scale, gapRatio0), 1)
    gapCost1 = costFromProb(scale, gapRatio1)
    gapCost2 = costFromProb(scale, gapRatio2)
    return gapOpenCost, gapCost0, gapCost1, gapCost2

def frameshiftImbalanceFromGap(scale, gapCosts):
    gapOpenCost, gapCost0, gapCost1, gapCost2 = gapCosts
    a = math.exp(-gapOpenCost / scale)
    b = math.exp(-gapCost0 / scale)
    f = math.exp(-gapCost1 / scale)
    g = math.exp(-gapCost2 / scale)
    return a * b * (f + g * b + b * b) / (1 - b ** 3)

def frameshiftScoreImbalance(scale, matScores, rowProbs, colProbs,
                             delCosts, insCosts):
    d = frameshiftImbalanceFromGap(scale, delCosts)
    i = frameshiftImbalanceFromGap(scale, insCosts)
    m = sum(x * y * math.exp(matScores[i][j] / scale)
            for i, x in enumerate(rowProbs) for j, y in enumerate(colProbs))
    return m + d + i - 1

def normalizedFreqs(freqs):
    x = list(map(float, freqs))
    s = sum(x)
    return [i / s for i in x]

def codonScoresAndScale(originalScale, matParams, delRatios, insRatios):
    matchRatio, matProbs, rowFreqs, colFreqs = matParams
    rowProbs = normalizedFreqs(rowFreqs)
    colProbs = normalizedFreqs(colFreqs)

    r = range(4)
    baseProbs = [sum(colProbs[v*16 + i*4 + j] +
                     colProbs[i*16 + v*4 + j] +
                     colProbs[i*16 + j*4 + v] for i in r for j in r) / 3
                 for v in r]
    print("# % a c g t:", *(format(i * 100, ".3") for i in baseProbs))
    print()

    matParams = matchRatio, matProbs, rowProbs, colProbs
    while True:
        matScores = matScoresFromProbs(originalScale, *matParams)
        delCosts = frameshiftCostsFromProbRatios(originalScale, delRatios)
        insCosts = frameshiftCostsFromProbRatios(originalScale, insRatios)
        args = matScores, rowProbs, colProbs, delCosts, insCosts
        scale = balancedScale(frameshiftScoreImbalance, originalScale, args)
        if scale > 0:
            matScores = matScores, rowFreqs, colFreqs
            return matScores, delCosts, insCosts, scale, None, None
        print("# the integer-rounded scores are too inaccurate: "
              "doubling the scale")
        originalScale *= 2

def isCloseEnough(oldParameters, newParameters):
    delCosts0, insCosts0, substitutionParameters0 = oldParameters
    m0, rowFreqs0, colFreqs0 = substitutionParameters0

    delCosts1, insCosts1, substitutionParameters1 = newParameters
    m1, rowFreqs1, colFreqs1 = substitutionParameters1

    return (delCosts0 == delCosts1 and insCosts0 == insCosts1 and
            all(abs(i - j) < 2 for x, y in zip(m0, m1) for i, j in zip(x, y)))

### End of routines for codons & frameshifts

def writeGapCosts(opts, delCosts, insCosts, isLastFormat, outFile):
    if opts.codon:
        delOpen, delGrow, del1, del2 = delCosts
        insOpen, insGrow, ins1, ins2 = insCosts
        frameshiftCosts = del1, del2, ins1, ins2
        frameshiftCosts = ",".join(map(str, frameshiftCosts))
    else:
        delInit, delGrow = delCosts
        insInit, insGrow = insCosts
        delOpen = delInit - delGrow
        insOpen = insInit - insGrow
    if isLastFormat:
        print("#last -a", delOpen, file=outFile)
        print("#last -A", insOpen, file=outFile)
        print("#last -b", delGrow, file=outFile)
        print("#last -B", insGrow, file=outFile)
        if opts.codon:
            print("#last -F", frameshiftCosts, file=outFile)
    else:
        print("# delExistCost:", delOpen, file=outFile)
        print("# insExistCost:", insOpen, file=outFile)
        print("# delExtendCost:", delGrow, file=outFile)
        print("# insExtendCost:", insGrow, file=outFile)
        if opts.codon:
            print("# frameshiftCosts del-1,del-2,ins+1,ins+2:",
                  frameshiftCosts, file=outFile)

def probsFromFile(opts, lastalArgs, maxGapGrowProb, codonMatches, lines):
    print("#", *lastalArgs)
    print()
    sys.stdout.flush()
    # For the 1st iteration, use a % identity cutoff closer to 100:
    pid = opts.pid if lastalArgs[-1] == "-p-" else opts.pid * 0.8 + 20
    matCounts, gapCounts = countsFromLastOutput(opts, pid, codonMatches, lines)
    if opts.codon:
        gapProbs = frameshiftProbsFromCounts(gapCounts, opts)
        matProbs = codonMatProbsFromCounts(matCounts, opts)
    else:
        gapProbs = gapProbsFromCounts(gapCounts, opts, maxGapGrowProb)
        matProbs = matProbsFromCounts(matCounts, opts)
    return matProbs, gapProbs

def tryToMakeChildProgramsFindable():
    d = os.path.dirname(__file__)
    # put it first, to avoid getting older versions of LAST:
    os.environ["PATH"] = d + os.pathsep + os.environ["PATH"]

def readLastdbData(lastdbIndexName):
    bitsPerInt = "32"
    with open(lastdbIndexName + ".prj") as f:
        for line in f:
            if line.startswith("alphabet="):
                alphabet = line.split("=")[1].strip()
            if line.startswith("integersize="):
                bitsPerInt = line.split("=")[1].strip()
    suffix = "" if bitsPerInt == "32" else str(int(bitsPerInt) // 8)
    lastalProgName = "lastal" + suffix
    return lastalProgName, alphabet

def fixedLastalArgs(opts, lastalProgName, alphabet):
    x = [lastalProgName, "-j7"]
    if opts.D: x.append("-D" + opts.D)
    if opts.E: x.append("-E" + opts.E)
    if opts.s: x.append("-s" + opts.s)
    if opts.S: x.append("-S" + opts.S)
    if opts.C: x.append("-C" + opts.C)
    if opts.T: x.append("-T" + opts.T)
    if opts.R: x.append("-R" + opts.R)
    if opts.m: x.append("-m" + opts.m)
    if opts.k: x.append("-k" + opts.k)
    if opts.P: x.append("-P" + opts.P)
    if opts.X: x.append("-X" + opts.X)
    if opts.Q: x.append("-Q" + opts.Q)
    if opts.verbose: x.append("-" + "v" * opts.verbose)
    if len(alphabet) < 20:
        x.append("--split-n")
        x.append("--split-m=0.01")  # xxx ???
    else:
        if opts.revsym:
            raise RuntimeError("--revsym is for DNA only")
        x.append("-K1")
    return x

def process(args, inStream):
    return subprocess.Popen(args, stdin=inStream, stdout=subprocess.PIPE,
                            universal_newlines=True)

def versionFromLastal():
    args = ["lastal", "--version"]
    proc = process(args, None)
    return proc.stdout.read().split()[1]

def doTraining(opts, args):
    tryToMakeChildProgramsFindable()
    lastalProgName, alphabet = readLastdbData(args[0])
    lastalVersion = versionFromLastal()

    if not opts.p and (not opts.Q or opts.Q in ("0", "fastx", "keep")):
        if not opts.r: opts.r = "5" if len(alphabet) < 20 else "12"
        if not opts.q: opts.q = "5" if len(alphabet) < 20 else "7"
        if not opts.a: opts.a = "15"
        if not opts.b: opts.b = "3"

    print("# lastal version:", lastalVersion)
    print("# maximum percent identity:", opts.pid)

    if opts.codon:
        scaleIncrease = 1
        gapRatiosFunc = frameshiftRatiosFromProbs
        scoresAndScaleFunc = codonScoresAndScale
        writeScoreMatrixFunc = writeCodonScoreMatrix
    else:
        scaleIncrease = 20  # while training, upscale the scores by this amount
        gapRatiosFunc = gapRatiosFromProbs
        scoresAndScaleFunc = scoresAndScale
        writeScoreMatrixFunc = writeScoreMatrix
        codonMatches = None

        lastalArgs = fixedLastalArgs(opts, lastalProgName, alphabet)
        if opts.r: lastalArgs.append("-r" + opts.r)
        if opts.q: lastalArgs.append("-q" + opts.q)
        if opts.p: lastalArgs.append("-p" + opts.p)
        if opts.a: lastalArgs.append("-a" + opts.a)
        if opts.b: lastalArgs.append("-b" + opts.b)
        if opts.A: lastalArgs.append("-A" + opts.A)
        if opts.B: lastalArgs.append("-B" + opts.B)
        proc = process(lastalArgs + args, None)
        if opts.postmask:
            proc = process(["last-postmask"], proc.stdout)
        if not opts.scale:
            outerScale = scaleFromHeader(proc.stdout)

    if opts.scale:
        outerScale = opts.scale / math.log(2)

    # minimum possible (integer) gap extend cost = 1
    maxGapGrowProb = math.exp(-1 / outerScale)

    innerScale = outerScale * scaleIncrease
    oldParameters = []

    print("# scale of score parameters:", outerScale)
    print("# scale used while training:", innerScale)
    print()

    if opts.codon:
        matProbs, gapProbs = initialCodonProbs(opts)
    else:
        matProbs, gapProbs = probsFromFile(opts, lastalArgs, maxGapGrowProb,
                                           codonMatches, proc.stdout)

    while True:
        matchRatio, delRatios, insRatios = gapRatiosFunc(*gapProbs)
        rowProbs = [sum(i) for i in matProbs]
        colProbs = [sum(i) for i in zip(*matProbs)]
        if opts.codon:
            codonMatches = list(bestAminoPerCodon(matProbs, rowProbs))
            rowProbs = [freqText(i) for i in rowProbs]
            colProbs = [freqText(i) for i in colProbs]
        matParams = matchRatio, matProbs, rowProbs, colProbs
        ss = scoresAndScaleFunc(innerScale, matParams, delRatios, insRatios)
        matScores, delCosts, insCosts, scale, rowFreqs, colFreqs = ss
        writeGapCosts(opts, delCosts, insCosts, False, None)
        print()
        print("# score matrix "
              "(query letters = columns, reference letters = rows):")
        writeScoreMatrixFunc(sys.stdout, matScores, "# ")
        print()
        parameters = delCosts, insCosts, matScores
        if opts.codon:
            if any(isCloseEnough(i, parameters) for i in oldParameters):
                break
        else:
            if parameters in oldParameters:
                break
        oldParameters.append(parameters)
        lastalArgs = fixedLastalArgs(opts, lastalProgName, alphabet)
        lastalArgs.append("-t{0:.6}".format(scale))
        lastalArgs.append("-p-")
        proc = process(lastalArgs + args, subprocess.PIPE)
        writeGapCosts(opts, delCosts, insCosts, True, proc.stdin)
        writeScoreMatrixFunc(proc.stdin, matScores, "")
        proc.stdin.close()
        if opts.postmask and not opts.codon:
            proc = process(["last-postmask"], proc.stdout)
        matProbs, gapProbs = probsFromFile(opts, lastalArgs, maxGapGrowProb,
                                           codonMatches, proc.stdout)

    ss = scoresAndScaleFunc(outerScale, matParams, delRatios, insRatios)
    matScores, delCosts, insCosts, scale, rowFreqs, colFreqs = ss
    if not opts.codon:
        pid = sum(math.exp(matScores[i][i] / scale) * rowFreqs[i] * colFreqs[i]
                  for i in range(len(matScores))) / sum(colFreqs)
        print("# substitution percent identity: {0:.6}".format(100 * pid))
    if opts.X: print("#last -X", opts.X)
    if opts.R: print("#last -R", opts.R)
    if opts.Q: print("#last -Q", opts.Q)
    print("#last -t{0:.6}".format(scale))
    writeGapCosts(opts, delCosts, insCosts, True, None)
    if opts.s: print("#last -s", opts.s)
    if opts.S: print("#last -S", opts.S)
    print("# score matrix "
          "(query letters = columns, reference letters = rows):")
    writeScoreMatrixFunc(sys.stdout, matScores, "")

def lastTrain(opts, args):
    if opts.sample_number:
        random.seed(math.pi)
        refName = args[0]
        queryFiles = args[1:]
        try:
            with tempfile.NamedTemporaryFile("w", delete=False) as f:
                getSeqSample(opts, queryFiles, f)
            doTraining(opts, [refName, f.name])
        finally:
            os.remove(f.name)
    else:
        doTraining(opts, args)

if __name__ == "__main__":
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)  # avoid silly error message
    usage = "%prog [options] lastdb-name sequence-file(s)"
    description = "Try to find suitable score parameters for aligning the given sequences."
    op = optparse.OptionParser(usage=usage, description=description)
    op.add_option("-v", "--verbose", action="count",
                  help="show more details of intermediate steps")

    og = optparse.OptionGroup(op, "Training options")
    og.add_option("--revsym", action="store_true",
                  help="force reverse-complement symmetry")
    og.add_option("--matsym", action="store_true",
                  help="force symmetric substitution matrix")
    og.add_option("--gapsym", action="store_true",
                  help="force insertion/deletion symmetry")
    og.add_option("--pid", type="float", default=100, help=
                  "skip alignments with > PID% identity (default: %default)")
    og.add_option("--postmask", type="int", metavar="NUMBER", default=1, help=
                  "skip mostly-lowercase alignments (default=%default)")
    og.add_option("--sample-number", type="int", metavar="N",
                  help="number of random sequence samples "
                  "(default: 20000 if --codon else 500)")
    og.add_option("--sample-length", type="int", default=2000, metavar="L",
                  help="length of each sample (default: %default)")
    og.add_option("--scale", type="float", metavar="S",
                  help="output scores in units of 1/S bits")
    og.add_option("--codon", action="store_true",
                  help="DNA queries & protein reference, with frameshifts")
    op.add_option_group(og)

    og = optparse.OptionGroup(op, "Initial parameter options")
    og.add_option("-r", metavar="SCORE", help=
                  "match score   (default:  6 if Q>=1, or 5 if DNA, or 12)")
    og.add_option("-q", metavar="COST", help=
                  "mismatch cost (default: 18 if Q>=1, or 5 if DNA, or  7)")
    og.add_option("-p", metavar="NAME", help="match/mismatch score matrix")
    og.add_option("-a", metavar="COST",
                  help="gap existence cost (default: 21 if Q>=1, else 15)")
    og.add_option("-b", metavar="COST",
                  help="gap extension cost (default: 9 if Q>=1, else 3)")
    og.add_option("-A", metavar="COST", help="insertion existence cost")
    og.add_option("-B", metavar="COST", help="insertion extension cost")
    og.add_option("-F", metavar="LIST", help="frameshift probabilities: "
                  "del-1,del-2,ins+1,ins+2 (default: 1-b,1-b,1-B,1-B)")
    op.add_option_group(og)

    og = optparse.OptionGroup(op, "Alignment options")
    og.add_option("-D", metavar="LENGTH",
                  help="query letters per random alignment (default: 1e6)")
    og.add_option("-E", metavar="EG2",
                  help="maximum expected alignments per square giga")
    og.add_option("-s", metavar="STRAND", help=
                  "0=reverse, 1=forward, 2=both (default: 2 if DNA, else 1)")
    og.add_option("-S", metavar="NUMBER", default="1", help=
                  "score matrix applies to forward strand of: " +
                  "0=reference, 1=query (default: %default)")
    og.add_option("-C", metavar="COUNT", help=
                  "omit gapless alignments in COUNT others with > score-per-length")
    og.add_option("-T", metavar="NUMBER",
                  help="type of alignment: 0=local, 1=overlap (default: 0)")
    og.add_option("-R", metavar="DIGITS",
                  help="lowercase & simple-sequence options")
    og.add_option("-m", metavar="COUNT", help=
                  "maximum initial matches per query position (default: 10)")
    og.add_option("-k", metavar="STEP", help="use initial matches starting at "
                  "every STEP-th position in each query (default: 1)")
    og.add_option("-P", metavar="THREADS",
                  help="number of parallel threads")
    og.add_option("-X", metavar="NUMBER", help="N/X is ambiguous in: "
                  "0=neither sequence, 1=reference, 2=query, 3=both "
                  "(default=0)")
    og.add_option("-Q", metavar="NAME",
                  help="input format: fastx, sanger (default=fasta)")
    op.add_option_group(og)

    (opts, args) = op.parse_args()
    if len(args) < 1:
        op.error("I need a lastdb index and query sequences")
    if opts.sample_number is None:
        opts.sample_number = 20000 if opts.codon else 500
    if not opts.sample_number and (len(args) < 2 or "-" in args):
        op.error("sorry, can't use stdin when --sample-number=0")
    if opts.codon:
        if not opts.scale: opts.scale = 3
        if not opts.r: opts.r = "0.4"
        if not opts.a: opts.a = "0.02"
        if not opts.b: opts.b = "0.5"
        if not opts.A: opts.A = opts.a
        if not opts.B: opts.B = opts.b
        opts.S = None

    try: lastTrain(opts, args)
    except KeyboardInterrupt: pass  # avoid silly error message
    except Exception as e:
        prog = os.path.basename(sys.argv[0])
        sys.exit(prog + ": error: " + str(e))
