Commit d15820b7 authored by Sean Solari's avatar Sean Solari
Browse files

Bug fixes for tree and database building

parent d1d4c5da
import os
from expam.cli.main import CommandGroup, ExpamOptions, clear_logs
from expam.database import FileLocationConfig
from expam.database.config import JSONConfig, create_database, make_database_config, validate_database_file_configuration
from expam.database.build import main as expam
from expam.database.config import ExpamDatabaseDoesNotExistError, JSONConfig, create_database, make_database_config, validate_database_file_configuration
from expam.logger import Timer
from expam.utils import die, ls, make_path_absolute
......@@ -15,9 +16,7 @@ class BuildCommand(CommandGroup):
def __init__(
self, config: FileLocationConfig,
k: int, n: int, s: int, phylogeny_path: str, pile_size: int,
files: list[str], group: str,
first_n: int,
make_plot: bool = False
files: list[str], group: str, first_n: int
) -> None:
super().__init__()
self.config: FileLocationConfig = config
......@@ -32,8 +31,6 @@ class BuildCommand(CommandGroup):
self.group = group
self.first_n = first_n
self.make_plot = make_plot
@classmethod
def take_args(cls: CommandGroup, args: ExpamOptions) -> dict:
k, n, s, pile_size, first_n = cls.parse_ints(args.k, args.n, args.s, args.pile, args.first_n)
......@@ -47,12 +44,13 @@ class BuildCommand(CommandGroup):
'pile_size': pile_size,
'files': args.directory,
'group': None if args.groups is None else args.groups[0][0],
'first_n': first_n,
'make_plot': args.plot
'first_n': first_n
}
def validate_database(self):
if not validate_database_file_configuration(self.config):
try:
validate_database_file_configuration(self.config)
except ExpamDatabaseDoesNotExistError:
die("%s is not a database." % self.config.database)
"""
......@@ -64,7 +62,7 @@ class BuildCommand(CommandGroup):
self.create()
self.set()
self.add()
self.build_database()
self.build()
"""
Default database command
......@@ -89,9 +87,8 @@ class BuildCommand(CommandGroup):
======================
"""
def build_database(self):
def build(self):
self.validate_database()
conf = JSONConfig(self.config.conf)
# Check if a phylogeny has been provided.
......@@ -100,7 +97,7 @@ class BuildCommand(CommandGroup):
# Read configuration file.
k, n, phylogeny, genome_paths, pile = conf.get_build_params()
phylogeny_path = make_path_absolute(phylogeny_path, self.config.phylogeny)
phylogeny_path = make_path_absolute(phylogeny, self.config.phylogeny)
clear_logs(self.config.logs)
......@@ -108,13 +105,12 @@ class BuildCommand(CommandGroup):
try:
with Timer() as t:
expam(
out_dir=self.config.database,
db_path=self.config.base,
genome_paths=genome_paths,
phylogeny_path=phylogeny,
phylogeny_path=phylogeny_path,
k=k,
n=n - 1, # Account for main process.
pile_size=pile,
plot=self.make_plot
)
print("expam: " + str(t))
......
......@@ -3,6 +3,7 @@ import os
import re
import shutil
import subprocess
from typing import Union
from expam.classify import ResultsPathConfig
from expam.classify.classify import ClassificationResults, name_to_id
from expam.classify.config import make_results_config, validate_results_configuration
......@@ -26,12 +27,12 @@ class TreeCommand(CommandGroup):
out_dir: str, cutoff: int, cpm: float, groups: list[tuple[str]],
use_node_names: bool, keep_zeros: bool, plot_phyla: bool,
colour_list: list[str], log_scores: bool, itol_mode: bool,
at_rank: str, n: int, use_sourmash: bool, use_quicktree: bool
at_rank: str, use_sourmash: bool, use_quicktree: bool
) -> None:
super().__init__()
self.config: FileLocationConfig = config
self.results_config: ResultsPathConfig = make_results_config(out_dir)
self.results_config: ResultsPathConfig = None if out_dir is None else make_results_config(out_dir)
self.json_conf = None
self.out_dir = out_dir
......@@ -50,7 +51,6 @@ class TreeCommand(CommandGroup):
self.at_rank = at_rank
self.n = n
self.use_sourmash = use_sourmash
self.use_quicktree = use_quicktree
......@@ -94,7 +94,6 @@ class TreeCommand(CommandGroup):
'log_scores': args.log_scores,
'itol_mode': args.itol_mode,
'at_rank': args.rank,
'n': n,
'use_sourmash': args.use_sourmash,
'use_quicktree': args.use_quicktree
}
......@@ -106,12 +105,36 @@ class TreeCommand(CommandGroup):
return self.json_conf
def get_group_or_die(self, group_name):
conf: JSONConfig = self.get_conf()
k, s, sequences = conf.group_get(group_name)
if k is None:
die("Parameter `k` has not been set for group %s." % group_name)
elif s is None:
die("Parameter `s` has not been set for group %s." % group_name)
return k, s, sequences
def get_n_processes(self) -> int:
conf: JSONConfig = self.get_conf()
n_processes: Union[None, int] = conf.get_n_processes()
if n_processes is None:
return 1
else:
return int(n_processes)
def check_database_exists(self):
if not validate_database_file_configuration(self.config):
die("Database %s does not exist!" % self.config.database)
if not os.path.exists(self.config.database_file):
die("Database has not been built! Not found at %s." % self.config.database_file)
def check_results_exist(self):
if self.results_config is None:
die("Results not specified!")
if not validate_results_configuration(self.config):
die("Invalid results configuration!")
@staticmethod
def _check_command(command):
......@@ -153,10 +176,7 @@ class TreeCommand(CommandGroup):
"""
def phylotree(self):
if self.results_config is None:
die("Require output directory (-o, --out_dir)!")
validate_results_configuration(self.results_config)
self.check_results_exist()
config: JSONConfig = self.get_conf()
phylogeny_path = make_path_absolute(config["phylogeny_path"], self.config.database)
......@@ -274,36 +294,10 @@ class TreeCommand(CommandGroup):
def tree(self):
conf: JSONConfig = self.get_conf()
entry_points = (
partial( # Sketch sequences.
self.do_sketches,
conf=conf,
phy_dir=self.config.phylogeny,
group=self.group,
use_sourmash=self.use_sourmash
),
partial( # Pairwise distances.
self.do_distances,
conf=conf,
phy_dir=self.config.phylogeny,
group=self.group,
use_sourmash=self.use_sourmash
),
partial( # NJ trees.
self.do_trees,
conf=conf,
phy_dir=self.config.phylogeny,
group=self.group,
use_quicktree=self.use_quicktree
)
)
entry_stage = self.argmax(
self.check_sketches(conf, self.config.phylogeny, self.group, use_sourmash=self.use_sourmash),
self.check_distances(conf, self.config.phylogeny, self.group),
self.check_trees(conf, self.config.phylogeny, self.group)
)
entry_points = (self.do_sketches, self.do_distances, self.do_trees)
entry_stage = self.argmax(self.check_sketches(), self.check_distances(), self.check_trees())
for stage in range(entry_stage, 3):
for stage in range(entry_stage, len(entry_points)):
entry_points[stage]()
tree_dir = self.finalise_tree()
......@@ -377,9 +371,8 @@ class TreeCommand(CommandGroup):
"""
def mashtree(self):
conf: JSONConfig = self.get_conf()
self.do_mashtree(conf, self.config.phylogeny, self.group)
self.do_mashtree()
self.tree()
def do_mashtree(self):
print("Creating mashtree...")
......@@ -388,17 +381,18 @@ class TreeCommand(CommandGroup):
tree_dir = os.path.join(self.config.phylogeny, 'tree')
temp_dir = os.path.join(self.config.phylogeny, 'tmp')
n: int = self.get_n_processes()
if not os.path.exists(tree_dir):
os.mkdir(tree_dir)
for group_name in conf.get_groups(self.group):
k, s, sequences = conf.group_get(group_name)
k, s, sequences = self.get_group_or_die(group_name)
tree_path = os.path.join(tree_dir, "%s.nwk" % group_name)
self.mashtree(k, s, sequences, tree_path, temp_dir)
self.make_mashtree(k, s, n, sequences, tree_path, temp_dir)
def mashtree(self, k, s, sequences, tree_dir, temp_dir):
def make_mashtree(self, k, s, n, sequences, tree_dir, temp_dir):
_names_file = os.path.join(temp_dir, 'sequence_names.txt')
def delete_temp():
......@@ -430,7 +424,7 @@ class TreeCommand(CommandGroup):
# Make mashtree and insert into configuration file.
print("Making mashtree...")
cmd = "mashtree --numcpus %d --kmerlength %d --sketch-size %d --file-of-files %s" \
% (self.n, k, s, _names_file)
% (n, k, s, _names_file)
try:
return_val = self.shell(cmd, cwd=temp_dir)
......@@ -453,9 +447,7 @@ class TreeCommand(CommandGroup):
"""
def sketch(self):
conf: JSONConfig = self.get_conf()
self.do_sketches(conf, self.config.phylogeny, self.group, use_sourmash=self.use_sourmash)
self.do_sketches()
def check_sketches(self):
conf: JSONConfig = JSONConfig(self.config.conf)
......@@ -466,7 +458,7 @@ class TreeCommand(CommandGroup):
return False
for group_name in conf.get_groups(self.group):
k, s, _ = conf.group_get(group_name)
k, s, _ = self.get_group_or_die(group_name)
file_name = sketch_name_fmt % (group_name, k, s, "%s")
dest = os.path.join(sketch_dir, file_name % ("sour" if self.use_sourmash else "msh"))
......@@ -477,6 +469,7 @@ class TreeCommand(CommandGroup):
def do_sketches(self):
conf: JSONConfig = JSONConfig(self.config.conf)
n: int = self.get_n_processes()
sketch_dir = os.path.join(self.config.phylogeny, 'sketch')
if not os.path.exists(sketch_dir):
......@@ -485,7 +478,7 @@ class TreeCommand(CommandGroup):
sketch_name_fmt = "%s.k%d.s%d.%s"
for group_name in conf.get_groups(self.group):
k, s, sequences = conf.group_get(group_name)
k, s, sequences = self.get_group_or_die(group_name)
file_name = sketch_name_fmt % (group_name, k, s, "%s")
if self.use_sourmash:
......@@ -499,11 +492,11 @@ class TreeCommand(CommandGroup):
self.check_mash()
file_path = os.path.join(sketch_dir, file_name % "msh")
self.mash_sketch(k=k, s=s, p=self.n, sequences=sequences, out_dir=file_path)
self.mash_sketch(k=k, s=s, p=n, sequences=sequences, out_dir=file_path)
def sour_sketch(self, k: int, s: int, sequences: list[str], sig_dir: str):
from expam.tree.sourmash import make_signatures
make_signatures(self.n, sequences, sig_dir, k, s)
make_signatures(self.get_n_processes(), sequences, sig_dir, k, s)
def mash_sketch(self, k, s, p, sequences, out_dir):
cmd_fmt = "mash sketch -k %d -p %d -s %d -o %s %s"
......@@ -516,8 +509,7 @@ class TreeCommand(CommandGroup):
"""
def distance(self):
conf: JSONConfig = self.get_conf()
self.do_distances(conf, self.config.phylogeny, self.group, use_sourmash=self.use_sourmash)
self.do_distances()
def check_distances(self):
conf: JSONConfig = self.get_conf()
......@@ -531,7 +523,7 @@ class TreeCommand(CommandGroup):
dist_name_fmt = "%s.k%d.s%d.tab"
for group_name in conf.get_groups(self.group):
k, s, _ = conf.group_get(group_name)
k, s, _ = self.get_group_or_raise(group_name)
file_name = dist_name_fmt % (group_name, k, s)
dest = os.path.join(sketch_dir, file_name)
......@@ -553,8 +545,7 @@ class TreeCommand(CommandGroup):
dist_name_fmt = "%s.k%d.s%d.%s"
for group_name in conf.get_groups(self.group):
k, s, _ = conf.group_get(group_name)
k, s, _ = self.get_group_or_die(group_name)
sketch_name = dist_name_fmt % (group_name, k, s, '%s')
matrix_name = dist_name_fmt % (group_name, k, s, 'tab')
......@@ -575,7 +566,7 @@ class TreeCommand(CommandGroup):
cmd_fmt = "mash dist -p %d -t %s %s"
cmd = cmd_fmt % (
self.n,
self.get_n_processes(),
sketch_dir,
sketch_dir
)
......@@ -597,7 +588,7 @@ class TreeCommand(CommandGroup):
def sour_dist(self, sig_dir, matrix_dir):
from expam.tree.sourmash import make_distances
make_distances(sig_dir, self.n, matrix_dir)
make_distances(sig_dir, self.get_n_processes(), matrix_dir)
"""
Execute neighbour-joining on distance matrix
......@@ -605,8 +596,7 @@ class TreeCommand(CommandGroup):
"""
def nj(self):
conf: JSONConfig = self.get_conf()
self.do_trees(conf, self.config.phylogeny, self.group, use_quicktree=self.use_quicktree)
self.do_trees()
def check_trees(self):
conf: JSONConfig = self.get_conf()
......@@ -636,7 +626,7 @@ class TreeCommand(CommandGroup):
os.mkdir(tree_dir)
for group_name in conf.get_groups(self.group):
k, s, _ = conf.group_get(group_name)
k, s, _ = self.get_group_or_die(group_name)
matrix_name = dist_fmt % (group_name, k, s)
matrix_path = os.path.join(dist_dir, matrix_name)
......@@ -653,7 +643,7 @@ class TreeCommand(CommandGroup):
self.check_rapidnj()
cmd_fmt = "rapidnj %s -i pd -o t -c %d"
cmd = cmd_fmt % (matrix_dir, self.n)
cmd = cmd_fmt % (matrix_dir, self.get_n_processes())
unformatted_tree = self.shell(cmd)
tree = self._format_tree_string(unformatted_tree)
......
......@@ -3,7 +3,6 @@ import math
from multiprocessing import Pipe, Value, shared_memory
import os
import subprocess
import numpy as np
from expam.database import CHUNK_SIZE, TIMEOUT, UNION_RATIO, FileLocationConfig, expam_dtypes
from expam.database.config import load_database_config
......@@ -175,7 +174,7 @@ def sort_by_size(dirs):
return ordered_files, max_size
def prepare_kmer_allocations(rows, cols, n_processes):
allocation_params = ()
allocation_params = tuple()
allocation_size = rows * cols * expam_dtypes.keys_dtype_size
allocation_shape = (rows, cols)
......@@ -194,7 +193,7 @@ def prepare_kmer_allocations(rows, cols, n_processes):
next_arr[:] = 0
# Pass on allocation details to children.
allocation_params += tuple(next_shm_allocation.name, allocation_shape)
allocation_params += (next_shm_allocation.name, allocation_shape)
next_shm_allocation.close()
return allocation_params
......
......@@ -198,14 +198,14 @@ class JSONConfig:
# Check group has more than zero sequences.
for group in groups:
k, s, sequences = self.group_get(group)
_, _, sequences = self.group_get(group)
if len(sequences) > 0:
if k is None or s is None:
raise ValueError("Parameters unspecified for group %s (k=%s, s=%s)!" % (group, str(k), str(s)))
yield group
def get_n_processes(self):
return self['n']
class ExpamDatabaseDoesNotExistError(Exception):
pass
......@@ -233,9 +233,7 @@ def make_database_config(db_path: str) -> FileLocationConfig:
def load_database_config(db_path: str) -> FileLocationConfig:
proposed_config: FileLocationConfig = make_database_config(db_path)
if not validate_database_file_configuration(proposed_config):
raise ExpamDatabaseDoesNotExistError("Database does not exist at %s" % db_path)
validate_database_file_configuration(proposed_config)
return proposed_config
......@@ -255,10 +253,11 @@ def create_database(config: FileLocationConfig) -> None:
def validate_database_file_configuration(proposed_config: FileLocationConfig) -> bool:
for field_to_check in ('base', 'database', 'phylogeny', 'conf'):
if not os.path.exists(getattr(proposed_config, field_to_check)):
return False
else:
return True
path_to_check = getattr(proposed_config, field_to_check)
if not os.path.exists(path_to_check):
print(path_to_check)
raise ExpamDatabaseDoesNotExistError("Database does not exist at %s" % path_to_check)
def validate_taxonomy_files(config: FileLocationConfig) -> bool:
......
from ctypes import Array, c_int8
from multiprocessing import Pipe, Queue, shared_memory
from ctypes import *
from multiprocessing import Array, Pipe, Queue, shared_memory
import os
import queue
import time
......@@ -673,7 +673,6 @@ class ExpamProcesses(ControlCenter):
# Save database parameters to config file.
self.save_db_params(
out_dir=self.out_dir,
db_params={
"keys_shape": current_shape,
"keys_data_type": self.dtypes.keys_dtype_str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment