Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Report FP / FN negative variants in single table #107

Merged
merged 19 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions workflow/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ if "variant-calls" in config:
benchmark=used_benchmarks,
vartype=["snvs", "indels"],
),
expand(
"results/report/fp-fn/callsets/{callset}/{classification}",
callset=used_callsets,
classification=["fp", "fn"],
),
get_fp_fn_reports,
# collect the checkpoint inputs to avoid issues when
# --all-temp is used: --all-temp leads to premature deletion
Expand Down
37 changes: 37 additions & 0 deletions workflow/resources/datavzrd/fp-fn-per-callset-config.yte.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
__use_yte__: true

__variables__:
green: "#74c476"
orange: "#fd8d3c"

name: ?f"{wildcards.classification} of {wildcards.callset}"

webview-controls: true
default-view: results-table

datasets:
results:
path: ?input.table
separator: "\t"
offer-excel: true

views:
results-table:
dataset: results
desc: |
?f"""
Rows are sorted by coverage.
Benchmark version: {params.genome} {params.version}
"""
page-size: 12
render-table:
columns:
coverage:
plot:
heatmap:
scale: ordinal
?if params.somatic:
true_genotype:
display-mode: hidden
predicted_genotype:
display-mode: hidden
54 changes: 50 additions & 4 deletions workflow/rules/common.smk
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ callsets = config.get("variant-calls", dict())
benchmarks.update(config.get("custom-benchmarks", dict()))
used_benchmarks = {callset["benchmark"] for callset in callsets.values()}

used_callsets = {callset for callset in callsets.keys()}

used_genomes = {benchmarks[benchmark]["genome"] for benchmark in used_benchmarks}


Expand Down Expand Up @@ -93,7 +95,7 @@ def get_plot_cov_labels(): # TODO check if ever used anywhere
def label(name):
lower, upper = get_cov_interval(name)
if upper:
return f"{lower}-{upper-1}"
return f"{lower}-{upper - 1}"
return f"≥{lower}"

return {name: label(name) for name in low_coverages}
Expand Down Expand Up @@ -403,6 +405,16 @@ def get_coverages(wildcards):
return coverages


def get_coverages_of_callset(callset):
benchmark = config["variant-calls"][callset]["benchmark"]
high_cov_status = benchmarks[benchmark].get("high-coverage", False)
if high_cov_status:
coverages = high_coverages
else:
coverages = low_coverages
return coverages


def get_somatic_status(wildcards):
if hasattr(wildcards, "benchmark"):
return genomes[benchmarks[wildcards.benchmark]["genome"]].get("somatic")
Expand Down Expand Up @@ -468,10 +480,17 @@ def get_collect_stratifications_input(wildcards):
)


def get_collect_stratifications_fp_fn_input(wildcards):
return expand(
"results/fp-fn/callsets/{{callset}}/{cov}.{{classification}}.tsv",
cov=get_nonempty_coverages(wildcards),
)


def get_fp_fn_reports(wildcards):
for genome in used_genomes:
yield from expand(
"results/report/fp-fn/{genome}/{cov}/{classification}",
"results/report/fp-fn/genomes/{genome}/{cov}/{classification}",
genome=genome,
cov={
cov
Expand All @@ -482,6 +501,15 @@ def get_fp_fn_reports(wildcards):
)


def get_fp_fn_reports_benchmarks(wildcards):
for genome in used_genomes:
yield from expand(
"results/report/fp-fn/benchmarks/{benchmark}/{classification}",
benchmark={benchmark for benchmark in used_benchmarks},
classification=["fp", "fn"],
)


def get_benchmark_callsets(benchmark):
return [
callset
Expand All @@ -497,6 +525,13 @@ def get_collect_precision_recall_input(wildcards):
)


def get_collect_fp_fn_benchmark_input(wildcards):
callsets = get_benchmark_callsets(wildcards.benchmark)
return expand(
"results/fp-fn/callsets/{callset}.{{classification}}.tsv", callset=callsets
)


def get_genome_name(wildcards):
if hasattr(wildcards, "benchmark"):
return get_benchmark(wildcards.benchmark).get("genome")
Expand Down Expand Up @@ -546,19 +581,30 @@ def get_callset_label_entries(callsets):


def get_collect_fp_fn_callsets(wildcards):
return get_genome_callsets(wildcards.genome)
callsets = get_genome_callsets(wildcards.genome)
callsets = [
callset
for callset in callsets
if wildcards.cov in get_coverages_of_callset(callset)
]
return callsets


def get_collect_fp_fn_input(wildcards):
callsets = get_collect_fp_fn_callsets(wildcards)
return expand(
"results/fp-fn/callsets/{{cov}}/{callset}/{{classification}}.tsv",
"results/fp-fn/callsets/{callset}/{{cov}}.{{classification}}.tsv",
callset=callsets,
)


def get_collect_fp_fn_labels(wildcards):
callsets = get_genome_callsets(wildcards.genome)
callsets = [
callset
for callset in callsets
if wildcards.cov in get_coverages_of_callset(callset)
]
BiancaStoecker marked this conversation as resolved.
Show resolved Hide resolved
return get_callset_label_entries(callsets)


Expand Down
76 changes: 71 additions & 5 deletions workflow/rules/eval.smk
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,9 @@ rule extract_fp_fn:
calls="results/vcfeval/{callset}/{cov}/output.vcf.gz",
common_src=common_src,
output:
"results/fp-fn/callsets/{cov}/{callset}/{classification}.tsv",
"results/fp-fn/callsets/{callset}/{cov}.{classification}.tsv",
log:
"logs/extract-fp-fn/{cov}/{callset}/{classification}.log",
"logs/extract-fp-fn/{callset}/{cov}.{classification}.log",
conda:
"../envs/vembrane.yaml"
script:
Expand Down Expand Up @@ -374,18 +374,55 @@ rule collect_fp_fn:
"../scripts/collect-fp-fn.py"


rule collect_stratifications_fp_fn:
input:
get_collect_stratifications_fp_fn_input,
output:
"results/fp-fn/callsets/{callset}.{classification}.tsv",
params:
coverages=get_nonempty_coverages,
coverage_lower_bounds=get_coverages,
log:
"logs/fp-fn/callsets/{callset}.{classification}.log",
conda:
"../envs/stats.yaml"
# This has to happen after precision/recall has been computed, otherwise we risk
# extremely high memory usage if a callset does not match the truth at all.
priority: 1
script:
"../scripts/collect-stratifications-fp-fn.py"


rule collect_fp_fn_benchmark:
input:
tables=get_collect_fp_fn_benchmark_input,
output:
"results/fp-fn/benchmarks/{benchmark}.{classification}.tsv",
params:
callsets=lambda w: get_benchmark_callsets(w.benchmark),
log:
"logs/fp-fn/benchmarks/{benchmark}.{classification}.log",
conda:
"../envs/stats.yaml"
script:
"../scripts/collect-fp-fn-benchmarks.py"


rule report_fp_fn:
input:
main_dataset="results/fp-fn/genomes/{genome}/{cov}/{classification}/main.tsv",
dependency_sorting_datasets="results/fp-fn/genomes/{genome}/{cov}/{classification}/dependency-sorting",
config=workflow.source_path("../resources/datavzrd/fp-fn-config.yte.yaml"),
output:
report(
directory("results/report/fp-fn/{genome}/{cov}/{classification}"),
directory("results/report/fp-fn/genomes/{genome}/{cov}/{classification}"),
htmlindex="index.html",
category="{classification} variants",
category="{classification} variants per genome",
subcategory=lambda w: w.genome,
labels=lambda w: {"coverage": w.cov},
labels=lambda w: {
"coverage": w.cov,
"genome": w.genome,
},
),
log:
"logs/datavzrd/fp-fn/{genome}/{cov}/{classification}.log",
Expand All @@ -394,3 +431,32 @@ rule report_fp_fn:
version=get_genome_version,
wrapper:
"v5.0.1/utils/datavzrd"


rule report_fp_fn_callset:
input:
table="results/fp-fn/callsets/{callset}.{classification}.tsv",
config=workflow.source_path(
"../resources/datavzrd/fp-fn-per-callset-config.yte.yaml"
),
output:
report(
directory("results/report/fp-fn/callsets/{callset}/{classification}"),
htmlindex="index.html",
category="{classification} variants per benchmark",
subcategory=lambda w: config["variant-calls"][w.callset]["benchmark"],
labels=lambda w: {
"callset": w.callset,
},
),
log:
"logs/datavzrd/fp-fn/{callset}/{classification}.log",
params:
labels=lambda w: get_callsets_labels(
get_benchmark_callsets(config["variant-calls"][w.callset]["benchmark"])
),
genome=get_genome_name,
version=get_genome_version,
somatic=get_somatic_status,
wrapper:
"v5.0.1/utils/datavzrd"
42 changes: 42 additions & 0 deletions workflow/scripts/collect-fp-fn-benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import sys
sys.stderr = open(snakemake.log[0], "w")

import pandas as pd


def load_data(path, callset):
d = pd.read_csv(path, sep="\t")
d.insert(0, "callset", callset)
return d


results = pd.concat(
[
load_data(f, callset)
for f, callset in zip(snakemake.input.tables, snakemake.params.callsets)
],
axis="rows",
)
famosab marked this conversation as resolved.
Show resolved Hide resolved

def cov_key(cov_label):
# return lower bound as integer for sorting
if ".." in cov_label:
return int(cov_label.split("..")[0])
else:
return int(cov_label[1:])



def sort_key(col):
if col.name == "callset":
return col
if col.name == "coverage":
return col.apply(cov_key)
else:
return col

famosab marked this conversation as resolved.
Show resolved Hide resolved

results.sort_values(["callset", "coverage"], inplace=True, key=sort_key)
results["sort_index"] = results["coverage"].apply(cov_key)

results.to_csv(snakemake.output[0], sep="\t", index=False)
famosab marked this conversation as resolved.
Show resolved Hide resolved
56 changes: 56 additions & 0 deletions workflow/scripts/collect-stratifications-fp-fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import sys

sys.stderr = open(snakemake.log[0], "w")

import pandas as pd

famosab marked this conversation as resolved.
Show resolved Hide resolved

def get_cov_label(coverage):
lower = snakemake.params.coverage_lower_bounds[coverage]
bounds = [
bound
for bound in snakemake.params.coverage_lower_bounds.values()
if bound > lower
]
if bounds:
upper = min(bounds)
return f"{lower}..{upper}"
else:
return f"≥{lower}"


def load_data(f, coverage):
d = pd.read_csv(f, sep="\t")
d.insert(0, "coverage", get_cov_label(coverage))
return d

famosab marked this conversation as resolved.
Show resolved Hide resolved

if snakemake.input:
report = pd.concat(
load_data(f, cov) for cov, f in zip(snakemake.params.coverages, snakemake.input)
)

# TODO With separate files for SNVs and indels with e.g. STRELKA no predicted variants for the other type are expected
# If later relevant, add annotation to the report
# if (report["tp_truth"] == 0).all():
# raise ValueError(
# f"The callset {snakemake.wildcards.callset} does not predict any variant from the truth. "
# "This is likely a technical issue in the callset and should be checked before further evaluation."
# )

report.to_csv(snakemake.output[0], sep="\t", index=False)
else:
pd.DataFrame(
{
col: []
for col in [
"coverage",
"class",
"chromosome position",
"ref_allele",
"alt_allele"
"true_genotype",
"predicted_genotype"
]
famosab marked this conversation as resolved.
Show resolved Hide resolved
}
).to_csv(snakemake.output[0], sep="\t")
Loading