Commit 0e01e20b authored by Sean Solari's avatar Sean Solari
Browse files

Bug fixes

parent d15820b7
......@@ -27,7 +27,7 @@ TAX_SPLIT_FILE = os.path.join(TAX_RESULTS, SPLIT_NAME)
ResultsPathConfig = namedtuple(
'ResultsPathConfig',
[
'phy', 'tax',
'base', 'phy', 'tax',
'temp',
'phy_raw', 'tax_raw',
'phy_classified', 'phy_split',
......
......@@ -2,6 +2,7 @@ from math import floor
from multiprocessing import shared_memory
import os
import re
import shutil
from typing import Union
import numpy as np
......@@ -37,7 +38,7 @@ def run_classifier(
colour_list: list[str] = None, paired_end: bool = False, alpha: float = 1.0,
log_scores: bool = False, itol_mode: bool = False
):
output_config: ResultsPathConfig = load_results_config(out_dir)
output_config: ResultsPathConfig = load_results_config(out_dir, create=True)
database_config: FileLocationConfig = load_database_config(db_dir)
# Load the kmer dictionary.
......@@ -83,8 +84,7 @@ def run_classifier(
results = ClassificationResults(
index=index,
phylogeny_index=phylogeny_index,
in_dir=output_config.phy,
out_dir=out_dir,
results_config=output_config,
groups=groups,
keep_zeros=keep_zeros,
cutoff=cutoff,
......@@ -123,7 +123,7 @@ def run_classifier(
values_shm.unlink()
values_shm.close()
os.rmdir(output_config.temp)
shutil.rmtree(output_config.temp)
def name_to_id(phylogeny_path: str):
......@@ -141,7 +141,7 @@ def name_to_id(phylogeny_path: str):
class Distribution:
def __init__(self, k, kmer_db, index, lca_matrix, read_paths, out_dir, temp_dir, logging_dir, alpha,
def __init__(self, k, kmer_db, index: Index, lca_matrix, read_paths, out_dir, temp_dir, logging_dir, alpha,
keep_zeros=False, cutoff=0.0, cpm=0.0, paired_end=False):
"""
......@@ -160,7 +160,7 @@ class Distribution:
self.temp_dir = temp_dir
self.logging_dir = logging_dir
self.index = index
self.index: Index = index
self.node_names = [node.name if i > 0 else "unclassified" for i, node in enumerate(index.pool)]
self.keep_zeros = keep_zeros
......@@ -227,18 +227,15 @@ class Distribution:
# Combine raw read output.
temporary_files = ls(self.temp_dir, ext=".reads_%s" % EXPAM_TEMP_EXT)
base_file_names = {
re.match(r'(\S+)_\d+.reads_%s' % EXPAM_TEMP_EXT, file_name).group(1)
re.match(r'(\S+)_\d+.reads_%s' % EXPAM_TEMP_EXT, os.path.basename(file_name)).group(1)
for file_name in temporary_files
}
result_files = []
for base_file in base_file_names:
results_file_name = os.path.join(results_config.phy_raw, base_file + ".csv")
results = [
self.get_data(os.path.join(self.temp_dir, file))
for file in temporary_files
if file[:len(base_file)] == base_file
]
raw_files = [file for file in temporary_files if re.match(r'.*{base}_\d+.reads_{ext}$'.format(base=base_file, ext=EXPAM_TEMP_EXT), file)]
results = [self.get_data(file) for file in raw_files]
with open(results_file_name, "w") as f:
f.write("\n".join(results))
......@@ -446,14 +443,12 @@ class Distribution:
class ClassificationResults:
def __init__(self, index, phylogeny_index, in_dir, out_dir, groups=None, keep_zeros=False, cutoff=0.0, cpm=0.0,
def __init__(self, index, phylogeny_index, results_config, groups=None, keep_zeros=False, cutoff=0.0, cpm=0.0,
use_node_names=True, phyla=False, name_taxa=None, colour_list=None, circle_scale=1.0,
log_scores=False):
self.index = index
self.phylogeny_index: Index = phylogeny_index
self.in_dir = in_dir
self.out_dir = out_dir
self.results_config: ResultsPathConfig = results_config
self.groups = groups # [(colour, (name1, name2, ...)), ...]
self.keep_zeros = keep_zeros
......@@ -477,17 +472,11 @@ class ClassificationResults:
def to_taxonomy(self, name_to_lineage, taxon_to_rank, tax_dir):
col_names = ["c_perc", "c_cumul", "c_count", "s_perc", "s_cumul", "s_count", "rank", "scientific name"]
raw_counts_dir = os.path.join(self.in_dir, 'raw')
raw_output_dir = os.path.join(tax_dir, 'raw')
if not os.path.exists(raw_output_dir):
os.mkdir(raw_output_dir)
class_counts = pd.read_csv(os.path.join(self.in_dir, "classified_counts.csv"), sep="\t", index_col=0, header=0)
split_counts = pd.read_csv(os.path.join(self.in_dir, "splits_counts.csv"), sep="\t", index_col=0, header=0)
class_counts = pd.read_csv(self.results_config.phy_classified, sep="\t", index_col=0, header=0)
split_counts = pd.read_csv(self.results_config.phy_split, sep="\t", index_col=0, header=0)
# Get rid of phylogenetically printed node names.
def fix_index(df):
def fix_index(df: pd.DataFrame):
df.index = [
index.lstrip("p")
if index not in self.phylogeny_index._pointers
......@@ -552,7 +541,7 @@ class ClassificationResults:
df.loc[:, "s_perc"] = round(df["s_perc"], 3).map(str) + "%"
# Employ cutoff.
cutoff = max((self.cutoff, (total_counts / 1e6) * self.cpm))
cutoff = max(self.cutoff, (total_counts / 1e6) * self.cpm)
df = df[(df['c_cumul'] > cutoff) | (df['s_cumul'] > cutoff) | (df.index == 'unclassified')]
df.to_csv(os.path.join(tax_dir, sample_name + ".csv"), sep="\t", header=True)
......@@ -561,7 +550,7 @@ class ClassificationResults:
# Map raw read output to taxonomy.
#
raw_counts_file = os.path.join(raw_counts_dir, sample_name + ".csv")
raw_counts_file = os.path.join(self.results_config.phy_raw, sample_name + ".csv")
raw_read_data = []
with open(raw_counts_file, 'r') as f:
......@@ -577,7 +566,7 @@ class ClassificationResults:
raw_read_data.append('\t'.join(read_data))
with open(os.path.join(raw_output_dir, sample_name + '.csv'), 'w') as f:
with open(os.path.join(self.results_config.tax_raw, sample_name + '.csv'), 'w') as f:
f.write('\n'.join(raw_read_data))
def map_phylogeny_node(self, node_name, name_to_lineage, taxon_to_rank):
......@@ -644,8 +633,8 @@ class ClassificationResults:
def draw_results(self, itol_mode=False):
# Draw classified tree.
self.phylogeny_index.draw_results(
os.path.join(self.in_dir, "classified_counts.csv"),
os.path.join(self.out_dir, "phylotree_classified.pdf"),
self.results_config.phy_classified,
os.path.join(self.results_config.base, "phylotree_classified.pdf"),
skiprows=[1],
groups=self.groups,
cutoff=self.cutoff,
......@@ -661,8 +650,8 @@ class ClassificationResults:
# Draw unclassified tree.
self.phylogeny_index.draw_results(
os.path.join(self.in_dir, "splits_counts.csv"),
os.path.join(self.out_dir, "phylotree_splits.pdf"),
self.results_config.phy_split,
os.path.join(self.results_config.base, "phylotree_splits.pdf"),
groups=self.groups,
cutoff=self.cutoff,
cpm=self.cpm,
......
......@@ -6,6 +6,7 @@ from expam.utils import die
def make_results_config(out_path: str) -> ResultsPathConfig:
output_file_locations = {
'base': out_path,
'phy': os.path.join(out_path, PHY_RESULTS),
'tax': os.path.join(out_path, TAX_RESULTS),
'temp': os.path.join(out_path, TEMP_RESULTS),
......@@ -30,14 +31,13 @@ def load_results_config(out_path: str, create: bool = False) -> ResultsPathConfi
print("Failed to make results path %s." % out_path)
create_results(proposed_config)
if not validate_results_configuration(proposed_config, check_taxonomy=False):
elif not validate_results_configuration(proposed_config, check_taxonomy=False):
die("Results path does not exist!")
return proposed_config
def create_results(config: ResultsPathConfig):
for path_field in ('phy', 'tax', 'phy_raw', 'tax_raw'):
for path_field in ('phy', 'tax', 'phy_raw', 'tax_raw', 'temp'):
path = getattr(config, path_field)
if not os.path.exists(path):
......
import os
import re
import time
import requests
......@@ -10,12 +11,15 @@ from expam.utils import yield_csv
class TaxonomyNCBI:
def __init__(self, file_config: FileLocationConfig) -> None:
self.config = file_config
self.config: FileLocationConfig = file_config
if not validate_taxonomy_files(file_config):
def find_downloaded_taxonomy(self):
if not validate_taxonomy_files(self.config):
raise OSError("Taxonomy files not located!")
def load_taxonomy_map(self, convert_to_name=True):
self.find_downloaded_taxonomy()
# Create map from scientific name --> (taxid, rank).
taxon_data = {}
for data in yield_csv(self.config.taxon_rank):
......@@ -40,14 +44,18 @@ class TaxonomyNCBI:
return list(yield_csv(self.config.accession_id))
def load_taxid_lineage_map(self):
return list(yield_csv(self.config.taxid_lineage))
if os.path.exists(self.config.taxid_lineage):
return list(yield_csv(self.config.taxid_lineage))
else:
return []
def load_rank_map(self):
name_to_rank = {}
for data in yield_csv(self.config.taxon_rank):
if len(data) > 1:
name_to_rank[data[0]] = ",".join(data[1:])
if os.path.exists(self.config.taxon_rank):
for data in yield_csv(self.config.taxon_rank):
if len(data) > 1:
name_to_rank[data[0]] = ",".join(data[1:])
return name_to_rank
......
......@@ -11,7 +11,7 @@ from expam.utils import die, is_hex, make_path_absolute
class ClassifyCommand(CommandGroup):
commands: set[str] = {
'run', 'to_taxonomy', 'download_taxonomy'
'classify', 'to_taxonomy'
}
def __init__(
......@@ -50,6 +50,9 @@ class ClassifyCommand(CommandGroup):
cutoff = cls.parse_ints(args.cutoff)
cpm, alpha = cls.parse_floats(args.cpm, args.alpha)
if args.out_url is None:
die("Must supply -o/--out.")
# Format groups.
if args.groups is not None:
groups = [v for v in args.groups if v]
......@@ -90,8 +93,7 @@ class ClassifyCommand(CommandGroup):
}
def check_database_exists(self):
if not validate_database_file_configuration(self.config):
die("Database %s does not exist!" % self.config.database)
validate_database_file_configuration(self.config)
if not os.path.exists(self.config.database_file):
die("Database has not been built! Not found at %s." % self.config.database_file)
......@@ -101,7 +103,7 @@ class ClassifyCommand(CommandGroup):
===========
"""
def run(self):
def classify(self):
self.check_database_exists()
clear_logs(self.config.logs)
......@@ -114,14 +116,20 @@ class ClassifyCommand(CommandGroup):
if len(keys_shape) == 1:
keys_shape = keys_shape + (1,)
tax_obj = TaxonomyNCBI()
name_to_lineage, _ = tax_obj.load_taxonomy_map(self.config)
try:
tax_obj = TaxonomyNCBI(self.config)
name_to_lineage, _ = tax_obj.load_taxonomy_map(self.config)
except OSError:
if self.convert_to_taxonomy:
die("First run `download_taxonomy` to collect associated taxonomy data.")
else:
name_to_lineage = None
# Run expam classification.
run_classifier(
read_paths=self.files,
out_dir=self.out_dir,
db_dir=self.config.database,
db_dir=self.config.base,
k=k,
n=n - 1, # Account for main process.
phylogeny_path=phylogeny_path,
......@@ -153,7 +161,7 @@ class ClassifyCommand(CommandGroup):
if self.out_dir is None:
die("Require output directory (-o, --out_dir)!")
else:
validate_results_configuration(self.out_dir)
validate_results_configuration(self.results_config)
if not os.path.exists(self.config.taxid_lineage):
die("Run command `download_taxonomy` first to collect taxa for your genomes!")
......@@ -163,7 +171,7 @@ class ClassifyCommand(CommandGroup):
index, phylogenyIndex = name_to_id(phylogeny_path)
tax_obj = TaxonomyNCBI()
tax_obj = TaxonomyNCBI(self.config)
name_to_lineage, taxon_to_rank = tax_obj.load_taxonomy_map(self.config)
if not os.path.exists(self.results_config.tax):
......@@ -172,8 +180,7 @@ class ClassifyCommand(CommandGroup):
results = ClassificationResults(
index=index,
phylogeny_index=phylogenyIndex,
in_dir=self.results_config.phy,
out_dir=self.out_dir,
results_config=self.results_config,
groups=self.groups,
keep_zeros=self.keep_zeros,
cutoff=self.cutoff,
......@@ -185,15 +192,4 @@ class ClassifyCommand(CommandGroup):
log_scores=self.log_scores
)
results.to_taxonomy(name_to_lineage, taxon_to_rank, self.results_config.tax)
"""
Download taxonomy command
=========================
"""
def download_taxonomy(self):
self.check_database_exists()
tax_obj: TaxonomyNCBI = TaxonomyNCBI(self.config)
tax_obj.accession_to_taxonomy()
\ No newline at end of file
......@@ -126,8 +126,7 @@ class TreeCommand(CommandGroup):
return int(n_processes)
def check_database_exists(self):
if not validate_database_file_configuration(self.config):
die("Database %s does not exist!" % self.config.database)
validate_database_file_configuration(self.config)
def check_results_exist(self):
if self.results_config is None:
......
......@@ -2,6 +2,7 @@ import os
import matplotlib.pyplot as plt
from expam.classify import ResultsPathConfig
from expam.classify.config import make_results_config, validate_classification_results, validate_results_configuration
from expam.classify.taxonomy import TaxonomyNCBI
from expam.cli.main import CommandGroup, ExpamOptions
from expam.database import FileLocationConfig
from expam.database.config import JSONConfig, make_database_config, validate_database_file_configuration
......@@ -12,7 +13,7 @@ from expam.utils import die, ls
class UtilsCommand(CommandGroup):
commands: set[str] = {
'cutoff', 'fake_phylogeny', 'plot_memory'
'download_taxonomy', 'cutoff', 'fake_phylogeny', 'plot_memory'
}
def __init__(
......@@ -22,7 +23,7 @@ class UtilsCommand(CommandGroup):
self.config: FileLocationConfig = config
self.out_dir: str = out_dir
self.results_config: ResultsPathConfig = make_results_config(out_dir)
self.results_config: ResultsPathConfig = None if out_dir is None else make_results_config(out_dir)
self.json_conf = None
......@@ -49,8 +50,18 @@ class UtilsCommand(CommandGroup):
return self.json_conf
def check_database_exists(self):
if not validate_database_file_configuration(self.config):
die("Database %s does not exist!" % self.config.database)
validate_database_file_configuration(self.config)
"""
Download taxonomy command
=========================
"""
def download_taxonomy(self):
self.check_database_exists()
tax_obj: TaxonomyNCBI = TaxonomyNCBI(self.config)
tax_obj.accession_to_taxonomy()
"""
Employ cutoff on taxonomic classification output
......@@ -58,6 +69,9 @@ class UtilsCommand(CommandGroup):
"""
def cutoff(self):
if self.out_dir is None:
die("Must supply -o/--out!")
validate_results_configuration(self.results_config, check_taxonomy=True)
if os.path.exists(self.out_dir):
......
......@@ -5,6 +5,7 @@ from expam.cli.build import BuildCommand
from expam.cli.classify import ClassifyCommand
from expam.cli.main import CommandGroup, ExpamOptions, retrieve_arguments
from expam.cli.tree import TreeCommand
from expam.cli.utils import UtilsCommand
from expam.utils import die
......@@ -15,7 +16,7 @@ def main():
args: ExpamOptions = retrieve_arguments()
handlers: tuple[CommandGroup] = (BuildCommand, ClassifyCommand, TreeCommand)
handlers: tuple[CommandGroup] = (BuildCommand, ClassifyCommand, TreeCommand, UtilsCommand)
for handler in handlers:
if args.command in handler.commands:
handler(**handler.take_args(args)).run(args.command)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment