#!/usr/bin/env python #Copyright (C) 2006-2011 by Benedict Paten (benedictpaten@gmail.com) # #Released under the MIT license, see LICENSE.txt import sys import os import re import math import random from sonLib.misc import close #import bioio ######################################################### ######################################################### ######################################################### #basic tree datastructures ######################################################### ######################################################### ######################################################### MIN_TREE_DISTANCE = 0.00001 class BinaryTree: def __init__(self, distance, internal, left, right, iD): self.distance = distance self.internal = internal self.left = left self.right = right self.iD = iD class TraversalID: """ tree traversal numbers, used as nodeIDs for identifying orders in the tree """ def __init__(self, midStart, mid, midEnd): self.midStart = midStart self.mid = mid self.midEnd = midEnd ######################################################### ######################################################### ######################################################### #tree functions ######################################################### ######################################################### ######################################################### def binaryTree_depthFirstNumbers(binaryTree, labelTree=True, dontStopAtID=True): """ get mid-order depth first tree numbers """ traversalIDs = {} def traverse(binaryTree, mid=0, leafNo=0): if binaryTree.internal and (dontStopAtID or binaryTree.iD is None): midStart = mid j, leafNo = traverse(binaryTree.left, mid, leafNo) mid = j j, leafNo = traverse(binaryTree.right, j+1, leafNo) traversalIDs[binaryTree] = TraversalID(midStart, mid, j) return j, leafNo traversalID = TraversalID(mid, mid, mid+1) traversalID.leafNo = leafNo #thus nodes must be unique traversalIDs[binaryTree] = traversalID return mid+1, leafNo+1 traverse(binaryTree) if labelTree: for binaryTree in traversalIDs.keys(): binaryTree.traversalID = traversalIDs[binaryTree] return traversalIDs def binaryTree_nodeNames(binaryTree): """ creates names for the leave and internal nodes of the newick tree from the leaf labels """ def fn(binaryTree, labels): if binaryTree.internal: fn(binaryTree.left, labels) fn(binaryTree.right, labels) labels[binaryTree.traversalID.mid] = labels[binaryTree.left.traversalID.mid] + "_" + labels[binaryTree.right.traversalID.mid] return labels[binaryTree.traversalID.mid] else: labels[binaryTree.traversalID.mid] = str(binaryTree.iD) return labels[binaryTree.traversalID.mid] labels = [None]*binaryTree.traversalID.midEnd fn(binaryTree, labels) return labels def getBinaryTreeNodes(binaryTree, l): if binaryTree is not None: getBinaryTreeNodes(binaryTree.left, l) l.append(binaryTree) getBinaryTreeNodes(binaryTree.right, l) def binaryTree_leafNo(binaryTree): return (binaryTree.traversalID.midEnd - binaryTree.traversalID.midStart + 1)/2 def getDistanceMatrix(tree): m = {} def fn(tree): if tree.internal: leftNodes = fn(tree.left) rightNodes = fn(tree.right) for i, d1 in leftNodes: for j, d2 in rightNodes: m[(i, j)] = d1 + d2 m[(j, i)] = d1 + d2 j = tree.traversalID.mid for i, d1 in leftNodes + rightNodes: m[(i, j)] = d1 m[(j, i)] = d1 return [ (i[0], i[1]+tree.distance) for i in leftNodes + rightNodes + [(j, 0.0)] ] return [ (tree.traversalID.mid, tree.distance) ] fn(tree) return m def makeRandomBinaryTree(leafNodeNumber=None): """Creates a random binary tree. """ while True: nodeNo = [-1] def fn(): nodeNo[0] += 1 if random.random() > 0.6: i = str(nodeNo[0]) return BinaryTree(0.00001 + random.random()*0.8, True, fn(), fn(), i) else: return BinaryTree(0.00001 + random.random()*0.8, False, None, None, str(nodeNo[0])) tree = fn() def fn2(tree): if tree.internal: return fn2(tree.left) + fn2(tree.right) return 1 if leafNodeNumber is None or fn2(tree) == leafNodeNumber: return tree def getRandomBinaryTreeLeafNode(binaryTree): """Get random binary tree node. """ if binaryTree.internal == True: if random.random() > 0.5: return getRandomBinaryTreeLeafNode(binaryTree.left) else: return getRandomBinaryTreeLeafNode(binaryTree.right) else: return binaryTree ######################################################### ######################################################### ######################################################### #substition functions and felsensteins algorithm ######################################################### ######################################################### ######################################################### def transformByDistance(wV, subModel, alphabetSize=4): """ transform wV by given substitution matrix """ nc = [0.0]*alphabetSize for i in xrange(0, alphabetSize): j = wV[i] k = subModel[i] for l in xrange(0, alphabetSize): nc[l] += j * k[l] return nc def multiplyWV(wVX, wVY, alphabetSize=4): return [ wVX[i] * wVY[i] for i in xrange(0, alphabetSize) ] def sumWV(wVX, wVY, alphabetSize=4): return [ wVX[i] + wVY[i] for i in xrange(0, alphabetSize) ] def normaliseWV(wV, normFac=1.0): """ make char probs divisible by one """ f = sum(wV) / normFac return [ i/f for i in wV ] def sumWVA(wVA, alphabetSize=4): totals = [0.0]*alphabetSize for wV in wVA: for i in xrange(0, alphabetSize): totals[i] += wV[i] return totals def felsensteins(binaryTree, subMatrices, ancestorProbs, leaves, alphabetSize): """ calculates the un-normalised probabilties of each non-gap residue position """ l = {} def upPass(binaryTree): if binaryTree.internal: #is internal binaryTree i = branchUp(binaryTree.left) j = branchUp(binaryTree.right) k = multiplyWV(i, j, alphabetSize) l[binaryTree.traversalID.mid] = (k, i, j) return k l[binaryTree.traversalID.mid] = leaves[binaryTree.traversalID.leafNo] return leaves[binaryTree.traversalID.leafNo] def downPass(binaryTree, ancestorProbs): if binaryTree.internal: #is internal binaryTree i = l[binaryTree.traversalID.mid] l[binaryTree.traversalID.mid] = multiplyWV(ancestorProbs, i[0], alphabetSize) branchDown(binaryTree.left, multiplyWV(ancestorProbs, i[2], alphabetSize)) branchDown(binaryTree.right, multiplyWV(ancestorProbs, i[1], alphabetSize)) def branchUp(binaryTree): return transformByDistance(upPass(binaryTree), subMatrices[binaryTree.traversalID.mid], alphabetSize) def branchDown(binaryTree, ancestorProbs): downPass(binaryTree, transformByDistance(ancestorProbs, subMatrices[binaryTree.traversalID.mid], alphabetSize)) upPass(binaryTree) downPass(binaryTree, ancestorProbs) return l def calculateCharacterFrequencies(seq, map, alphabetSize): counts = [0.0]*alphabetSize for i in seq: counts[map(i)] += 1 return counts ######################################################### ######################################################### ######################################################### #distance matrix to tree building functions ######################################################### ######################################################### ######################################################### class DistancePair: def __init__(self, distance, leaf1, leafNo1, leaf2, leafNo2): self.distance = distance self.leaf1 = leaf1 self.leaf2 = leaf2 self.leafNo1 = leafNo1 self.leafNo2 = leafNo2 def __cmp__(self, distancePair): if self.distance < distancePair.distance: return -1 if self.distance > distancePair.distance: return 1 return 0 #don't care #doesn't wort for floats return self.distance.__cmp__(distancePair.distance) def correctTreeDistances(tree): if tree is not None: if tree.distance < MIN_TREE_DISTANCE: tree.distance = MIN_TREE_DISTANCE correctTreeDistances(tree.left) correctTreeDistances(tree.right) def calculateDNADistanceMatrix(seqNo, fastaIter, transitionTransversionRatio=2.0): transitions = [0.1]*seqNo*seqNo transversions = [0.1]*seqNo*seqNo counts = [1.0]*seqNo*seqNo for column in fastaIter: for i in xrange(0, seqNo): if column[i] in [ 'A', 'C', 'T', 'G' ]: for j in xrange(i+1, seqNo): if column[j] in [ 'A', 'C', 'T', 'G' ]: counts[i*seqNo + j] += 1 if column[i] != column[j]: if column[i] in [ 'A', 'G' ]: if column[j] in [ 'C', 'T' ]: transversions[i*seqNo + j] += 1 else: transitions[i*seqNo + j] += 1 else: if column[j] in [ 'A', 'G' ]: transversions[i*seqNo + j] += 1 else: transitions[i*seqNo + j] += 1 distanceMatrix = [ [None]*seqNo for i in xrange(0, seqNo) ] for i in xrange(0, seqNo*seqNo): for j in xrange(i+1, seqNo): k = i * seqNo + j distanceMatrix[i][j] = -0.75*math.log(1 - (4/3)*((transitions[k]+transversions[k])/counts[k])) #jukes cantor correction distanceMatrix[j][i] = distanceMatrix[i][j] #print "boo", i, j, distanceMatrix[i][j], (transitions[k]+transversions[k])/counts[k] #distanceMatrix[i] = -0.5*math.log(1 - 2*P - Q)-0.25*math.log(1 - 2*Q) return distanceMatrix def makeDistancePairs(distanceMatrix, iDs, seqNo): binaryTrees = [ BinaryTree(0.0, False, None, None, iDs[i]) for i in xrange(0, seqNo) ] distancePairs = [] for i in xrange(0, seqNo): for j in xrange(i+1, seqNo): distancePairs.append(DistancePair(distanceMatrix[i][j], binaryTrees[i], 1, binaryTrees[j], 1)) distancePairs.append(DistancePair(distanceMatrix[i][j], binaryTrees[j], 1, binaryTrees[i], 1)) return distancePairs def upgma(distanceMatrix, iDs, leafNo): binaryTree = upgmaI(makeDistancePairs(distanceMatrix, iDs, leafNo), leafNo) def fn(tree): if tree.internal: tree.distance -= tree.left.distance fn(tree.left) fn(tree.right) fn(binaryTree) binaryTree.distance = MIN_TREE_DISTANCE correctTreeDistances(binaryTree) return binaryTree def upgmaI(distancePairs, leafNo): #get min pair distancePairs.sort() distancePair = distancePairs[0] #calculate shared distance distancePair.leaf1.distance = distancePair.distance/2 distancePair.leaf2.distance = distancePair.distance/2 newLeaf = BinaryTree(0.0, True, distancePair.leaf1, distancePair.leaf2, None) if leafNo-1 == 1: return newLeaf #replace references holder1 = {} holder2 = {} newDistances = [] for i in distancePairs: if i.leaf1 == distancePair.leaf1 and i.leaf2 != distancePair.leaf2: holder1[i.leaf2] = i if i.leaf1 == distancePair.leaf2 and i.leaf2 != distancePair.leaf1: holder2[i.leaf2] = i assert len(holder1.keys()) == leafNo-2 assert len(holder2.keys()) == leafNo-2 assert set(holder1.keys()) == set(holder2.keys()) for i in holder1.keys(): j = holder1[i] k = holder2[i] newDistance = (j.distance*j.leafNo1 + k.distance*k.leafNo1)/(j.leafNo1 + k.leafNo1) newDistances.append(DistancePair(newDistance, j.leaf2, j.leafNo2, newLeaf, j.leafNo1 + k.leafNo1)) newDistances.append(DistancePair(newDistance, newLeaf, j.leafNo1 + k.leafNo1, j.leaf2, j.leafNo2)) distancePairs = [ i for i in distancePairs if (i.leaf1 != distancePair.leaf1 and i.leaf1 != distancePair.leaf2 and i.leaf2 != distancePair.leaf1 and i.leaf2 != distancePair.leaf2) ] + newDistances return upgmaI(distancePairs, leafNo-1) def nj(distanceMatrix, iDs, leafNo): binaryTree = njI(makeDistancePairs(distanceMatrix, iDs, leafNo), leafNo) correctTreeDistances(binaryTree) return binaryTree def getMinPair(distancePairs, rValues, leafNo): j = None k = sys.maxint for i in distancePairs: adjustD = i.distance - (rValues[i.leaf1] + rValues[i.leaf2])/(leafNo-2) #print "the adjusted value ", adjustD, i.distance, rValues[i.leaf1]/(leafNo-2), rValues[i.leaf2]/(leafNo-2) if adjustD < k: k = adjustD j = i #print "value is ", k, j.distance return j def calculateRValues(distancePairs, leafNo): j = {} for i in distancePairs: if j.has_key(i.leaf1): j[i.leaf1] += i.distance else: j[i.leaf1] = i.distance assert len(j.keys()) == leafNo return j def njI(distancePairs, leafNo): assert leafNo >= 2 if leafNo == 2: assert len(distancePairs) == 2 distancePair = distancePairs[0] distancePair.leaf1.distance = distancePair.distance*0.5 distancePair.leaf2.distance = distancePair.distance*0.5 return BinaryTree(MIN_TREE_DISTANCE, True, distancePair.leaf1, distancePair.leaf2, None) #calculate r values rValues = calculateRValues(distancePairs, leafNo) #get min pair distancePair = getMinPair(distancePairs, rValues, leafNo) #distance, internal, left, right distancePair.leaf1.distance = 0.5*(distancePair.distance + (rValues[distancePair.leaf1] - rValues[distancePair.leaf2])/(leafNo-2)) distancePair.leaf2.distance = distancePair.distance - distancePair.leaf1.distance newLeaf = BinaryTree(0.0, True, distancePair.leaf1, distancePair.leaf2, None) #replace references holder1 = {} holder2 = {} newDistances = [] for i in distancePairs: if i.leaf1 == distancePair.leaf1 and i.leaf2 != distancePair.leaf2: holder1[i.leaf2] = i if i.leaf1 == distancePair.leaf2 and i.leaf2 != distancePair.leaf1: holder2[i.leaf2] = i assert len(holder1.keys()) == leafNo-2 assert len(holder2.keys()) == leafNo-2 assert set(holder1.keys()) == set(holder2.keys()) for i in holder1.keys(): j = holder1[i] k = holder2[i] assert j.leaf2 == k.leaf2 #print "the leaf is ", j.leaf2 newDistance = 0.5*(j.distance + k.distance - distancePair.distance) #print "now a new distance", newDistance newDistances.append(DistancePair(newDistance, j.leaf2, 0, newLeaf, 0)) #leaf numbers are un important, and hence omitted newDistances.append(DistancePair(newDistance, newLeaf, 0, j.leaf2, 0)) distancePairs = [ i for i in distancePairs if (i.leaf1 != distancePair.leaf1 and i.leaf1 != distancePair.leaf2 and i.leaf2 != distancePair.leaf1 and i.leaf2 != distancePair.leaf2) ] + newDistances return njI(distancePairs, leafNo-1) ######################################################### ######################################################### ######################################################### #substitution matrix functions ######################################################### ######################################################### ######################################################### def checkMatrix(m, fV, AS=4, reversible=True): #print m for i in xrange(0, AS): j = sum(m[i]) #print "AAAAA", j assert j <= 1.0001 assert j >= 0.9999 if reversible: for k in xrange(0, AS): #print "comp2", (fV[i] * m[i][k]), (fV[k] * m[k][i] ) assert close(fV[i] * m[i][k], fV[k] * m[k][i], 0.00001) wV = fV wV2 = fV wV3 = transformByDistance(wV, m, AS) wV4 = transformByDistance(wV2, m, AS) i = sum(multiplyWV(wV2, wV3, AS)) j = sum(multiplyWV(wV, wV4, AS)) #print i, j assert close(i, j, 0.00001) def reverseSubMatrix(m, AS=4): k = [ [None]*AS for i in xrange(0, AS) ] for i in xrange(0, AS): for j in xrange(0, AS): k[j][i] = m[i][j] return k def subMatrix_jukesCantor(d): i = 0.25 + 0.75*math.exp(-(4.0/3.0)*d) j = 0.25 - 0.25*math.exp(-(4.0/3.0)*d) return [ [i, j, j, j], [j, i, j, j], [j, j, i, j], [j, j, j, i] ] """ def distanceTamureiNei(aF, cF, gF, tF, a2G, t2C, tV): rF = aF + gF yF = cF + tF xx = 1.0 - a2G * rF/(2.0 * aF * gF) - tV/(2.0 * rF) yy = 1.0 - t2C * yF/(2.0 * tF * cF) - tV/(2.0 * yF) i = -(2.0 * aF * gF / rF) * math.log(1.0 - a2G * rF/(2.0 * aF * gF) - tV/(2.0 * rF)) j = -(2.0 * tF * cF / yF) * math.log(1.0 - t2C * yF/(2.0 * tF * cF) - tV/(2.0 * yF)) k = -2.0 * (rF * yF - (aF * gF * yF / rF) - (tF * cF * rF / yF)) * math.log(1.0 - tV/(2.0 * rF * yF)) print i, j, k, xx, yy d = -(2.0 * aF * gF / rF) * math.log(1.0 - a2G * rF/(2.0 * aF * gF) - tV/(2.0 * rF)) \ -(2.0 * tF * cF / yF) * math.log(1.0 - t2C * yF/(2.0 * tF * cF) - tV/(2.0 * yF)) \ -2.0 * (rF * yF - (aF * gF * yF / rF) - (tF * cF * rF / yF)) * math.log(1.0 - tV/(2.0 * rF * yF)) return d """ def subMatrix_TamuraNei(d, fA, fC, fG, fT, alphaPur, alphaPyr, beta): i = fA + fC + fG + fT assert i < 1.00001 assert i > 0.99999 assert d >= 0.0 #assert alphaPur >= 0.0 #assert alphaPyr >= 0.0 #assert beta >= 0 AS = 4 freq = ( fA, fC, fG, fT ) alpha = ( alphaPur, alphaPyr, alphaPur, alphaPyr ) matrix = [ [ 0.0 ]*AS for i in xrange(0, AS) ] #see page 203 of Felsenstein's Inferring Phylogenies for explanation of calculations def watKro(j, k): if (j % 2) == (k % 2): return 1.0 return 0.0 def kroenickerDelta(i, j): if i == j: return 1.0 return 0.0 for i in xrange(0, AS): #long winded, totally unoptimised method for calculating matrix for j in xrange(0, AS): l = 0.0 for k in xrange(0, AS): l += watKro(j, k) * freq[k] matrix[i][j] =\ math.exp(-(alpha[i] + beta) * d) * kroenickerDelta(i, j) + \ math.exp(-beta*d) * (1.0 - math.exp(-alpha[i]*d)) * (freq[j] * watKro(i, j) / l) + \ (1.0 - math.exp(-beta * d)) * freq[j] checkMatrix(matrix, (fA, fC, fG, fT)) return matrix def subMatrix_HKY(d, fA, fC, fG, fT, transitionTransversionR): i = fA + fC + fG + fT assert i < 1.00001 assert i > 0.99999 fPur = fA + fG fPyr = fC + fT p = fPur/fPyr #makes like HKY beta = 1.0 / (2.0 * fPur * fPyr * (1.0 + transitionTransversionR)) alphaPyr = ((fPur * fPyr * transitionTransversionR) - (fA * fG) - (fC * fT)) \ / (2.0 * (1.0 + transitionTransversionR) * (fPyr * fA * fG * p + fPur * fC * fT)) alphaPur = p * alphaPyr return subMatrix_TamuraNei(d, fA, fC, fG, fT, alphaPur, alphaPyr, beta) def subMatrix_HalpernBruno(d, freqColumn, subMatrix, AS=4): #return subMatrix_HKY(d, freqColumn[0], freqColumn[1], freqColumn[2], freqColumn[3], 2.0) #return subMatrix matrix = [ [ 0.0 ]*AS for i in xrange(0, AS) ] for i in xrange(0, AS): for j in xrange(0, AS): a = freqColumn[i] * subMatrix[i][j] b = freqColumn[j] * subMatrix[j][i] if not close(a, b, 0.0001): matrix[i][j] = subMatrix[i][j] * (math.log(b/a) / (1 - (a/b))) else: matrix[i][j] = subMatrix[i][j] #for i in xrange(0, AS): # #print matrix[i][i], sum(matrix[i]) # matrix[i][i] -= sum(matrix[i]) - 1.0 # assert matrix[i][i] >= 0 #checkMatrix(matrix, freqColumn) return matrix ######################################################### ######################################################### ######################################################### #misc tree functions ######################################################### ######################################################### ######################################################### def annotateTree(bT, fn): """ annotate a tree in an external array using the given function """ l = [None]*bT.traversalID.midEnd def fn2(bT): l[bT.traversalID.mid] = fn(bT) if bT.internal: fn2(bT.left) fn2(bT.right) fn2(bT) return l def mapTraversalIDsBetweenTrees(oldTree, newTree): map = {} leafMap = {} internalMap = {} def fn(i): j = i.traversalID.mid if j == oldTree.traversalID.mid or (oldTree.internal and j == oldTree.left.traversalID.mid or j == oldTree.right.traversalID.mid): return oldTree.traversalID.mid return j def fn2(oldTree): if oldTree.internal: fn2(oldTree.left) fn2(oldTree.right) else: leafMap[oldTree.iD] = fn(oldTree) fn2(oldTree) def fn3(oldTree): if oldTree.internal: fn3(oldTree.left) fn3(oldTree.right) internalMap[(fn(oldTree.left), fn(oldTree.right))] = fn(oldTree) internalMap[(fn(oldTree.right), fn(oldTree.left))] = fn(oldTree) internalMap[(fn(oldTree), fn(oldTree.right))] = fn(oldTree.left) internalMap[(fn(oldTree.right), fn(oldTree))] = fn(oldTree.left) internalMap[(fn(oldTree), fn(oldTree.left))] = fn(oldTree.right) internalMap[(fn(oldTree.left), fn(oldTree))] = fn(oldTree.right) fn3(oldTree) print leafMap print internalMap def fn4(newTree): if newTree.internal: fn4(newTree.left) fn4(newTree.right) map[newTree.traversalID.mid] = internalMap[(map[newTree.left.traversalID.mid], map[newTree.right.traversalID.mid])] else: map[newTree.traversalID.mid] = leafMap[newTree.iD] fn4(newTree) return map def remodelTreeRemovingRoot(root, node): """ Node is mid order number """ import bioio assert root.traversalID.mid != node hash = {} def fn(bT): if bT.traversalID.mid == node: assert bT.internal == False return [ bT ] elif bT.internal: i = fn(bT.left) if i is None: i = fn(bT.right) if i is not None: hash[i[-1]]= bT i.append(bT) return i return None l = fn(root) def fn2(i, j): if i.left == j: return i.right assert i.right == j return i.left def fn3(bT): if hash[bT] == root: s = '(' + bioio.printBinaryTree(fn2(hash[bT], bT), bT, True)[:-1] + ')' else: s = '(' + bioio.printBinaryTree(fn2(hash[bT], bT), bT, True)[:-1] + ',' + fn3(hash[bT]) + ')' return s + ":" + str(bT.distance) s = fn3(l[0]) + ';' t = bioio.newickTreeParser(s) return t def moveRoot(root, branch): """ Removes the old root and places the new root at the mid point along the given branch """ import bioio if root.traversalID.mid == branch: return bioio.newickTreeParser(bioio.printBinaryTree(root, True)) def fn2(tree, seq): if seq is not None: return '(' + bioio.printBinaryTree(tree, True)[:-1] + ',' + seq + ')' return bioio.printBinaryTree(tree, True)[:-1] def fn(tree, seq): if tree.traversalID.mid == branch: i = tree.distance tree.distance /= 2 seq = '(' + bioio.printBinaryTree(tree, True)[:-1] + ',(' + seq + ('):%s' % tree.distance) + ');' tree.distance = i return seq if tree.internal: if branch < tree.traversalID.mid: seq = fn2(tree.right, seq) return fn(tree.left, seq) else: assert branch > tree.traversalID.mid seq = fn2(tree.left, seq) return fn(tree.right, seq) else: return bioio.printBinaryTree(tree, True)[:-1] s = fn(root, None) return bioio.newickTreeParser(s) def checkGeneTreeMatchesSpeciesTree(speciesTree, geneTree, processID): """ Function to check ids in gene tree all match nodes in species tree """ def fn(tree, l): if tree.internal: fn(tree.left, l) fn(tree.right, l) else: l.append(processID(tree.iD)) l = [] fn(speciesTree, l) l2 = [] fn(geneTree, l2) for i in l2: #print "node", i, l assert i in l def calculateDupsAndLossesByReconcilingTrees(speciesTree, geneTree, processID): """ Reconciles the given gene tree with the species tree and report the number of needed duplications and losses """ checkGeneTreeMatchesSpeciesTree(speciesTree, geneTree, processID) def fn(tree, m): if tree.internal: nodes = fn(tree.left, m) nodes = nodes.union(fn(tree.right, m)) m[tree.traversalID.mid] = nodes else: m[tree.traversalID.mid] = set((processID(tree.iD),)) return m[tree.traversalID.mid] a = {} fn(speciesTree, a) b = {} fn(geneTree, b) def fn2(nodes, speciesTree): assert nodes.issubset(a[speciesTree.traversalID.mid]) if speciesTree.internal: if nodes.issubset(a[speciesTree.left.traversalID.mid]): return fn2(nodes, speciesTree.left) if nodes.issubset(a[speciesTree.right.traversalID.mid]): return fn2(nodes, speciesTree.right) return speciesTree.traversalID.mid for iD in b.keys(): nodes = b[iD] b[iD] = fn2(nodes, speciesTree) dups = [] def fn3(geneTree): if geneTree.internal: i = b[geneTree.traversalID.mid] if b[geneTree.left.traversalID.mid] == i or b[geneTree.right.traversalID.mid] == i: dups.append(geneTree.traversalID.mid) fn3(geneTree.left) fn3(geneTree.right) fn3(geneTree) lossMap = {} def fn4(speciesTree): nodes = [(speciesTree.traversalID.mid, -1)] lossMap[(speciesTree.traversalID.mid, speciesTree.traversalID.mid)] = 0 if speciesTree.internal: for node, losses in fn4(speciesTree.left) + fn4(speciesTree.right): lossMap[(speciesTree.traversalID.mid, node)] = losses+1 nodes.append((node, losses+1)) return nodes for node, losses in fn4(speciesTree): lossMap[(sys.maxint, node)] = losses+1 losses = [0] def fn5(geneTree, ancestor): if geneTree.internal: i = b[geneTree.traversalID.mid] if geneTree.traversalID.mid in dups: losses[0] += lossMap[(ancestor, b[geneTree.left.traversalID.mid])] losses[0] += lossMap[(ancestor, b[geneTree.right.traversalID.mid])] else: losses[0] += lossMap[(i, b[geneTree.left.traversalID.mid])] losses[0] += lossMap[(i, b[geneTree.right.traversalID.mid])] fn5(geneTree.left, i) fn5(geneTree.right, i) ancestorHolder = [None] def fn6(speciesTree, ancestor, node): if speciesTree.traversalID.mid == node: ancestorHolder[0] = ancestor if speciesTree.internal: fn6(speciesTree.left, speciesTree.traversalID.mid, node) fn6(speciesTree.right, speciesTree.traversalID.mid, node) ancestor = fn6(speciesTree, sys.maxint, b[geneTree.traversalID.mid]) assert ancestorHolder[0] is not None fn5(geneTree, ancestorHolder[0]) return len(dups), losses[0] def calculateProbableRootOfGeneTree(speciesTree, geneTree, processID=lambda x : x): """ Goes through each root possible branch making it the root. Returns tree that requires the minimum number of duplications. """ #get all rooted trees #run dup calc on each tree #return tree with fewest number of dups if geneTree.traversalID.midEnd <= 3: return (0, 0, geneTree) checkGeneTreeMatchesSpeciesTree(speciesTree, geneTree, processID) l = [] def fn(tree): if tree.traversalID.mid != geneTree.left.traversalID.mid and tree.traversalID.mid != geneTree.right.traversalID.mid: newGeneTree = moveRoot(geneTree, tree.traversalID.mid) binaryTree_depthFirstNumbers(newGeneTree) dupCount, lossCount = calculateDupsAndLossesByReconcilingTrees(speciesTree, newGeneTree, processID) l.append((dupCount, lossCount, newGeneTree)) if tree.internal: fn(tree.left) fn(tree.right) fn(geneTree) l.sort() return l[0][2], l[0][0], l[0][1] #add traversalID.mid to each node name #print tree #parse tree #remove names, and add them to traversalID def main(): pass def _test(): import doctest return doctest.testmod() if __name__ == '__main__': _test() main()