https://github.com/tskit-dev/msprime
Tip revision: a3da1b02836537cf35564c914938c75d8ed1acec authored by Jerome Kelleher on 20 July 2016, 15:58:14 UTC
Changelog for 0.3.2.
Changelog for 0.3.2.
Tip revision: a3da1b0
verification.py
"""
Script to automate verification of the msprime simulator against
Hudson's ms.
"""
from __future__ import print_function
from __future__ import division
import math
import os
import random
import subprocess
import sys
import tempfile
import scipy.special
import pandas as pd
import numpy as np
import numpy.random
import statsmodels.api as sm
import matplotlib
# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')
from matplotlib import pyplot
import dendropy
import msprime.cli as cli
import msprime
class SimulationVerifier(object):
"""
Class to compare msprime against ms to ensure that the same distributions
of values are output under the same parameters.
"""
def __init__(self, output_dir):
self._output_dir = output_dir
self._instances = {}
self._ms_executable = ["./data/ms/ms"]
self._mspms_executable = ["python", "mspms_dev.py"]
def get_ms_seeds(self):
max_seed = 2**16
seeds = [random.randint(1, max_seed) for j in range(3)]
return ["-seed"] + map(str, seeds)
def _run_sample_stats(self, args):
print("\t", " ".join(args))
p1 = subprocess.Popen(args, stdout=subprocess.PIPE)
p2 = subprocess.Popen(
["./data/ms/sample_stats"], stdin=p1.stdout,
stdout=subprocess.PIPE)
p1.stdout.close()
output = p2.communicate()[0]
with tempfile.TemporaryFile() as f:
f.write(output)
f.seek(0)
df = pd.read_table(f)
return df
def _run_ms_mutation_stats(self, args):
return self._run_sample_stats(
self._ms_executable + args.split() + self.get_ms_seeds())
def _run_msprime_mutation_stats(self, args):
return self._run_sample_stats(
self._mspms_executable + args.split() + self.get_ms_seeds())
def _run_ms_coalescent_stats(self, args):
executable = ["./data/ms/ms_summary_stats"]
with tempfile.TemporaryFile() as f:
argList = executable + args.split() + self.get_ms_seeds()
print("\t", " ".join(argList))
subprocess.call(argList, stdout=f)
f.seek(0)
df = pd.read_table(f)
return df
def _run_msprime_coalescent_stats(self, args):
print("\t msprime:", args)
runner = cli.get_mspms_runner(args.split())
sim = runner.get_simulator()
rng = msprime.RandomGenerator(random.randint(1, 2**32 - 1))
sim.set_random_generator(rng)
num_populations = sim.get_num_populations()
replicates = runner.get_num_replicates()
num_trees = [0 for j in range(replicates)]
time = [0 for j in range(replicates)]
ca_events = [0 for j in range(replicates)]
re_events = [0 for j in range(replicates)]
mig_events = [None for j in range(replicates)]
for j in range(replicates):
sim.reset()
sim.run()
num_trees[j] = sim.get_num_breakpoints() + 1
time[j] = sim.get_time() / 4 # Convert to coalescent units
ca_events[j] = sim.get_num_common_ancestor_events()
re_events[j] = sim.get_num_recombination_events()
mig_events[j] = [
r for row in sim.get_num_migration_events() for r in row]
d = {
"t": time, "num_trees": num_trees,
"ca_events": ca_events, "re_events": re_events}
for j in range(num_populations**2):
events = [mig_events[k][j] for k in range(replicates)]
d["mig_events_{}".format(j)] = events
df = pd.DataFrame(d)
return df
def _build_filename(self, *args):
output_dir = os.path.join(self._output_dir, args[0])
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
return os.path.join(output_dir, "_".join(args[1:]))
def _plot_stats(self, key, stats_type, df_msp, df_ms):
assert set(df_ms.columns.values) == set(df_msp.columns.values)
for stat in df_ms.columns.values:
v1 = df_ms[stat]
v2 = df_msp[stat]
sm.graphics.qqplot(v1)
sm.qqplot_2samples(v1, v2, line="45")
f = self._build_filename(key, stats_type, stat)
pyplot.savefig(f, dpi=72)
pyplot.close('all')
def _run_coalescent_stats(self, key, args):
df_msp = self._run_msprime_coalescent_stats(args)
df_ms = self._run_ms_coalescent_stats(args)
self._plot_stats(key, "coalescent", df_ms, df_msp)
def _run_mutation_stats(self, key, args):
df_msp = self._run_msprime_mutation_stats(args)
df_ms = self._run_ms_mutation_stats(args)
self._plot_stats(key, "mutation", df_ms, df_msp)
def run(self, keys=None):
the_keys = sorted(self._instances.keys())
if keys is not None:
the_keys = keys
for key in the_keys:
runner = self._instances[key]
runner()
def add_ms_instance(self, key, command_line):
"""
Adds a test instance with the specified ms command line.
"""
def f():
print(key, command_line)
self._run_coalescent_stats(key, command_line)
self._run_mutation_stats(key, command_line)
self._instances[key] = f
def get_pairwise_coalescence_time(self, cmd, R):
# print("\t", " ".join(cmd))
output = subprocess.check_output(cmd)
T = np.zeros(R)
j = 0
for line in output.splitlines():
if line.startswith("("):
t = dendropy.Tree.get_from_string(line, schema="newick")
a = t.calc_node_ages()
T[j] = a[-1]
j += 1
return T
def run_pairwise_island_model(self):
"""
Runs the check for the pairwise coalscence times for within
and between populations.
"""
R = 10000
M = 0.2
basedir = "tmp__NOBACKUP__/analytical_pairwise_island"
if not os.path.exists(basedir):
os.mkdir(basedir)
for d in range(2, 6):
cmd = "2 {} -T -I {} 2 {} {}".format(R, d, "0 " * (d - 1), M)
T_w_ms = self.get_pairwise_coalescence_time(
self._ms_executable + cmd.split() + self.get_ms_seeds(), R)
T_w_msp = self.get_pairwise_coalescence_time(
self._mspms_executable + cmd.split() + self.get_ms_seeds(), R)
cmd = "2 {} -T -I {} 1 1 {} {}".format(R, d, "0 " * (d - 2), M)
T_b_ms = self.get_pairwise_coalescence_time(
self._ms_executable + cmd.split() + self.get_ms_seeds(), R)
T_b_msp = self.get_pairwise_coalescence_time(
self._mspms_executable + cmd.split() + self.get_ms_seeds(), R)
print(d, np.mean(T_w_ms), np.mean(T_w_msp), d / 2,
np.mean(T_b_ms), np.mean(T_b_msp), (d + (d - 1) / M) / 2,
sep="\t")
sm.graphics.qqplot(T_w_ms)
sm.qqplot_2samples(T_w_ms, T_w_msp, line="45")
f = os.path.join(basedir, "within_{}.png".format(d))
pyplot.savefig(f, dpi=72)
pyplot.close('all')
sm.graphics.qqplot(T_b_ms)
sm.qqplot_2samples(T_b_ms, T_b_msp, line="45")
f = os.path.join(basedir, "between_{}.png".format(d))
pyplot.savefig(f, dpi=72)
pyplot.close('all')
def get_segregating_sites_histogram(self, cmd):
print("\t", " ".join(cmd))
output = subprocess.check_output(cmd)
max_s = 200
hist = np.zeros(max_s)
for line in output.splitlines():
if line.startswith("segsites"):
s = int(line.split()[1])
if s <= max_s:
hist[s] += 1
return hist / np.sum(hist)
def get_S_distribution(self, k, n, theta):
"""
Returns the probability of having k segregating sites in a sample of
size n. Wakely pg 94.
"""
s = 0.0
for i in range(2, n + 1):
t1 = (-1)**i
t2 = scipy.special.binom(n - 1, i - 1)
t3 = (i - 1) / (theta + i - 1)
t4 = (theta / (theta + i - 1))**k
s += t1 * t2 * t3 * t4
return s
def run_s_analytical_check(self):
"""
Runs the check for the number of segregating sites against the
analytical prediction.
"""
R = 100000
theta = 2
max_s = 20
basedir = "tmp__NOBACKUP__/analytical_s"
if not os.path.exists(basedir):
os.mkdir(basedir)
for n in range(2, 15):
cmd = "{} {} -t {}".format(n, R, theta)
S_ms = self.get_segregating_sites_histogram(
self._ms_executable + cmd.split() + self.get_ms_seeds())
S_msp = self.get_segregating_sites_histogram(
self._mspms_executable + cmd.split() + self.get_ms_seeds())
filename = os.path.join(basedir, "{}.png".format(n))
fig, ax = pyplot.subplots()
index = np.arange(10)
S_analytical = [self.get_S_distribution(j, n, theta) for j in index]
bar_width = 0.35
rects1 = pyplot.bar(
index, S_ms[index], bar_width, color='b', label="ms")
rects2 = pyplot.bar(
index + bar_width, S_msp[index], bar_width, color='r', label="msp")
pyplot.plot(index + bar_width, S_analytical, "o", color='k')
pyplot.legend()
pyplot.xticks(index + bar_width, [str(j) for j in index])
pyplot.tight_layout()
pyplot.savefig(filename)
def run_pi_analytical_check(self):
"""
Runs the check for pi against analytical predictions.
"""
R = 100000
theta = 4.5
basedir = "tmp__NOBACKUP__/analytical_pi"
if not os.path.exists(basedir):
os.mkdir(basedir)
sample_size = np.arange(2, 15)
mean = np.zeros_like(sample_size, dtype=float)
var = np.zeros_like(sample_size, dtype=float)
predicted_mean = np.zeros_like(sample_size, dtype=float)
predicted_var = np.zeros_like(sample_size, dtype=float)
for k, n in enumerate(sample_size):
pi = np.zeros(R)
replicates = msprime.simulate(
sample_size=n,
mutation_rate=theta/4,
num_replicates=R)
for j, ts in enumerate(replicates):
pi[j] = ts.get_pairwise_diversity()
# Predicted mean is is theta.
predicted_mean[k] = theta
# From Wakely, eqn (4.14), pg. 101
predicted_var[k] = (
(n + 1) * theta / (3 * (n - 1)) +
2 * (n**2 + n + 3) * theta**2 / (9 * n * (n - 1)))
mean[k] = np.mean(pi)
var[k] = np.var(pi)
print(
n, theta, np.mean(pi), predicted_var[k], np.var(pi),
sep="\t")
filename = os.path.join(basedir, "mean.png")
pyplot.plot(sample_size, predicted_mean, "-")
pyplot.plot(sample_size, mean, "-")
pyplot.savefig(filename)
pyplot.close('all')
filename = os.path.join(basedir, "var.png")
pyplot.plot(sample_size, predicted_var, "-")
pyplot.plot(sample_size, var, "-")
pyplot.savefig(filename)
pyplot.close('all')
def get_tbl_distribution(self, n, R, executable):
"""
Returns an array of the R total branch length values from
the specified ms-like executable.
"""
cmd = executable + "{} {} -T -p 10".format(n, R).split()
cmd += self.get_ms_seeds()
print("\t", " ".join(cmd))
output = subprocess.check_output(cmd)
tbl = np.zeros(R)
j = 0
for line in output.splitlines():
if line.startswith("("):
t = dendropy.Tree.get_from_string(line, schema="newick")
tbl[j] = t.length()
j += 1
return tbl
def get_analytical_tbl(self, n, t):
"""
Returns the probabily density of the total branch length t with
a sample of n lineages. Wakeley Page 78.
"""
t1 = (n - 1) / 2
t2 = math.exp(-t / 2)
t3 = pow(1 - math.exp(-t / 2), n - 2)
return t1 * t2 * t3
def run_tbl_analytical_check(self):
"""
Runs the check for the total branch length.
"""
R = 10000
basedir = "tmp__NOBACKUP__/analytical_tbl"
if not os.path.exists(basedir):
os.mkdir(basedir)
for n in range(2, 15):
tbl_ms = self.get_tbl_distribution(n, R, self._ms_executable)
tbl_msp = self.get_tbl_distribution(n, R, self._mspms_executable)
sm.graphics.qqplot(tbl_ms)
sm.qqplot_2samples(tbl_ms, tbl_msp, line="45")
filename = os.path.join(basedir, "qqplot_{}.png".format(n))
pyplot.savefig(filename, dpi=72)
pyplot.close('all')
hist_ms, bin_edges = np.histogram(tbl_ms, 20, density=True)
hist_msp, _ = np.histogram(tbl_msp, bin_edges, density=True)
index = bin_edges[:-1]
# We don't seem to have the analytical value quite right here,
# but since the value is so very close to ms's, there doesn't
# seem to be much point in trying to fix it.
analytical = [self.get_analytical_tbl(n, x * 2) for x in index]
fig, ax = pyplot.subplots()
bar_width = 0.15
rects1 = pyplot.bar(
index, hist_ms, bar_width, color='b', label="ms")
rects2 = pyplot.bar(
index + bar_width, hist_msp, bar_width, color='r', label="msp")
pyplot.plot(index + bar_width, analytical, "o", color='k')
pyplot.legend()
# pyplot.xticks(index + bar_width, [str(j) for j in index])
pyplot.tight_layout()
filename = os.path.join(basedir, "hist_{}.png".format(n))
pyplot.savefig(filename)
def add_s_analytical_check(self):
"""
Adds a check for the analytical predictions about the distribution
of S, the number of segregating sites.
"""
self._instances["analytical_s"] = self.run_s_analytical_check
def add_pi_analytical_check(self):
"""
Adds a check for the analytical predictions about the pi,
the pairwise site diversity.
"""
self._instances["analytical_pi"] = self.run_pi_analytical_check
def add_total_branch_length_analytical_check(self):
"""
Adds a check for the analytical check for the total branch length.
"""
self._instances["analytical_tbl"] = self.run_tbl_analytical_check
def add_pairwise_island_model_analytical_check(self):
"""
Adds a check for the analytical check for pairwise island model
"""
self._instances[
"analytical_pairwise_island"] = self.run_pairwise_island_model
def add_random_instance(
self, key, num_populations=1, num_replicates=1000,
num_demographic_events=0):
m = random.randint(1, 1000)
r = random.uniform(0.01, 0.1) * m
theta = random.uniform(1, 100)
N = num_populations
sample_sizes = [random.randint(2, 10) for _ in range(N)]
migration_matrix = [
random.random() * (j % (N + 1) != 0) for j in range(N**2)]
structure = ""
if num_populations > 1:
structure = "-I {} {} -ma {}".format(
num_populations, " ".join(str(s) for s in sample_sizes),
" ".join(str(r) for r in migration_matrix))
cmd = "{} {} -t {} -r {} {} {}".format(
sum(sample_sizes), num_replicates, theta, r, m, structure)
if N > 1:
# Add some migration matrix changes
t = 0
for j in range(1, 6):
t += 0.125
u = random.random()
if u < 0.33:
cmd += " -eM {} {}".format(t, random.random())
elif u < 0.66:
j = random.randint(1, N)
k = j
while k == j:
k = random.randint(1, N)
r = random.random()
cmd += " -em {} {}".format(t, j, k, r)
else:
migration_matrix = [
random.random() * (j % (N + 1) != 0)
for j in range(N**2)]
cmd += " -ema {} {} {}".format(
t, N, " ".join(str(r) for r in migration_matrix))
# Set some initial growth rates, etc.
if N == 1:
if random.random() < 0.5:
cmd += " -G {}".format(random.random())
else:
cmd += " -eN 0 {}".format(random.random())
# Add some demographic events
t = 0
for j in range(num_demographic_events):
t += 0.125
if random.random() < 0.5:
cmd += " -eG {} {}".format(t, random.random())
else:
cmd += " -eN {} {}".format(t, random.random())
self.add_ms_instance(key, cmd)
def main():
# random.seed(2)
verifier = SimulationVerifier("tmp__NOBACKUP__")
# Try various options independently
verifier.add_ms_instance(
"size-change1", "10 10000 -t 2.0 -eN 0.1 2.0")
verifier.add_ms_instance(
"growth-rate-change1", "10 10000 -t 2.0 -eG 0.1 5.0")
verifier.add_ms_instance(
"growth-rate-2-pops1", "10 10000 -t 2.0 -I 2 5 5 2.5 -G 5.0")
verifier.add_ms_instance(
"growth-rate-2-pops2", "10 10000 -t 2.0 -I 2 5 5 2.5 -G 5.0 -g 1 0.1")
verifier.add_ms_instance(
"growth-rate-2-pops3", "10 10000 -t 2.0 -I 2 5 5 2.5 -g 1 0.1")
verifier.add_ms_instance(
"growth-rate-2-pops4", "10 10000 -t 2.0 -I 2 5 5 2.5 -eg 1.0 1 5.0")
verifier.add_ms_instance(
"pop-size-2-pops1", "100 10000 -t 2.0 -I 2 50 50 2.5 -n 1 0.1")
verifier.add_ms_instance(
"pop-size-2-pops2", "100 10000 -t 2.0 -I 2 50 50 2.5 -g 1 2 -n 1 0.1")
verifier.add_ms_instance(
"pop-size-2-pops3", "100 10000 -t 2.0 -I 2 50 50 2.5 -eN 0.5 3.5")
verifier.add_ms_instance(
"pop-size-2-pops4", "100 10000 -t 2.0 -I 2 50 50 2.5 -en 0.5 1 3.5")
verifier.add_ms_instance(
"migration-rate-2-pops1", "100 10000 -t 2.0 -I 2 50 50 0 -eM 3 5")
verifier.add_ms_instance(
"migration-matrix-2-pops1",
"100 10000 -t 2.0 -I 2 50 50 -ma x 10 0 x")
verifier.add_ms_instance(
"migration-matrix-2-pops2",
"100 10000 -t 2.0 -I 2 50 50 -m 1 2 10 -m 2 1 50")
verifier.add_ms_instance(
"migration-rate-change-2-pops1",
"100 10000 -t 2.0 -I 2 50 50 -eM 5 10")
verifier.add_ms_instance(
"migration-matrix-entry-change-2-pops1",
"100 10000 -t 2.0 -I 2 50 50 -em 0.5 2 1 10")
verifier.add_ms_instance(
"migration-matrix-change-2-pops1",
"100 10000 -t 2.0 -I 2 50 50 -ema 10.0 2 x 10 0 x")
verifier.add_ms_instance(
"migration-matrix-change-2-pops2",
"100 10000 -t 2.0 -I 2 50 50 -ema 1.0 2 x 0.1 0 x "
"-eN 1.1 0 -ema 10 2 x 0 10 x")
verifier.add_ms_instance(
"population-split-2-pops1",
"100 10000 -t 2.0 -I 2 50 50 5.0 -ej 2.0 1 2")
verifier.add_ms_instance(
"population-split-4-pops1",
"100 10000 -t 2.0 -I 4 50 50 0 0 2.0 -ej 0.5 2 1")
verifier.add_ms_instance(
"population-split-4-pops2",
"100 10000 -t 2.0 -I 4 25 25 25 25 -ej 1 2 1 -ej 2 3 1 -ej 3 4 1")
verifier.add_ms_instance(
"population-split-4-pops3", (
"100 10000 -t 2.0 -I 4 25 25 25 25 -ej 1 2 1 -em 1.5 4 1 2 "
"-ej 2 3 1 -ej 3 4 1"))
verifier.add_ms_instance(
"admixture-1-pop1", "1000 1000 -t 2.0 -es 0.1 1 0.5 -em 0.1 1 2 1")
verifier.add_ms_instance(
"admixture-1-pop2", "1000 1000 -t 2.0 -es 0.1 1 0.1 -em 0.1 1 2 1")
verifier.add_ms_instance(
"admixture-1-pop3", "1000 1000 -t 2.0 -es 0.01 1 0.1 -em 0.1 2 1 1")
verifier.add_ms_instance(
"admixture-1-pop4",
"1000 1000 -t 2.0 -es 0.01 1 0.1 -es 0.1 2 0 -em 0.1 3 1 1")
verifier.add_ms_instance(
"admixture-1-pop5",
"1000 1000 -t 2.0 -es 0.01 1 0.1 -ej 1 2 1")
verifier.add_ms_instance(
"admixture-1-pop6", "1000 1000 -t 2.0 -es 0.01 1 0.0 -eg 0.02 2 5.0 ")
verifier.add_ms_instance(
"admixture-1-pop7", "1000 1000 -t 2.0 -es 0.01 1 0.0 -en 0.02 2 5.0 ")
verifier.add_ms_instance(
"admixture-2-pop1",
"1000 1000 -t 2.0 -I 2 500 500 1 -es 0.01 1 0.1 -ej 1 3 1")
verifier.add_ms_instance(
"admixture-2-pop2",
"1000 1000 -t 2.0 -I 2 500 500 2 -es 0.01 1 0.75 -em 2.0 3 1 1")
verifier.add_ms_instance(
"admixture-2-pop3", (
"1000 1000 -t 2.0 -I 2 500 500 2 -es 0.01 1 0.75 -G 5.0 "
"-em 2.0 3 1 1"))
verifier.add_ms_instance(
"admixture-2-pop4", (
"1000 1000 -t 2.0 -I 2 500 500 2 -es 0.01 1 0.75 -eg 0.02 1 5.0 "
"-em 0.02 3 1 1"))
# Examples from ms documentation
verifier.add_ms_instance(
"msdoc-simple-ex", "4 20000 -t 5.0")
verifier.add_ms_instance(
"msdoc-recomb-ex", "15 1000 -t 10.04 -r 100.0 2501")
verifier.add_ms_instance(
"msdoc-structure-ex1", "15 1000 -t 2.0 -I 3 10 4 1 5.0")
verifier.add_ms_instance(
"msdoc-structure-ex2",
"15 1000 -t 2.0 -I 3 10 4 1 5.0 -m 1 2 10.0 -m 2 1 9.0")
verifier.add_ms_instance(
"msdoc-structure-ex3",
"15 1000 -t 10.0 -I 3 10 4 1 -ma x 1.0 2.0 3.0 x 4.0 5.0 6.0 x")
verifier.add_ms_instance(
"msdoc-outgroup-sequence", "11 1000 -t 2.0 -I 2 1 10 -ej 6.0 1 2")
verifier.add_ms_instance(
"msdoc-two-species", (
"15 10000 -t 11.2 -I 2 3 12 -g 1 44.36 -n 2 0.125 -eg 0.03125 1 0.0 "
"-en 0.0625 2 0.05 -ej 0.09375 2 1"))
verifier.add_ms_instance(
"msdoc-stepping-stone", (
"15 10000 -t 3.0 -I 6 0 7 0 0 8 0 -m 1 2 2.5 -m 2 1 2.5 -m 2 3 2.5 "
"-m 3 2 2.5 -m 4 5 2.5 -m 5 4 2.5 -m 5 6 2.5 -m 6 5 2.5 -em 2.0 3 4 "
"2.5 -em 2.0 4 3 2.5"))
# The order of simultaneous events matters in ms.
verifier.add_ms_instance(
"simultaneous-ex1", "10 10000 -t 2.0 -eN 0.3 0.5 -eG .3 7.0")
# Add a bunch more instances...
verifier.add_ms_instance(
"zero-growth-rate", "10 10000 -t 2.0 -G 6.93 -eG 0.2 0.0 -eN 0.3 0.5")
# Some examples provided by Konrad Lohse
verifier.add_ms_instance(
"konrad-1", (
"4 1000 -t 2508 -I 2 2 2 0 -n 2 2.59 -ma x 0 1.502 x -ej 0.9485 1 2 "
"-r 23.76 3000"))
verifier.add_ms_instance(
"konrad-2", (
"3 10000 -t 0.423 -I 3 1 1 1 -es 0.0786 1 0.946635 -ej 0.0786 4 3 "
"-ej 0.189256 1 2 -ej 0.483492 2 3"))
verifier.add_ms_instance(
"konrad-3", (
"100 100 -t 2 -I 10 10 10 10 10 10 10 10 10 10 10 0.001 "))
# Add some random instances.
verifier.add_random_instance("random1")
verifier.add_random_instance(
"random2", num_replicates=10**4, num_demographic_events=10)
# verifier.add_random_instance("random2", num_populations=3)
# Add analytical checks
verifier.add_s_analytical_check()
verifier.add_pi_analytical_check()
verifier.add_total_branch_length_analytical_check()
verifier.add_pairwise_island_model_analytical_check()
keys = None
if len(sys.argv) > 1:
keys = sys.argv[1:]
verifier.run(keys)
if __name__ == "__main__":
main()