Commit 0328eb3d authored by Sean Solari's avatar Sean Solari
Browse files

ITOL compatibility

parent ed154aca
......@@ -91,7 +91,7 @@ def string_to_tuple(st):
def run_classifier(read_paths, out_dir, db_dir, k, n, phylogeny_path, keys_shape, values_shape, logging_dir,
taxonomy=False, cutoff=0.0, groups=None, keep_zeros=False, cpm=0.0, use_node_names=True,
phyla=False, name_taxa=None, colour_list=None, circle_scale=1.0,
paired_end=False, alpha=1.0, log_scores=False):
paired_end=False, alpha=1.0, log_scores=False, itol_mode=False):
# Verify results path.
if not os.path.exists(out_dir):
......@@ -178,7 +178,7 @@ def run_classifier(read_paths, out_dir, db_dir, k, n, phylogeny_path, keys_shape
name_to_lineage, taxon_to_rank = load_taxonomy_map(db_dir)
results.to_taxonomy(name_to_lineage, taxon_to_rank, tax_results_path)
results.draw_results()
results.draw_results(itol_mode=itol_mode)
finally:
......@@ -1300,7 +1300,7 @@ class ClassificationResults:
return intersection
def draw_results(self):
def draw_results(self, itol_mode=False):
# Draw classified tree.
self.phylogeny_index.draw_results(
os.path.join(self.in_dir, "classified_counts.csv"),
......@@ -1314,7 +1314,8 @@ class ClassificationResults:
use_phyla=self.phyla,
keep_zeros=self.keep_zeros,
use_node_names=self.use_node_names,
log_scores=self.log_scores
log_scores=self.log_scores,
itol_mode=itol_mode
)
# Draw unclassified tree.
......@@ -1329,5 +1330,6 @@ class ClassificationResults:
use_phyla=self.phyla,
keep_zeros=self.keep_zeros,
use_node_names=self.use_node_names,
log_scores=self.log_scores
log_scores=self.log_scores,
itol_mode=itol_mode
)
......@@ -4,6 +4,7 @@ import json.decoder
import math
import platform
import re
from tkinter import E
import matplotlib.pyplot as plt
import multiprocessing as mp
......@@ -804,7 +805,7 @@ def validate_results_path(results_dir):
classified_path = os.path.join(phy_results_path, "classified_counts.csv")
splits_path = os.path.join(phy_results_path, "splits_counts.csv")
if not os.path.exists(classified_path) or not os.path.exists(splits_path):
if not (os.path.exists(classified_path) or os.path.exists(splits_path)):
raise Exception("Path does not look like expam results folder!")
......@@ -843,7 +844,7 @@ def main():
parser.add_argument("-k", "--kmer", dest="k",
help="Length of mer used for analysis.",
metavar="[k value (int)]")
parser.add_argument("-n", "--n_processes", dest="n",
parser.add_argument("-n", "--n-processes", dest="n",
help="Number of CPUs to use for processing.",
metavar="[n (int)]")
parser.add_argument("-s", "--sketch", dest="sketch",
......@@ -864,7 +865,7 @@ def main():
parser.add_argument("-y", "--pile", dest="pile",
help="Number of genomes to pile at a time (or inf).",
metavar="[pile size]")
parser.add_argument("-e", "--error_rate", dest="error_rate",
parser.add_argument("-e", "--error-rate", dest="error_rate",
help="Generate error in reads (error ~ reads with errors / reads).",
metavar="[error rate]")
parser.add_argument("-t", "--truth", dest="truth_dir",
......@@ -883,14 +884,14 @@ def main():
help="Colour phylotree results by phyla.")
parser.add_argument("--rank", dest="rank", default=None,
help="Rank at which to sort results.")
parser.add_argument("--keep_zeros", dest="keep_zeros", default=False, action="store_true",
parser.add_argument("--keep-zeros", dest="keep_zeros", default=False, action="store_true",
help="Keep nodes of output where no reads have been assigned.")
parser.add_argument("--ignore_names", dest="ignore_node_names", default=False, action="store_true")
parser.add_argument("--ignore-names", dest="ignore_node_names", default=False, action="store_true")
parser.add_argument("--group", dest="groups", action="append", nargs="+",
help="Space-separated list of sample files to be treated as a single group in phylotree.")
parser.add_argument("--colour_list", dest="colour_list", nargs="+",
parser.add_argument("--colour-list", dest="colour_list", nargs="+",
help="List of colours to use when plotting groups in phylotree.")
parser.add_argument("--circle_scale", dest="circle_scale", default=1.0,
parser.add_argument("--circle-scale", dest="circle_scale", default=1.0,
help="Scale of circles that represent splits in phylotree.")
parser.add_argument("--sourmash", dest="use_sourmash", default=False, action="store_true",
help="Use sourmash for distance estimation.")
......@@ -904,6 +905,8 @@ def main():
help="Percentage requirement for classification subtrees (see Tutorials 1 & 2).")
parser.add_argument("--log-scores", dest="log_scores", default=False, action="store_true",
help="Log transformation to opacity scores on phylotree (think uneven distributions).")
parser.add_argument("--itol", dest="itol_mode", default=False, action="store_true",
help="Output plotting data in ITOL format.")
# Parse arguments.
args = parser.parse_args()
......@@ -914,14 +917,14 @@ def main():
param_args = args.length, args.pile, args.error_rate, args.first_n, args.sketch, args.paired_end
summary_args = args.plot, args.cutoff, args.cpm, args.taxonomy
plot_args = args.groups, args.phyla, args.keep_zeros, not args.ignore_node_names, args.colour_list, \
args.circle_scale, args.rank, args.log_scores
args.circle_scale, args.rank, args.log_scores, args.itol_mode
tree_args = args.use_sourmash, args.use_rapidnj, args.use_quicktree
command, db_name, k, n, phylogeny, alpha = runtime_args
directories, out_dir, truth_dir = directory_args
length, pile_size, error_rate, first_n, sketch, paired_end = param_args
plot, cutoff, cpm, taxonomy = summary_args
groups, plot_phyla, keep_zeros, use_node_names, colour_list, circle_scale, at_rank, log_scores = plot_args
groups, plot_phyla, keep_zeros, use_node_names, colour_list, circle_scale, at_rank, log_scores, itol_mode = plot_args
use_sourmash, use_rapidnj, use_quicktree = tree_args
group = None if groups is None else groups[0][0] # When referring to sequence groups, not plotting groups.
......@@ -1091,7 +1094,8 @@ def main():
circle_scale=circle_scale,
paired_end=paired_end,
alpha=alpha,
log_scores=log_scores
log_scores=log_scores,
itol_mode=itol_mode
)
#
......@@ -1159,20 +1163,21 @@ def main():
die("Require output directory (-o, --out_dir)!")
else:
validate_results_path(out_dir)
pass
from expam.classification import TAXID_LINEAGE_MAP_NAME, PHY_RESULTS, \
load_taxonomy_map, name_to_id, ClassificationResults
map_url = os.path.join(DB_DIR, TAXID_LINEAGE_MAP_NAME)
if not os.path.exists(map_url):
die("Run command `download_taxonomy` first to collect taxa for your genomes!")
config = load_configuration_file(DB_DIR, return_config=True)
phylogeny_path = config["phylogeny_path"]
phylogeny_path = make_path_absolute(phylogeny_path, DB_DIR)
index, phylogenyIndex = name_to_id(phylogeny_path)
name_to_lineage, taxon_to_rank = load_taxonomy_map(DB_DIR)
try:
name_to_lineage, _ = load_taxonomy_map(DB_DIR)
except FileNotFoundError:
name_to_lineage = None
# Load phylogenetic results.
phy_results_url = os.path.join(out_dir, PHY_RESULTS)
......@@ -1188,12 +1193,12 @@ def main():
cpm=cpm,
use_node_names=use_node_names,
phyla=plot_phyla,
name_taxa=name_to_lineage,
name_taxa=name_to_lineage,
colour_list=colour_list,
circle_scale=circle_scale,
log_scores=log_scores
)
results.draw_results()
results.draw_results(itol_mode=itol_mode)
#
# Plot tree.
......
......@@ -2,6 +2,8 @@ import random
from math import floor, log
import json
import os
import re
import sys
import traceback
import numpy as np
......@@ -294,7 +296,7 @@ class Index:
return f'<Phylogeny Index, length={len(self)}>'
@classmethod
def load_newick(cls, path):
def load_newick(cls, path, keep_names=False, verbose=True):
"""load_newick Load Newick tree from file.
:param path: path to Newick file
......@@ -312,10 +314,10 @@ class Index:
for line in f:
newick_str += line.strip()
return cls.from_newick(newick_str)
return cls.from_newick(newick_str, keep_names=keep_names, verbose=verbose)
@classmethod
def from_newick(cls, newick_string):
def from_newick(cls, newick_string, keep_names=False, verbose=True):
"""from_newick Parse Newick string.
:param newick_string: Newick string encoding tree.
......@@ -323,18 +325,22 @@ class Index:
:return: name of leaves and phylogeny Index object
:rtype: list[str], expam.tree.Index
"""
stream = sys.stdout if verbose else open(os.devnull, 'w')
# Remove whitespace.
newick_string = newick_string.replace(" ", "")
newick_string = newick_string.replace(" ", "").replace("\n", "")
print("* Initialising node pool...")
print("* Initialising node pool...", file=stream)
index_pool = cls.init_pool(newick_string)
print("* Checking for polytomies...")
print("* Checking for polytomies...", file=stream)
cls.resolve_polytomies(index_pool)
print("* Finalising index...")
leaf_names, index = cls.from_pool(index_pool)
print("* Finalising index...", file=stream)
leaf_names, index = cls.from_pool(index_pool, keep_names=keep_names)
if stream != sys.stdout:
stream.close()
return leaf_names, index
......@@ -400,10 +406,19 @@ class Index:
def parse_string(force_digits=False):
nonlocal i # Inherit current index.
j = i + 1
# Detect quotations.
if newick[i] == "'":
terminal = "'"
i += 1
while newick[j] not in NEWICK_PARSE:
else:
terminal = NEWICK_PARSE
j = i + 1
while newick[j] not in terminal:
j += 1
string = newick[i:j]
# Check formatting.
......@@ -414,7 +429,10 @@ class Index:
except ValueError:
raise ValueError("Invalid distance declaration %s!" % newick[i:j])
i = j - 1 # Update current index position.
if terminal == "'":
i = j
else:
i = j - 1
return string
......@@ -505,7 +523,7 @@ class Index:
i += 1
@classmethod
def from_pool(cls, pool):
def from_pool(cls, pool, keep_names=False):
leaf_names = []
# Initialise expam Index.
......@@ -518,8 +536,9 @@ class Index:
continue
if node.type == "Branch":
# Set branch name.
node.name = str(branch_id)
if not keep_names:
node.name = str(branch_id)
branch_id += 1
else:
......@@ -767,19 +786,43 @@ class Index:
"""
return list(self.yield_leaves(node_name))
def lca(self, name_one, name_two):
"""lca Return name of the lowest common ancestor of these two nodes.
Note that coordinates are read from right-to-left.
:param name_one: name of node
:type name_one: str
:param name_two: name of node
:type name_two: str
"""
lca_coord = self.right_intersect(self[name_one].coordinate, self[name_two].coordinate)
return self.coord(lca_coord).name
@staticmethod
def right_intersect(a_list, b_list):
for i, (a, b) in enumerate(zip(a_list[::-1], b_list[::-1])):
if a != b:
if i == 0:
return []
else:
return a_list[-i:]
def draw_results(self, file_path, out_dir, skiprows=None, groups=None, cutoff=None, cpm=None, colour_list=None,
name_to_taxon=None, use_phyla=False, keep_zeros=True, use_node_names=True, log_scores=False):
name_to_taxon=None, use_phyla=False, keep_zeros=True, use_node_names=True, log_scores=False,
itol_mode=False):
counts = pd.read_csv(file_path, sep='\t', index_col=0, header=0, skiprows=skiprows)
self.draw_tree(out_dir, counts=counts, groups=groups, cutoff=cutoff, cpm=cpm, colour_list=colour_list,
name_to_taxon=name_to_taxon, use_phyla=use_phyla, keep_zeros=keep_zeros,
use_node_names=use_node_names, log_scores=log_scores)
use_node_names=use_node_names, log_scores=log_scores, itol_mode=itol_mode)
def draw_tree(self, out_dir, counts, groups=None, cutoff=None, cpm=None, colour_list=None, name_to_taxon=None,
use_phyla=False, keep_zeros=True, use_node_names=True, log_scores=True, itol_mode=False):
from expam.sequences import format_name
def draw_tree(self, out_dir, counts=None, groups=None, cutoff=None, cpm=None, colour_list=None, name_to_taxon=None,
use_phyla=False, keep_zeros=True, use_node_names=True, log_scores=True):
try:
import ete3.coretype.tree
from ete3 import AttrFace, faces, Tree, TreeStyle, NodeStyle, TextFace
from ete3 import Tree
except ModuleNotFoundError as e:
print("Could not import ete3 plotting modules! Error raised:")
print(traceback.format_exc())
......@@ -787,8 +830,6 @@ class Index:
return
from expam.sequences import format_name
"""
Phylogenetic printing of nodes.
"""
......@@ -820,118 +861,165 @@ class Index:
for col, group in groups
]
# *** Relies on counts being defined.
if counts is not None:
# Combine counts within the group.
for _, group in groups:
if len(group) > 1:
group_name = group[0]
# Combine counts within the group.
for _, group in groups:
if len(group) > 1:
group_name = group[0]
counts.loc[:, group_name] = counts[list(group)].sum(axis=1)
counts.drop(labels=list(groups[1:]), axis=1, inplace=True)
counts.loc[:, group_name] = counts[list(group)].sum(axis=1)
counts.drop(labels=list(groups[1:]), axis=1, inplace=True)
# Remove any groups that weren't specified.
all_groups = counts.columns.tolist()
specified_groups = set(group[0] for _, group in groups)
unspecified_groups = [group for group in all_groups if group not in specified_groups]
# Remove any groups that weren't specified.
all_groups = counts.columns.tolist()
specified_groups = set(group[0] for _, group in groups)
unspecified_groups = [group for group in all_groups if group not in specified_groups]
counts.drop(labels=unspecified_groups, axis=1, inplace=True)
counts.drop(labels=unspecified_groups, axis=1, inplace=True)
"""
Employ cutoff.
"""
sections = list(specified_groups)
nodes = counts.index.tolist()
"""
Employ cutoff.
"""
sections = list(specified_groups)
nodes = counts.index.tolist()
if cutoff is None and cpm is None:
nodes_with_counts = nodes
if cutoff is None and cpm is None:
nodes_with_counts = nodes
else:
nodes_with_counts = set()
else:
nodes_with_counts = set()
for section in sections:
total = sum(counts[section])
section_cutoff = max(cutoff, (total / 1e6) * cpm)
for section in sections:
total = sum(counts[section])
section_cutoff = max(cutoff, (total / 1e6) * cpm)
for index in nodes:
if index not in nodes_with_counts:
if np.any(counts.loc[index, :] >= section_cutoff):
nodes_with_counts.add(index)
for index in nodes:
if index not in nodes_with_counts:
if np.any(counts.loc[index, :] >= section_cutoff):
nodes_with_counts.add(index)
nodes_with_counts = list(nodes_with_counts)
nodes_with_counts = list(nodes_with_counts)
"""
Attempt pruning of tree.
"""
"""
Attempt pruning of tree.
"""
try:
if not keep_zeros:
tree.prune(nodes_with_counts, preserve_branch_length=True)
except ete3.coretype.tree.TreeError:
print("Tree pruning failed, aborting tree drawing!")
return
"""
Colour generation.
"""
max_vector = tuple(counts.max())
min_vector = tuple(counts.min())
colours = [colour for colour, _ in groups]
colour_generator = ColourGenerator(
colours,
max_vector,
min_vector,
colour_list=colour_list,
log_scores=log_scores
)
if itol_mode:
itol_data = [
"TREE_COLORS",
"# Drag me onto the iTOL tree (in browser) to apply me!",
"SEPARATOR TAB",
"DATA",
]
for node in tree.traverse():
if node.name in counts.index:
node_counts = tuple(counts.loc[node.name, :])
if any(node_counts):
colour = colour_generator.generate(node_counts)
itol_data.append("%s\trange\t%s\t%s" % (node.name, colour, node.name.upper()))
# Write itol_data and Newick tree to file.
name_modifier = re.findall(r"phylotree_(\S+).pdf", os.path.basename(out_dir))[0]
itol_dir = os.path.join(os.path.dirname(out_dir), 'itol_%s' % name_modifier)
if not os.path.exists(itol_dir):
os.mkdir(itol_dir)
print("iTOL output directory created at %s." % itol_dir)
# Write style file.
itol_style_dir = os.path.join(itol_dir, 'style.txt')
with open(itol_style_dir, 'w') as f:
f.write('\n'.join(itol_data))
print("iTOL style written to %s." % itol_style_dir)
# Write Newick file.
itol_tree_dir = os.path.join(itol_dir, 'itol_tree.nwk')
tree.write(format=1, outfile=itol_tree_dir)
print("iTOL tree written to %s. Put this in itol (in your browser)." % itol_tree_dir)
else:
# Draw with ete3.
try:
if not keep_zeros:
tree.prune(nodes_with_counts, preserve_branch_length=True)
from ete3 import AttrFace, faces, TreeStyle, NodeStyle, TextFace
except ModuleNotFoundError as e:
print("Could not import ete3 plotting modules! Error raised:")
print(traceback.format_exc())
print("Skipping plotting...")
except ete3.coretype.tree.TreeError:
print("Tree pruning failed, aborting tree drawing!")
return
"""
Colour generation.
Function to render any given node.
"""
max_vector = tuple(counts.max())
min_vector = tuple(counts.min())
colours = [colour for colour, _ in groups]
colour_generator = ColourGenerator(
colours,
max_vector,
min_vector,
colour_list=colour_list,
log_scores=log_scores
)
"""
Function to render any given node.
"""
def layout(node):
nonlocal name_to_taxon
def layout(node):
nonlocal name_to_taxon
if node.is_leaf():
if use_node_names:
faces.add_face_to_node(AttrFace('name', fsize=20), node, column=0, position='aligned')
if node.is_leaf():
if use_node_names:
faces.add_face_to_node(AttrFace('name', fsize=20), node, column=0, position='aligned')
if use_phyla:
for phyla, colour in PHYLA_COLOURS:
lineage = name_to_taxon[node.name]
if use_phyla:
for phyla, colour in PHYLA_COLOURS:
lineage = name_to_taxon[node.name]
if phyla in lineage:
tax_face = TextFace(" ")
tax_face.background.color = colour
faces.add_face_to_node(tax_face, node, column=1, position='aligned')
if phyla in lineage:
tax_face = TextFace(" ")
tax_face.background.color = colour
faces.add_face_to_node(tax_face, node, column=1, position='aligned')
if node.name in counts.index:
node_counts = tuple(counts.loc[node.name, :])
if (counts is not None) and (node.name in counts.index):
node_counts = tuple(counts.loc[node.name, :])
if any(node_counts):
colour = colour_generator.generate(node_counts)
if any(node_counts):
colour = colour_generator.generate(node_counts)
ns = NodeStyle()
ns['bgcolor'] = colour
node.set_style(ns)
ns = NodeStyle()
ns['bgcolor'] = colour
node.set_style(ns)
"""
Render tree.
"""
ts = TreeStyle()
ts.mode = "c"
ts.show_leaf_name = False
ts.layout_fn = layout
ts.force_topology = False
ts.allow_face_overlap = False
ts.draw_guiding_lines = True
ts.root_opening_factor = 1
tree.render(
out_dir,
tree_style=ts,
dpi=300
)
print("Phylogenetic tree written to %s!" % out_dir)
"""
Render tree.
"""
ts = TreeStyle()
ts.mode = "c"
ts.show_leaf_name = False
ts.layout_fn = layout
ts.force_topology = False
ts.allow_face_overlap = False
ts.draw_guiding_lines = True
ts.root_opening_factor = 1
tree.render(
out_dir,
tree_style=ts,
dpi=300
)
print("Phylogenetic tree written to %s!" % out_dir)
class HexColour:
......@@ -1028,7 +1116,7 @@ class RandomColour:
class ColourGenerator:
def __init__(self, colours, max_vector, min_vector=None, colour_list=None, log_score=False):
def __init__(self, colours, max_vector, min_vector=None, colour_list=None, log_scores=False):
"""
:param colours:
......@@ -1059,7 +1147,7 @@ class ColourGenerator:
#
self.max_vector = max_vector
self.min_vector = min_vector
self.make_log_score = log_score
self.make_log_score = log_scores
if self.make_log_score and self.min_vector is None: