https://github.com/tskit-dev/msprime
Raw File
Tip revision: becc7b948123f8683c49ed41480ca2682d979a7f authored by Yan Wong on 09 December 2022, 18:50:12 UTC
Update docs/mutations.md
Tip revision: becc7b9
algorithms.py
"""
Python version of the simulation algorithm.
"""
import argparse
import heapq
import logging
import math
import random
import sys

import bintrees
import daiquiri
import numpy as np
import tskit

import msprime


logger = daiquiri.getLogger()


class FenwickTree:
    """
    A Fenwick Tree to represent cumulative frequency tables over
    integers. Each index from 1 to max_index initially has a
    zero frequency.

    This is an implementation of the Fenwick tree (also known as a Binary
    Indexed Tree) based on "A new data structure for cumulative frequency
    tables", Software Practice and Experience, Vol 24, No 3, pp 327 336 Mar
    1994. This implementation supports any non-negative frequencies, and the
    search procedure always returns the smallest index such that its cumulative
    frequency <= f. This search procedure is a slightly modified version of
    that presented in Tech Report 110, "A new data structure for cumulative
    frequency tables: an improved frequency-to-symbol algorithm." available at
    https://www.cs.auckland.ac.nz/~peter-f/FTPfiles/TechRep110.ps
    """

    def __init__(self, max_index):
        assert max_index > 0
        self.__max_index = max_index
        self.__tree = [0 for j in range(max_index + 1)]
        self.__value = [0 for j in range(max_index + 1)]
        # Compute the binary logarithm of max_index
        u = self.__max_index
        while u != 0:
            self.__log_max_index = u
            u -= u & -u

    def get_total(self):
        """
        Returns the total cumulative frequency over all indexes.
        """
        return self.get_cumulative_sum(self.__max_index)

    def increment(self, index, v):
        """
        Increments the frequency of the specified index by the specified
        value.
        """
        assert 0 < index <= self.__max_index
        self.__value[index] += v
        j = index
        while j <= self.__max_index:
            self.__tree[j] += v
            j += j & -j

    def set_value(self, index, v):
        """
        Sets the frequency at the specified index to the specified value.
        """
        f = self.get_value(index)
        self.increment(index, v - f)

    def get_cumulative_sum(self, index):
        """
        Returns the cumulative frequency of the specified index.
        """
        assert 0 < index <= self.__max_index
        j = index
        s = 0
        while j > 0:
            s += self.__tree[j]
            j -= j & -j
        return s

    def get_value(self, index):
        """
        Returns the frequency of the specified index.
        """
        return self.__value[index]

    def find(self, v):
        """
        Returns the smallest index with cumulative sum >= v.
        """
        j = 0
        s = v
        half = self.__log_max_index
        while half > 0:
            # Skip non-existant entries
            while j + half > self.__max_index:
                half >>= 1
            k = j + half
            if s > self.__tree[k]:
                j = k
                s -= self.__tree[j]
            half >>= 1
        return j + 1


class Segment:
    """
    A class representing a single segment. Each segment has a left
    and right, denoting the loci over which it spans, a node and a
    next, giving the next in the chain.
    """

    def __init__(self, index):
        self.left = None
        self.right = None
        self.node = None
        self.prev = None
        self.next = None
        self.population = None
        self.label = 0
        self.index = index

    def __repr__(self):
        return repr((self.left, self.right, self.node))

    @staticmethod
    def show_chain(seg):
        s = ""
        while seg is not None:
            s += f"[{seg.left}, {seg.right}: {seg.node}], "
            seg = seg.next
        return s[:-2]

    def __lt__(self, other):
        return (self.left, self.right, self.population, self.node) < (
            other.left,
            other.right,
            other.population,
            self.node,
        )


class Population:
    """
    Class representing a population in the simulation.
    """

    def __init__(self, id_, num_labels=1):
        self.id = id_
        self.start_time = 0
        self.start_size = 1.0
        self.growth_rate = 0
        # Keep a list of each label.
        # We'd like to use AVLTrees here for P but the API doesn't quite
        # do what we need. Lists are inefficient here and should not be
        # used in a real implementation.
        self._ancestors = [[] for _ in range(num_labels)]

    def print_state(self):
        print("Population ", self.id)
        print("\tstart_size = ", self.start_size)
        print("\tgrowth_rate = ", self.growth_rate)
        print("\tAncestors: ", len(self._ancestors))
        for label, ancestors in enumerate(self._ancestors):
            print("\tLabel = ", label)
            for u in ancestors:
                print("\t\t" + Segment.show_chain(u))

    def set_growth_rate(self, growth_rate, time):
        # TODO This doesn't work because we need to know what the time
        # is so we can set the start size accordingly. Need to look at
        # ms's model carefully to see what it actually does here.
        new_size = self.get_size(time)
        self.start_size = new_size
        self.start_time = time
        self.growth_rate = growth_rate

    def set_start_size(self, start_size):
        self.start_size = start_size
        self.growth_rate = 0

    def get_num_ancestors(self, label=None):
        if label is None:
            return sum(len(label_ancestors) for label_ancestors in self._ancestors)
        else:
            return len(self._ancestors[label])

    def get_size(self, t):
        """
        Returns the size of this population at time t.
        """
        dt = t - self.start_time
        return self.start_size * math.exp(-self.growth_rate * dt)

    def get_common_ancestor_waiting_time(self, t):
        """
        Returns the random waiting time until a common ancestor event
        occurs within this population.
        """
        ret = sys.float_info.max
        k = self.get_num_ancestors()
        if k > 1:
            u = random.expovariate(k * (k - 1))
            if self.growth_rate == 0:
                ret = self.start_size * u
            else:
                dt = t - self.start_time
                z = (
                    1
                    + self.growth_rate
                    * self.start_size
                    * math.exp(-self.growth_rate * dt)
                    * u
                )
                if z > 0:
                    ret = math.log(z) / self.growth_rate
        return ret

    def get_ind_range(self, t):
        """Returns ind labels at time t"""
        first_ind = np.sum([self.get_size(t_prev) for t_prev in range(0, int(t))])
        last_ind = first_ind + self.get_size(t)

        return range(int(first_ind), int(last_ind) + 1)

    def remove(self, index, label=0):
        """
        Removes and returns the individual at the specified index.
        """
        return self._ancestors[label].pop(index)

    def remove_individual(self, individual, label=0):
        """
        Removes the given individual from its population.
        """
        return self._ancestors[label].remove(individual)

    def add(self, individual, label=0):
        """
        Inserts the specified individual into this population.
        """
        assert individual.label == label
        self._ancestors[label].append(individual)

    def __iter__(self):
        # will default to label 0
        # inter_label() extends behavior
        return iter(self._ancestors[0])

    def iter_label(self, label):
        """
        Iterates ancestors in popn from a label
        """
        return iter(self._ancestors[label])

    def iter_ancestors(self):
        """
        Iterates over all ancestors in a population over all labels.
        """
        for ancestors in self._ancestors:
            yield from ancestors

    def find_indv(self, indv):
        """
        find the index of an ancestor in population
        """
        return self._ancestors[indv.label].index(indv)


class Pedigree:
    """
    Class representing a pedigree for use with the DTWF model, as implemented
    in C library
    """

    def __init__(self, tables):
        self.ploidy = 2
        self.individuals = []

        ts = tables.tree_sequence()
        for tsk_ind in ts.individuals():
            assert len(tsk_ind.nodes) == self.ploidy
            assert len(tsk_ind.parents) == self.ploidy
            time = ts.node(tsk_ind.nodes[0]).time
            # All nodes must be equivalent
            assert len({ts.node(node).flags for node in tsk_ind.nodes}) == 1
            assert len({ts.node(node).time for node in tsk_ind.nodes}) == 1
            assert len({ts.node(node).population for node in tsk_ind.nodes}) == 1
            ind = Individual(
                tsk_ind.id,
                ploidy=self.ploidy,
                nodes=list(tsk_ind.nodes),
                parents=list(tsk_ind.parents),
                time=time,
            )
            assert ind.id == len(self.individuals)
            self.individuals.append(ind)

    def print_state(self):
        print("Pedigree")
        print("-------")
        print("Individuals = ")
        for ind in self.individuals:
            print("\t", ind)
        print("-------")


class Individual:
    """
    Class representing a diploid individual in the DTWF pedigree model.
    """

    def __init__(self, id_, *, ploidy, nodes, parents, time):
        self.id = id_
        self.ploidy = ploidy
        self.nodes = nodes
        self.parents = parents
        self.time = time
        self.common_ancestors = [[] for i in range(ploidy)]

    def __str__(self):
        return (
            f"(ID: {self.id}, time: {self.time}, "
            + f"parents: {self.parents}, nodes: {self.nodes}, "
            + f"common_ancestors: {self.common_ancestors})"
        )

    def add_common_ancestor(self, head, ploid):
        """
        Adds the specified ancestor (represented by the head of a segment
        chain) to the list of ancestors that find a common ancestor in
        the specified ploid of this individual.
        """
        heapq.heappush(self.common_ancestors[ploid], (head.left, head))


class TrajectorySimulator:
    """
    Class to simulate an allele frequency trajectory on which to condition
    the coalescent simulation.
    """

    def __init__(self, initial_freq, end_freq, alpha, time_slice):
        self._initial_freq = initial_freq
        self._end_freq = end_freq
        self._alpha = alpha
        self._time_slice = time_slice
        self._reset()

    def _reset(self):
        self._allele_freqs = []
        self._times = []

    def _genic_selection_stochastic_forwards(self, dt, freq, alpha):
        ux = (alpha * freq * (1 - freq)) / np.tanh(alpha * freq)
        sign = 1 if random.random() < 0.5 else -1
        freq += (ux * dt) + sign * np.sqrt(freq * (1.0 - freq) * dt)
        return freq

    def _simulate(self):
        """
        Proposes a sweep trajectory and returns the acceptance probability.
        """
        x = self._end_freq  # backward time
        current_size = 1
        t_inc = self._time_slice
        t = 0
        while x > self._initial_freq:
            self._allele_freqs.append(max(x, self._initial_freq))
            self._times.append(t)
            # just a note below
            # current_size = self._size_calculator(t)
            #
            x = 1.0 - self._genic_selection_stochastic_forwards(
                t_inc, 1.0 - x, self._alpha * current_size
            )
            t += self._time_slice
        # will want to return current_size / N_max
        # for prototype this always equals 1
        return 1

    def run(self):
        while random.random() > self._simulate():
            self.reset()
        return self._allele_freqs, self._times


class RateMap:
    def __init__(self, positions, rates):
        self.positions = positions
        self.rates = rates
        self.cumulative = RateMap.recomb_mass(positions, rates)

    @staticmethod
    def recomb_mass(positions, rates):
        recomb_mass = 0
        cumulative = [recomb_mass]
        for i in range(1, len(positions)):
            recomb_mass += (positions[i] - positions[i - 1]) * rates[i - 1]
            cumulative.append(recomb_mass)
        return cumulative

    @property
    def sequence_length(self):
        return self.positions[-1]

    @property
    def total_mass(self):
        return self.cumulative[-1]

    @property
    def mean_rate(self):
        return self.total_mass / self.sequence_length

    def mass_between(self, left, right):
        left_mass = self.position_to_mass(left)
        right_mass = self.position_to_mass(right)
        return right_mass - left_mass

    def position_to_mass(self, pos):
        if pos == self.positions[0]:
            return 0
        if pos >= self.positions[-1]:
            return self.cumulative[-1]

        index = self._search(self.positions, pos)
        assert index > 0
        index -= 1
        offset = pos - self.positions[index]
        return self.cumulative[index] + offset * self.rates[index]

    def mass_to_position(self, recomb_mass):
        if recomb_mass == 0:
            return 0
        index = self._search(self.cumulative, recomb_mass)
        assert index > 0
        index -= 1
        mass_in_interval = recomb_mass - self.cumulative[index]
        pos = self.positions[index] + (mass_in_interval / self.rates[index])
        return pos

    def shift_by_mass(self, pos, mass):
        result_mass = self.position_to_mass(pos) + mass
        return self.mass_to_position(result_mass)

    def _search(self, values, query):
        left = 0
        right = len(values) - 1
        while left < right:
            m = (left + right) // 2
            if values[m] < query:
                left = m + 1
            else:
                right = m
        return left


class OverlapCounter:
    def __init__(self, seq_length):
        self.seq_length = seq_length
        self.overlaps = self._make_segment(0, seq_length, 0)

    def overlaps_at(self, pos):
        assert 0 <= pos < self.seq_length
        curr_interval = self.overlaps
        while curr_interval is not None:
            if curr_interval.left <= pos < curr_interval.right:
                return curr_interval.node
            curr_interval = curr_interval.next
        raise ValueError("Bad overlap count chain")

    def increment_interval(self, left, right):
        """
        Increment the count that spans the interval
        [left, right), creating additional intervals in overlaps
        if necessary.
        """
        curr_interval = self.overlaps
        while left < right:
            if curr_interval.left == left:
                if curr_interval.right <= right:
                    curr_interval.node += 1
                    left = curr_interval.right
                    curr_interval = curr_interval.next
                else:
                    self._split(curr_interval, right)
                    curr_interval.node += 1
                    break
            else:
                if curr_interval.right < left:
                    curr_interval = curr_interval.next
                else:
                    self._split(curr_interval, left)
                    curr_interval = curr_interval.next

    def _split(self, seg, bp):  # noqa: A002
        """
        Split the segment at breakpoint and add in another segment
        from breakpoint to seg.right. Set the original segment's
        right endpoint to breakpoint
        """
        right = self._make_segment(bp, seg.right, seg.node)
        if seg.next is not None:
            seg.next.prev = right
            right.next = seg.next
        right.prev = seg
        seg.next = right
        seg.right = bp

    def _make_segment(self, left, right, count):
        seg = Segment(0)
        seg.left = left
        seg.right = right
        seg.node = count
        return seg


class Simulator:
    """
    A reference implementation of the multi locus simulation algorithm.
    """

    def __init__(
        self,
        *,
        tables,
        recombination_map,
        migration_matrix,
        population_growth_rates,
        population_sizes,
        population_growth_rate_changes,
        population_size_changes,
        migration_matrix_element_changes,
        bottlenecks,
        census_times,
        model="hudson",
        max_segments=100,
        num_labels=1,
        sweep_trajectory=None,
        full_arg=False,
        store_unary=False,
        time_slice=None,
        gene_conversion_rate=0.0,
        gene_conversion_length=1,
        discrete_genome=True,
    ):
        # Must be a square matrix.
        N = len(migration_matrix)
        assert len(tables.populations) == N
        assert len(population_growth_rates) == N
        assert len(population_sizes) == N
        for j in range(N):
            assert N == len(migration_matrix[j])
            assert migration_matrix[j][j] == 0
        assert gene_conversion_length >= 1

        self.tables = tables
        self.model = model
        self.L = tables.sequence_length
        self.recomb_map = recombination_map
        self.gc_map = RateMap([0, self.L], [gene_conversion_rate, 0])
        self.tract_length = gene_conversion_length
        self.discrete_genome = discrete_genome
        self.migration_matrix = migration_matrix
        self.num_labels = num_labels
        self.num_populations = N
        self.max_segments = max_segments
        self.full_arg = full_arg
        self.store_unary = store_unary
        self.pedigree = None
        self.segment_stack = []
        self.segments = [None for j in range(self.max_segments + 1)]
        for j in range(self.max_segments):
            s = Segment(j + 1)
            self.segments[j + 1] = s
            self.segment_stack.append(s)
        self.P = [Population(id_, num_labels) for id_ in range(N)]
        if self.recomb_map.total_mass == 0:
            self.recomb_mass_index = None
        else:
            self.recomb_mass_index = [
                FenwickTree(self.max_segments) for j in range(num_labels)
            ]
        if self.gc_map.total_mass == 0:
            self.gc_mass_index = None
        else:
            self.gc_mass_index = [
                FenwickTree(self.max_segments) for j in range(num_labels)
            ]
        self.S = bintrees.AVLTree()
        for pop in self.P:
            pop.set_start_size(population_sizes[pop.id])
            pop.set_growth_rate(population_growth_rates[pop.id], 0)
        self.edge_buffer = []

        self.initialise(tables.tree_sequence())

        self.num_ca_events = 0
        self.num_re_events = 0
        self.num_gc_events = 0

        # Sweep variables
        self.sweep_site = (self.L // 2) - 1  # need to add options here
        self.sweep_trajectory = sweep_trajectory
        self.time_slice = time_slice

        self.modifier_events = [(sys.float_info.max, None, None)]
        for time, pop_id, new_size in population_size_changes:
            self.modifier_events.append(
                (time, self.change_population_size, (int(pop_id), new_size))
            )
        for time, pop_id, new_rate in population_growth_rate_changes:
            self.modifier_events.append(
                (
                    time,
                    self.change_population_growth_rate,
                    (int(pop_id), new_rate, time),
                )
            )
        for time, pop_i, pop_j, new_rate in migration_matrix_element_changes:
            self.modifier_events.append(
                (
                    time,
                    self.change_migration_matrix_element,
                    (int(pop_i), int(pop_j), new_rate),
                )
            )
        for time, pop_id, intensity in bottlenecks:
            self.modifier_events.append(
                (time, self.bottleneck_event, (int(pop_id), 0, intensity))
            )
        for time in census_times:
            self.modifier_events.append((time[0], self.census_event, time))
        self.modifier_events.sort()

    def initialise(self, ts):
        root_time = np.max(self.tables.nodes.time)
        self.t = root_time

        root_segments_head = [None for _ in range(ts.num_nodes)]
        root_segments_tail = [None for _ in range(ts.num_nodes)]
        last_S = -1
        for tree in ts.trees():
            left, right = tree.interval
            S = 0 if tree.num_roots == 1 else tree.num_roots
            if S != last_S:
                self.S[left] = S
                last_S = S
            # If we have 1 root this is a special case and we don't add in
            # any ancestral segments to the state.
            if tree.num_roots > 1:
                for root in tree.roots:
                    population = ts.node(root).population
                    if root_segments_head[root] is None:
                        seg = self.alloc_segment(left, right, root, population)
                        root_segments_head[root] = seg
                        root_segments_tail[root] = seg
                    else:
                        tail = root_segments_tail[root]
                        if tail.right == left:
                            tail.right = right
                        else:
                            seg = self.alloc_segment(
                                left, right, root, population, tail
                            )
                            tail.next = seg
                            root_segments_tail[root] = seg
        self.S[self.L] = -1

        # Insert the segment chains into the algorithm state.
        for node in range(ts.num_nodes):
            seg = root_segments_head[node]
            if seg is not None:
                self.P[seg.population].add(seg)
                while seg is not None:
                    self.set_segment_mass(seg)
                    seg = seg.next

    def ancestors_remain(self):
        """
        Returns True if the simulation is not finished, i.e., there is some ancestral
        material that has not fully coalesced.
        """
        return sum(pop.get_num_ancestors() for pop in self.P) != 0

    def change_population_size(self, pop_id, size):
        self.P[pop_id].set_start_size(size)

    def change_population_growth_rate(self, pop_id, rate, time):
        self.P[pop_id].set_growth_rate(rate, time)

    def change_migration_matrix_element(self, pop_i, pop_j, rate):
        self.migration_matrix[pop_i][pop_j] = rate

    def alloc_segment(
        self,
        left,
        right,
        node,
        population,
        prev=None,
        next=None,  # noqa: A002
        label=0,
    ):
        """
        Pops a new segment off the stack and sets its properties.
        """
        s = self.segment_stack.pop()
        s.left = left
        s.right = right
        s.node = node
        s.population = population
        s.next = next
        s.prev = prev
        s.label = label
        return s

    def copy_segment(self, segment):
        return self.alloc_segment(
            left=segment.left,
            right=segment.right,
            node=segment.node,
            population=segment.population,
            next=segment.next,
            prev=segment.prev,
            label=segment.label,
        )

    def free_segment(self, u):
        """
        Frees the specified segment making it ready for reuse and
        setting its weight to zero.
        """
        if self.recomb_mass_index is not None:
            self.recomb_mass_index[u.label].set_value(u.index, 0)
        if self.gc_mass_index is not None:
            self.gc_mass_index[u.label].set_value(u.index, 0)
        self.segment_stack.append(u)

    def store_node(self, population, flags=0):
        self.flush_edges()
        return self.tables.nodes.add_row(
            time=self.t, flags=flags, population=population
        )

    def flush_edges(self):
        """
        Flushes the edges in the edge buffer to the table, squashing any adjacent edges.
        """
        if len(self.edge_buffer) > 0:
            self.edge_buffer.sort(key=lambda e: (e.child, e.left))
            left = self.edge_buffer[0].left
            right = self.edge_buffer[0].right
            child = self.edge_buffer[0].child
            parent = self.edge_buffer[0].parent
            for e in self.edge_buffer[1:]:
                assert e.parent == parent
                if e.left != right or e.child != child:
                    self.tables.edges.add_row(left, right, parent, child)
                    left = e.left
                    child = e.child
                right = e.right
            self.tables.edges.add_row(left, right, parent, child)
            self.edge_buffer = []

    def store_edge(self, left, right, parent, child):
        """
        Stores the specified edge to the output tree sequence.
        """
        self.edge_buffer.append(
            tskit.Edge(left=left, right=right, parent=parent, child=child)
        )

    def finalise(self):
        """
        Finalises the simulation returns an msprime tree sequence object.
        """
        self.flush_edges()

        # Insert unary edges for any remainining lineages.
        current_time = self.t
        for population in self.P:
            for ancestor in population.iter_ancestors():
                node = tskit.NULL
                # See if there is already a node in this ancestor at the
                # current time
                seg = ancestor
                while seg is not None:
                    if self.tables.nodes[seg.node].time == current_time:
                        node = seg.node
                        break
                    seg = seg.next
                if node == tskit.NULL:
                    # Add a new node for the current ancestor
                    node = self.tables.nodes.add_row(
                        flags=0, time=current_time, population=population.id
                    )
                # Add in edges pointing to this ancestor
                seg = ancestor
                while seg is not None:
                    if seg.node != node:
                        self.tables.edges.add_row(seg.left, seg.right, node, seg.node)
                    seg = seg.next

        # Need to work around limitations in tskit Python API to prevent
        # individuals from getting unsorted:
        # https://github.com/tskit-dev/tskit/issues/1726
        ind_col = self.tables.nodes.individual
        ind_table = self.tables.individuals.copy()
        self.tables.sort()
        self.tables.individuals.clear()
        for ind in ind_table:
            self.tables.individuals.append(ind)
        self.tables.nodes.individual = ind_col
        return self.tables.tree_sequence()

    def simulate(self, end_time):
        self.verify()
        if self.model == "hudson":
            self.hudson_simulate(end_time)
        elif self.model == "dtwf":
            self.dtwf_simulate()
        elif self.model == "fixed_pedigree":
            self.pedigree_simulate()
        elif self.model == "single_sweep":
            self.single_sweep_simulate()
        else:
            print("Error: bad model specification -", self.model)
            raise ValueError
        return self.finalise()

    def get_potential_destinations(self):
        """
        For each population return the set of populations for which it has a
        non-zero migration into.
        """
        N = len(self.P)
        potential_destinations = [set() for _ in range(N)]
        for j in range(N):
            for k in range(N):
                if self.migration_matrix[j][k] > 0:
                    potential_destinations[j].add(k)
        return potential_destinations

    def get_total_recombination_rate(self, label):
        total_rate = 0
        if self.recomb_mass_index is not None:
            total_rate = self.recomb_mass_index[label].get_total()
        return total_rate

    def get_total_gc_rate(self, label):
        total_rate = 0
        if self.gc_mass_index is not None:
            total_rate = self.gc_mass_index[label].get_total()
        return total_rate

    def get_total_gc_left_rate(self, label):
        gc_left_total = self.get_total_gc_left(label)
        return gc_left_total

    def get_total_gc_left(self, label):
        gc_left_total = 0
        num_ancestors = sum(pop.get_num_ancestors() for pop in self.P)
        mean_gc_rate = self.gc_map.mean_rate
        gc_left_total = num_ancestors * mean_gc_rate * self.tract_length
        return gc_left_total

    def find_cleft_individual(self, label, cleft_value):
        mean_gc_rate = self.gc_map.mean_rate
        individual_index = math.floor(cleft_value / (mean_gc_rate * self.tract_length))
        for pop in self.P:
            num_ancestors = pop.get_num_ancestors()
            if individual_index < num_ancestors:
                return pop._ancestors[label][individual_index]
            individual_index -= num_ancestors
        raise AssertionError()

    def hudson_simulate(self, end_time):
        """
        Simulates the algorithm until all loci have coalesced.
        """
        infinity = sys.float_info.max
        non_empty_pops = {pop.id for pop in self.P if pop.get_num_ancestors() > 0}
        potential_destinations = self.get_potential_destinations()

        # only worried about label 0 below
        while len(non_empty_pops) > 0:
            self.verify()
            if self.t >= end_time:
                break
            # self.print_state()
            re_rate = self.get_total_recombination_rate(label=0)
            t_re = infinity
            if re_rate > 0:
                t_re = random.expovariate(re_rate)

            # Gene conversion can occur within segments ..
            gc_rate = self.get_total_gc_rate(label=0)
            t_gcin = infinity
            if gc_rate > 0:
                t_gcin = random.expovariate(gc_rate)
            # ... or to the left of the first segment.
            gc_left_rate = self.get_total_gc_left_rate(label=0)
            t_gc_left = infinity
            if gc_left_rate > 0:
                t_gc_left = random.expovariate(gc_left_rate)

            # Common ancestor events occur within demes.
            t_ca = infinity
            for index in non_empty_pops:
                pop = self.P[index]
                assert pop.get_num_ancestors() > 0
                t = pop.get_common_ancestor_waiting_time(self.t)
                if t < t_ca:
                    t_ca = t
                    ca_population = index
            t_mig = infinity
            # Migration events happen at the rates in the matrix.
            for j in non_empty_pops:
                source_size = self.P[j].get_num_ancestors()
                assert source_size > 0
                # for k in range(len(self.P)):
                for k in potential_destinations[j]:
                    rate = source_size * self.migration_matrix[j][k]
                    assert rate > 0
                    t = random.expovariate(rate)
                    if t < t_mig:
                        t_mig = t
                        mig_source = j
                        mig_dest = k
            min_time = min(t_re, t_ca, t_gcin, t_gc_left, t_mig)
            assert min_time != infinity
            if self.t + min_time > self.modifier_events[0][0]:
                t, func, args = self.modifier_events.pop(0)
                self.t = t
                func(*args)
                # Don't bother trying to maintain the non-zero lists
                # through demographic events, just recompute them.
                non_empty_pops = {
                    pop.id for pop in self.P if pop.get_num_ancestors() > 0
                }
                potential_destinations = self.get_potential_destinations()
                event = "MOD"
            else:
                self.t += min_time
                if min_time == t_re:
                    event = "RE"
                    self.hudson_recombination_event(0)
                elif min_time == t_gcin:
                    event = "GCI"
                    self.wiuf_gene_conversion_within_event(0)
                elif min_time == t_gc_left:
                    event = "GCL"
                    self.wiuf_gene_conversion_left_event(0)
                elif min_time == t_ca:
                    event = "CA"
                    self.common_ancestor_event(ca_population, 0)
                    if self.P[ca_population].get_num_ancestors() == 0:
                        non_empty_pops.remove(ca_population)
                else:
                    event = "MIG"
                    self.migration_event(mig_source, mig_dest)
                    if self.P[mig_source].get_num_ancestors() == 0:
                        non_empty_pops.remove(mig_source)
                    assert self.P[mig_dest].get_num_ancestors() > 0
                    non_empty_pops.add(mig_dest)
            logger.info(
                "%s time=%f n=%d",
                event,
                self.t,
                sum(pop.get_num_ancestors() for pop in self.P),
            )

            X = {pop.id for pop in self.P if pop.get_num_ancestors() > 0}
            assert non_empty_pops == X

    def single_sweep_simulate(self):
        """
        Does a structed coalescent until end_freq is reached, using
        information in self.weep_trajectory.

        """
        allele_freqs, times = self.sweep_trajectory
        sweep_traj_step = 0
        x = allele_freqs[sweep_traj_step]

        assert self.num_populations == 1

        # go through segments and assign labels
        # a bit ugly with the two loops because
        # of dealing with the pops
        indices = []
        for idx, u in enumerate(self.P[0].iter_label(0)):
            if random.random() < x:
                self.set_labels(u, 1)
                indices.append(idx)
            else:
                assert u.label == 0
        popped = 0
        for i in indices:
            tmp = self.P[0].remove(i - popped, 0)
            popped += 1
            self.P[0].add(tmp, 1)

        # main loop time
        t_inc_orig = self.time_slice
        e_time = 0.0
        while self.ancestors_remain() and sweep_traj_step < len(times) - 1:
            self.verify()
            event_prob = 1.0
            while event_prob > random.random() and sweep_traj_step < len(times) - 1:
                sweep_traj_step += 1
                x = allele_freqs[sweep_traj_step]
                e_time += times[sweep_traj_step]
                # self.t = self.t + times[sweep_traj_step]
                sweep_pop_sizes = [
                    self.P[0].get_num_ancestors(label=0),
                    self.P[0].get_num_ancestors(label=1),
                ]
                p_rec_b = self.get_total_recombination_rate(0) * t_inc_orig
                p_rec_B = self.get_total_recombination_rate(1) * t_inc_orig

                # JK NOTE: We should probably factor these pop size calculations
                # into a method in Population like get_common_ancestor_waiting_time().
                # That way we can handle exponentially growing populations as well?
                p_coal_b = (
                    (sweep_pop_sizes[0] * (sweep_pop_sizes[0] - 1))
                    / (1.0 - x)
                    * t_inc_orig
                    / self.P[0].start_size
                )
                p_coal_B = (
                    (sweep_pop_sizes[1] * (sweep_pop_sizes[1] - 1))
                    / x
                    * t_inc_orig
                    / self.P[0].start_size
                )
                sweep_pop_tot_rate = p_rec_b + p_rec_B + p_coal_b + p_coal_B

                total_rate = sweep_pop_tot_rate
                if total_rate == 0:
                    break
                event_prob *= 1.0 - total_rate

            if total_rate == 0:
                break
            if self.t + e_time > self.modifier_events[0][0]:
                t, func, args = self.modifier_events.pop(0)
                self.t = t
                func(*args)
            else:
                self.t += e_time
                # choose which event happened
                if random.random() < sweep_pop_tot_rate / total_rate:
                    # even in sweeping pop, choose which kind
                    r = random.random()
                    e_sum = p_coal_B
                    if r < e_sum / sweep_pop_tot_rate:
                        # coalescent in B
                        self.common_ancestor_event(0, 1)
                    else:
                        e_sum += p_coal_b
                        if r < e_sum / sweep_pop_tot_rate:
                            # coalescent in b
                            self.common_ancestor_event(0, 0)
                        else:
                            e_sum += p_rec_B
                            if r < e_sum / sweep_pop_tot_rate:
                                # recomb in B
                                self.hudson_recombination_event_sweep_phase(
                                    1, self.sweep_site, x
                                )
                            else:
                                # recomb in b
                                self.hudson_recombination_event_sweep_phase(
                                    0, self.sweep_site, 1.0 - x
                                )
        # clean up the labels at end
        for idx, u in enumerate(self.P[0].iter_label(1)):
            tmp = self.P[0].remove(idx, u.label)
            self.set_labels(u, 0)
            self.P[0].add(tmp)

    def pedigree_simulate(self):
        """
        Simulates through the provided pedigree, stopping at the top.
        """
        self.pedigree = Pedigree(self.tables)
        self.dtwf_climb_pedigree()

    def dtwf_simulate(self):
        """
        Simulates the algorithm until all loci have coalesced.
        """
        while self.ancestors_remain():
            self.t += 1
            self.verify()
            self.dtwf_generation()

    def dtwf_generation(self):
        """
        Evolves one generation of a Wright Fisher population
        """
        for pop_idx, pop in enumerate(self.P):
            # Cluster haploid inds by parent
            parent_inds = pop.get_ind_range(self.t)
            offspring = bintrees.AVLTree()
            for anc in pop.iter_label(0):
                parent_index = np.random.choice(parent_inds.stop - parent_inds.start)
                parent = parent_inds.start + parent_index
                if parent not in offspring:
                    offspring[parent] = []
                offspring[parent].append(anc)

            # Draw recombinations in children and sort segments by
            # inheritance direction
            for children in offspring.values():
                H = [[], []]
                for child in children:
                    segs_pair = self.dtwf_recombine(child)
                    for seg in segs_pair:
                        if seg is not None and seg.index != child.index:
                            pop.add(seg)

                    self.verify()
                    # Collect segments inherited from the same individual
                    for i, seg in enumerate(segs_pair):
                        if seg is None:
                            continue
                        assert seg.prev is None
                        heapq.heappush(H[i], (seg.left, seg))

                # Merge segments
                for h in H:
                    segments_to_merge = len(h)
                    if segments_to_merge == 1:
                        h = []
                    elif segments_to_merge >= 2:
                        for _, individual in h:
                            pop.remove_individual(individual)
                        if segments_to_merge == 2:
                            self.merge_two_ancestors(pop_idx, 0, h[0][1], h[1][1])
                        else:
                            self.merge_ancestors(h, pop_idx, 0)  # label 0 only
            self.verify()

        # Migration events happen at the rates in the matrix.
        for j in range(len(self.P)):
            source_size = self.P[j].get_num_ancestors()
            for k in range(len(self.P)):
                if j == k:
                    continue
                mig_rate = source_size * self.migration_matrix[j][k]
                num_migs = min(source_size, np.random.poisson(mig_rate))
                for _ in range(num_migs):
                    mig_source = j
                    mig_dest = k
                    self.migration_event(mig_source, mig_dest)

    def process_pedigree_common_ancestors(self, ind, ploid):
        """
        Merge the ancestral material that has been inherited on this "ploid"
        (i.e., single genome, chromosome strand, tskit node) of this
        individual, then recombine and distribute the remaining ancestral
        material among its parent ploids.
        """
        common_ancestors = ind.common_ancestors[ploid]
        if len(common_ancestors) == 0:
            # No ancestral material inherited on this ploid of this individual
            return

        # All the segment chains in common_ancestors reach a common
        # ancestor in this ploid of this individual. First we remove
        # them from the populations they are stored in:
        for _, anc in common_ancestors:
            pop = self.P[anc.population]
            pop.remove_individual(anc)

        # Merge together these lists of ancestral segments to create the
        # monoploid genome for this ploid of this individual.
        # If any coalescences occur, they use the corresponding node ID.
        # FIXME update the population/label here
        node = ind.nodes[ploid]
        genome = self.merge_ancestors(common_ancestors, 0, 0, node)
        if ind.parents[ploid] == tskit.NULL:
            # If this individual is a founder we need to make sure that all
            # lineages that are present are marked with unary nodes to show
            # where they emerged from the pedigree. These can then be
            # picked up as ancient samples by later simulations. Note that
            # pedigree simulations are a special case here in that we
            # don't want to extend unary edges to the last time point in the
            # simulation because we are *not* simulating the entire
            # population process, only the subset that we have information
            # about within the pedigree.
            seg = genome
            while seg is not None:
                if seg.node != node:
                    self.store_edge(seg.left, seg.right, parent=node, child=seg.node)
                seg = seg.next
            pop.remove_individual(genome)
        else:
            # If this individual is not a founder, it inherited the current
            # monoploid genome as a gamete from one parent. This gamete was
            # created by recombining between the parent's monoploid genomes
            # to create two independent lines of ancestry.
            parent = self.pedigree.individuals[ind.parents[ploid]]
            parent_ancestry = self.dtwf_recombine(genome)
            for parent_ploid in range(ind.ploidy):
                seg = parent_ancestry[parent_ploid]
                if seg is not None:
                    # Add this segment chain of ancestry to the accumulating
                    # set in the parent on the corresponding ploid.
                    parent.add_common_ancestor(seg, ploid=parent_ploid)
                    if seg != genome:
                        # Add the recombined ancestor to the population
                        pop.add(seg)

        self.flush_edges()
        self.verify()

    def dtwf_climb_pedigree(self):
        """
        Simulates transmission of ancestral material through a pre-specified
        pedigree
        """
        assert self.num_populations == 1  # Single pop/pedigree for now
        pop = self.P[0]

        # Go through the extant lineages and gather the ancestral material
        # into the corresponding pedigree individuals.
        for anc in pop.iter_ancestors():
            node = self.tables.nodes[anc.node]
            assert node.individual != tskit.NULL
            ind = self.pedigree.individuals[node.individual]
            ind.add_common_ancestor(anc, ploid=ind.nodes.index(anc.node))

        # Visit pedigree individuals in time order.
        visit_order = sorted(self.pedigree.individuals, key=lambda x: (x.time, x.id))
        for ind in visit_order:
            self.t = ind.time
            for ploid in range(ind.ploidy):
                self.process_pedigree_common_ancestors(ind, ploid)

    def store_arg_edges(self, segment, u=None):
        if u is None:
            u = len(self.tables.nodes) - 1
        # Store edges pointing to current node to the left
        x = segment
        while x is not None:
            if x.node != u:
                self.store_edge(x.left, x.right, u, x.node)
            x.node = u
            x = x.prev
        # Store edges pointing to current node to the right
        x = segment
        while x is not None:
            if x.node != u:
                self.store_edge(x.left, x.right, u, x.node)
            x.node = u
            x = x.next

    def migration_event(self, j, k):
        """
        Migrates an individual from population j to population k.
        Only does label 0
        """
        label = 0
        source = self.P[j]
        dest = self.P[k]
        index = random.randint(0, source.get_num_ancestors(label) - 1)
        x = source.remove(index, label)
        dest.add(x, label)
        if self.full_arg:
            self.store_node(k, flags=msprime.NODE_IS_MIG_EVENT)
            self.store_arg_edges(x)
        # Set the population id for each segment also.
        u = x
        while u is not None:
            u.population = k
            u = u.next

    def get_recomb_left_bound(self, seg):
        """
        Returns the left bound for genomic region over which the specified
        segment represents recombination events.
        """
        if seg.prev is None:
            left_bound = seg.left + 1 if self.discrete_genome else seg.left
        else:
            left_bound = seg.prev.right
        return left_bound

    def get_gc_left_bound(self, seg):
        # TODO remove me
        return self.get_recomb_left_bound(seg)

    def set_segment_mass(self, seg):
        """
        Sets the mass for the specified segment. All links *must* be
        appropriately set before calling this function.
        """
        if self.recomb_mass_index is not None:
            mass_index = self.recomb_mass_index[seg.label]
            recomb_left_bound = self.get_recomb_left_bound(seg)
            recomb_mass = self.recomb_map.mass_between(recomb_left_bound, seg.right)
            mass_index.set_value(seg.index, recomb_mass)
        if self.gc_mass_index is not None:
            mass_index = self.gc_mass_index[seg.label]
            gc_left_bound = self.get_gc_left_bound(seg)
            gc_mass = self.gc_map.mass_between(gc_left_bound, seg.right)
            mass_index.set_value(seg.index, gc_mass)

    def set_labels(self, segment, new_label):
        """
        Move the specified segment to the specified label.
        """
        mass_indexes = [self.recomb_mass_index, self.gc_mass_index]
        while segment is not None:
            masses = []
            for mass_index in mass_indexes:
                if mass_index is not None:
                    masses.append(mass_index[segment.label].get_value(segment.index))
                    mass_index[segment.label].set_value(segment.index, 0)
            segment.label = new_label
            for mass, mass_index in zip(masses, mass_indexes):
                if mass_index is not None:
                    mass_index[segment.label].set_value(segment.index, mass)
            segment = segment.next

    def choose_breakpoint(self, mass_index, rate_map):
        assert mass_index.get_total() > 0
        random_mass = random.uniform(0, mass_index.get_total())
        y = self.segments[mass_index.find(random_mass)]
        y_cumulative_mass = mass_index.get_cumulative_sum(y.index)
        y_right_mass = rate_map.position_to_mass(y.right)
        bp_mass = y_right_mass - (y_cumulative_mass - random_mass)
        bp = rate_map.mass_to_position(bp_mass)
        if self.discrete_genome:
            bp = math.floor(bp)
        return y, bp

    def hudson_recombination_event(self, label, return_heads=False):
        """
        Implements a recombination event.
        """
        self.num_re_events += 1
        y, bp = self.choose_breakpoint(self.recomb_mass_index[label], self.recomb_map)
        x = y.prev
        if y.left < bp:
            #   x         y
            # =====  ===|====  ...
            #          bp
            # becomes
            #   x     y
            # =====  ===  α        (LHS)
            #           =====  ... (RHS)
            alpha = self.copy_segment(y)
            alpha.left = bp
            alpha.prev = None
            if y.next is not None:
                y.next.prev = alpha
            y.next = None
            y.right = bp
            self.set_segment_mass(y)
            lhs_tail = y
        else:
            #   x            y
            # =====  |   =========  ...
            #
            # becomes
            #   x
            # =====          α          (LHS)
            #            =========  ... (RHS)
            x.next = None
            y.prev = None
            alpha = y
            lhs_tail = x
        self.set_segment_mass(alpha)
        self.P[alpha.population].add(alpha, label)
        if self.full_arg:
            self.store_node(lhs_tail.population, flags=msprime.NODE_IS_RE_EVENT)
            self.store_arg_edges(lhs_tail)
            self.store_node(alpha.population, flags=msprime.NODE_IS_RE_EVENT)
            self.store_arg_edges(alpha)
        ret = None
        if return_heads:
            x = lhs_tail
            # Seek back to the head of the x chain
            while x.prev is not None:
                x = x.prev
            ret = x, alpha
        return ret

    def generate_gc_tract_length(self):
        # generate tract length
        if self.discrete_genome:
            tl = np.random.geometric(1 / self.tract_length)
        else:
            tl = np.random.exponential(self.tract_length)
        return tl

    def wiuf_gene_conversion_within_event(self, label):
        """
        Implements a gene conversion event that starts within a segment
        """
        # TODO This is more complicated than it needs to be now because
        # we're not trying to simulate the full GC process with this
        # one event anymore. Look into what bits can be dropped now
        # that we're simulating gc_left separately again.
        y, left_breakpoint = self.choose_breakpoint(
            self.gc_mass_index[label], self.gc_map
        )
        x = y.prev
        # generate tract_length
        tl = self.generate_gc_tract_length()
        assert tl > 0
        right_breakpoint = left_breakpoint + tl
        if y.left >= right_breakpoint:
            #                  y
            # ...  |   |   ========== ...
            #     lbp rbp
            return None
        self.num_gc_events += 1

        # Process left break
        insert_alpha = True
        if left_breakpoint <= y.left:
            #  x             y
            # =====  |  ==========
            #       lbp
            #
            # becomes
            #  x
            # =====         α
            #           ==========
            if x is None:
                # In this case we *don't* insert alpha because it is already
                # the head of a segment chain
                insert_alpha = False
            else:
                x.next = None
            y.prev = None
            alpha = y
            tail = x
        else:
            #  x             y
            # =====     ====|=====
            #              lbp
            #
            # becomes
            #  x         y
            # =====     ====   α
            #               ======
            alpha = self.copy_segment(y)
            alpha.left = left_breakpoint
            alpha.prev = None
            if y.next is not None:
                y.next.prev = alpha
            y.next = None
            y.right = left_breakpoint
            self.set_segment_mass(y)
            tail = y
        self.set_segment_mass(alpha)

        # Find the segment z that the right breakpoint falls in
        z = alpha
        while z is not None and right_breakpoint >= z.right:
            z = z.next

        head = None
        # Process the right break
        if z is not None:
            if z.left < right_breakpoint:
                #   tail             z
                # ======
                #       ...  ===|==========
                #              rbp
                #
                # becomes
                #  tail              head
                # =====         ===========
                #      ...   ===
                #             z
                head = self.copy_segment(z)
                head.left = right_breakpoint
                if z.next is not None:
                    z.next.prev = head
                z.right = right_breakpoint
                z.next = None
                self.set_segment_mass(z)
            else:
                #   tail             z
                # ======
                #   ...   |   =============
                #        rbp
                #
                # becomes
                #  tail             z
                # ======      =============
                #  ...
                if z.prev is not None:
                    z.prev.next = None
                head = z
            if tail is not None:
                tail.next = head
            head.prev = tail
            self.set_segment_mass(head)

        #        y            z
        #  |  ========== ... ===== |
        # lbp                     rbp
        # When y and z are the head and tail of the segment chains, then
        # this GC event does nothing. This logic takes are of this situation.
        new_individual_head = None
        if insert_alpha:
            new_individual_head = alpha
        elif head is not None:
            new_individual_head = head
        if new_individual_head is not None:
            self.P[new_individual_head.population].add(
                new_individual_head, new_individual_head.label
            )

    def wiuf_gene_conversion_left_event(self, label):
        """
        Implements a gene conversion event that started left of a first segment.
        """
        random_gc_left = random.uniform(0, self.get_total_gc_left(label))
        # Get segment where gene conversion starts from left
        y = self.find_cleft_individual(label, random_gc_left)
        assert y is not None

        # generate tract_length
        tl = self.generate_gc_tract_length()

        bp = y.left + tl
        while y is not None and y.right <= bp:
            y = y.next

        # if the gene conversion spans the whole individual nothing happens
        if y is None:
            #    last segment
            # ... ==========   |
            #                  bp
            # stays in current state
            return None

        self.num_gc_events += 1
        x = y.prev
        if y.left < bp:
            #  x          y
            # =====   =====|====
            #              bp
            # becomes
            #  x         y
            # =====   =====
            #              =====
            #                α
            alpha = self.copy_segment(y)
            alpha.left = bp
            alpha.prev = None
            if alpha.next is not None:
                alpha.next.prev = alpha
            y.next = None
            y.right = bp
            self.set_segment_mass(y)
        else:
            #  x          y
            # ===== |  =========
            #       bp
            # becomes
            #  x
            # =====
            #          =========
            #              α
            # split the link between x and y.
            x.next = None
            y.prev = None
            alpha = y
        self.set_segment_mass(alpha)
        assert alpha.prev is None
        self.P[alpha.population].add(alpha, label)

    def hudson_recombination_event_sweep_phase(self, label, sweep_site, pop_freq):
        """
        Implements a recombination event in during a selective sweep.
        """
        lhs, rhs = self.hudson_recombination_event(label, return_heads=True)

        r = random.random()
        if sweep_site < rhs.left:
            if r < 1.0 - pop_freq:
                # move rhs to other population
                t_idx = self.P[rhs.population].find_indv(rhs)
                self.P[rhs.population].remove(t_idx, rhs.label)
                self.set_labels(rhs, 1 - label)
                self.P[rhs.population].add(rhs, rhs.label)
        else:
            if r < 1.0 - pop_freq:
                # move lhs to other population
                t_idx = self.P[lhs.population].find_indv(lhs)
                self.P[lhs.population].remove(t_idx, lhs.label)
                self.set_labels(lhs, 1 - label)
                self.P[lhs.population].add(lhs, lhs.label)

    def dtwf_generate_breakpoint(self, start):
        left_bound = start + 1 if self.discrete_genome else start
        mass_to_next_recomb = np.random.exponential(1.0)
        bp = self.recomb_map.shift_by_mass(left_bound, mass_to_next_recomb)
        if self.discrete_genome:
            bp = math.floor(bp)
        return bp

    def dtwf_recombine(self, x):
        """
        Chooses breakpoints and returns segments sorted by inheritance
        direction, by iterating through segment chain starting with x
        """
        u = self.alloc_segment(-1, -1, -1, -1, None, None)
        v = self.alloc_segment(-1, -1, -1, -1, None, None)
        seg_tails = [u, v]

        # TODO Should this be the recombination rate going foward from x.left?
        if self.recomb_map.total_mass > 0:
            k = self.dtwf_generate_breakpoint(x.left)
        else:
            k = np.inf

        ix = np.random.randint(2)
        seg_tails[ix].next = x

        while x is not None:
            seg_tails[ix] = x
            y = x.next

            if x.right > k:
                assert x.left < k
                self.num_re_events += 1
                ix = (ix + 1) % 2
                # Make new segment
                if seg_tails[ix] is u or seg_tails[ix] is v:
                    tail = None
                else:
                    tail = seg_tails[ix]
                z = self.copy_segment(x)
                z.left = k
                z.prev = tail
                self.set_segment_mass(z)
                if x.next is not None:
                    x.next.prev = z
                seg_tails[ix].next = z
                seg_tails[ix] = z
                x.next = None
                x.right = k
                self.set_segment_mass(x)
                x = z
                k = self.dtwf_generate_breakpoint(k)
            elif x.right <= k and y is not None and y.left >= k:
                # Recombine between segment and the next
                assert seg_tails[ix] == x
                x.next = None
                y.prev = None
                while y.left >= k:
                    self.num_re_events += 1
                    ix = (ix + 1) % 2
                    k = self.dtwf_generate_breakpoint(k)
                seg_tails[ix].next = y
                if seg_tails[ix] is u or seg_tails[ix] is v:
                    tail = None
                else:
                    tail = seg_tails[ix]
                y.prev = tail
                self.set_segment_mass(y)
                seg_tails[ix] = y
                x = y
            else:
                # No recombination between x.right and y.left
                x = y

        # Remove sentinel segments - this can be handled more simply
        # with pointers in C implemetation
        s = u
        u = s.next
        self.free_segment(s)

        s = v
        v = s.next
        self.free_segment(s)

        return u, v

    def census_event(self, time):
        for pop in self.P:
            for ancestor in pop.iter_ancestors():
                seg = ancestor
                self.flush_edges()
                u = self.tables.nodes.add_row(
                    time=time, flags=msprime.NODE_IS_CEN_EVENT, population=pop.id
                )
                while seg is not None:
                    # Add an edge joining the segment to the new node.
                    self.store_edge(seg.left, seg.right, u, seg.node)
                    seg.node = u
                    seg = seg.next

    def bottleneck_event(self, pop_id, label, intensity):
        # self.print_state()
        # Merge some of the ancestors.
        pop = self.P[pop_id]
        H = []
        for _ in range(pop.get_num_ancestors()):
            if random.random() < intensity:
                x = pop.remove(0)
                heapq.heappush(H, (x.left, x))
        self.merge_ancestors(H, pop_id, label)

    def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
        pop = self.P[pop_id]
        defrag_required = False
        coalescence = False
        alpha = None
        z = None
        merged_head = None
        while len(H) > 0:
            alpha = None
            left = H[0][0]
            X = []
            r_max = self.L
            while len(H) > 0 and H[0][0] == left:
                x = heapq.heappop(H)[1]
                X.append(x)
                r_max = min(r_max, x.right)
            if len(H) > 0:
                r_max = min(r_max, H[0][0])
            if len(X) == 1:
                x = X[0]
                if len(H) > 0 and H[0][0] < x.right:
                    alpha = self.alloc_segment(x.left, H[0][0], x.node, x.population)
                    alpha.label = label
                    x.left = H[0][0]
                    heapq.heappush(H, (x.left, x))
                else:
                    if x.next is not None:
                        y = x.next
                        heapq.heappush(H, (y.left, y))
                    alpha = x
                    alpha.next = None
            else:
                if new_node_id == -1:
                    coalescence = True
                    new_node_id = self.store_node(pop_id)
                # We must also break if the next left value is less than
                # any of the right values in the current overlap set.
                if left not in self.S:
                    j = self.S.floor_key(left)
                    self.S[left] = self.S[j]
                if r_max not in self.S:
                    j = self.S.floor_key(r_max)
                    self.S[r_max] = self.S[j]
                # Update the number of extant segments.
                if self.S[left] == len(X):
                    self.S[left] = 0
                    right = self.S.succ_key(left)
                else:
                    right = left
                    while right < r_max and self.S[right] != len(X):
                        self.S[right] -= len(X) - 1
                        right = self.S.succ_key(right)
                    alpha = self.alloc_segment(left, right, new_node_id, pop_id)
                # Update the heaps and make the record.
                for x in X:
                    self.store_edge(left, right, new_node_id, x.node)
                    if x.right == right:
                        self.free_segment(x)
                        if x.next is not None:
                            y = x.next
                            heapq.heappush(H, (y.left, y))
                    elif x.right > right:
                        x.left = right
                        heapq.heappush(H, (x.left, x))

            # loop tail; update alpha and integrate it into the state.
            if alpha is not None:
                if z is None:
                    pop.add(alpha, label)
                    merged_head = alpha
                else:
                    if self.full_arg:
                        defrag_required |= z.right == alpha.left
                    else:
                        defrag_required |= (
                            z.right == alpha.left and z.node == alpha.node
                        )
                    z.next = alpha
                alpha.prev = z
                self.set_segment_mass(alpha)
                z = alpha
        if self.full_arg:
            if not coalescence:
                self.store_node(pop_id, flags=msprime.NODE_IS_CA_EVENT)
            self.store_arg_edges(z)
        if defrag_required:
            self.defrag_segment_chain(z)
        if coalescence:
            self.defrag_breakpoints()
        return merged_head

    def defrag_segment_chain(self, z):
        y = z
        while y.prev is not None:
            x = y.prev
            if x.right == y.left and x.node == y.node:
                x.right = y.right
                x.next = y.next
                if y.next is not None:
                    y.next.prev = x
                self.set_segment_mass(x)
                self.free_segment(y)
            y = x

    def defrag_breakpoints(self):
        # Defrag the breakpoints set
        j = 0
        k = 0
        while k < self.L:
            k = self.S.succ_key(j)
            if self.S[j] == self.S[k]:
                del self.S[k]
            else:
                j = k

    def common_ancestor_event(self, population_index, label):
        """
        Implements a coancestry event.
        """
        pop = self.P[population_index]
        # Choose two ancestors uniformly.
        j = random.randint(0, pop.get_num_ancestors(label) - 1)
        x = pop.remove(j, label)
        j = random.randint(0, pop.get_num_ancestors(label) - 1)
        y = pop.remove(j, label)
        self.merge_two_ancestors(population_index, label, x, y)

    def merge_two_ancestors(self, population_index, label, x, y):
        pop = self.P[population_index]
        self.num_ca_events += 1
        z = None
        coalescence = False
        defrag_required = False
        while x is not None or y is not None:
            alpha = None
            if x is None or y is None:
                if x is not None:
                    alpha = x
                    x = None
                if y is not None:
                    alpha = y
                    y = None
            else:
                if y.left < x.left:
                    beta = x
                    x = y
                    y = beta
                if x.right <= y.left:
                    alpha = x
                    x = x.next
                    alpha.next = None
                elif x.left != y.left:
                    alpha = self.copy_segment(x)
                    alpha.prev = None
                    alpha.next = None
                    alpha.right = y.left
                    x.left = y.left
                else:
                    if not coalescence:
                        coalescence = True
                        self.store_node(population_index)
                    u = len(self.tables.nodes) - 1
                    # Put in breakpoints for the outer edges of the coalesced
                    # segment
                    left = x.left
                    r_max = min(x.right, y.right)
                    if left not in self.S:
                        j = self.S.floor_key(left)
                        self.S[left] = self.S[j]
                    if r_max not in self.S:
                        j = self.S.floor_key(r_max)
                        self.S[r_max] = self.S[j]
                    # Update the number of extant segments.
                    if self.S[left] == 2:
                        self.S[left] = 0
                        right = self.S.succ_key(left)
                    else:
                        right = left
                        while right < r_max and self.S[right] != 2:
                            self.S[right] -= 1
                            right = self.S.succ_key(right)
                        alpha = self.alloc_segment(
                            left=left,
                            right=right,
                            node=u,
                            population=population_index,
                            label=label,
                        )
                    self.store_edge(left, right, u, x.node)
                    self.store_edge(left, right, u, y.node)
                    # Now trim the ends of x and y to the right sizes.
                    if x.right == right:
                        self.free_segment(x)
                        x = x.next
                    else:
                        x.left = right
                    if y.right == right:
                        self.free_segment(y)
                        y = y.next
                    else:
                        y.left = right

            # loop tail; update alpha and integrate it into the state.
            if alpha is not None:
                if z is None:
                    pop.add(alpha, label)
                else:
                    if self.full_arg:
                        defrag_required |= z.right == alpha.left
                    else:
                        defrag_required |= (
                            z.right == alpha.left and z.node == alpha.node
                        )
                    z.next = alpha
                alpha.prev = z
                self.set_segment_mass(alpha)
                z = alpha

        if self.full_arg:
            if not coalescence:
                u = self.store_node(population_index, flags=msprime.NODE_IS_CA_EVENT)
            self.store_arg_edges(z, u)
        elif self.store_unary and coalescence:
            self.store_arg_edges(z, u)
        if defrag_required:
            self.defrag_segment_chain(z)
        if coalescence:
            self.defrag_breakpoints()

    def print_state(self, verify=False):
        print("State @ time ", self.t)
        for label in range(self.num_labels):
            print(
                "Recomb mass = ",
                0
                if self.recomb_mass_index is None
                else self.recomb_mass_index[label].get_total(),
            )
            print(
                "GC mass = ",
                0
                if self.gc_mass_index is None
                else self.gc_mass_index[label].get_total(),
            )
        print("Modifier events = ")
        for t, f, args in self.modifier_events:
            print("\t", t, f, args)
        print("Population sizes:", [pop.get_num_ancestors() for pop in self.P])
        print("Migration Matrix:")
        for row in self.migration_matrix:
            print("\t", row)
        for population in self.P:
            population.print_state()
        if self.pedigree is not None:
            self.pedigree.print_state()
        print("Overlap counts", len(self.S))
        for k, x in self.S.items():
            print("\t", k, "\t:\t", x)
        for label in range(self.num_labels):
            if self.recomb_mass_index is not None:
                print(
                    "recomb_mass_index [%d]: %d"
                    % (label, self.recomb_mass_index[label].get_total())
                )
                for j in range(1, self.max_segments + 1):
                    s = self.recomb_mass_index[label].get_value(j)
                    if s != 0:
                        seg = self.segments[j]
                        left_bound = self.get_recomb_left_bound(seg)
                        sp = self.recomb_map.mass_between(left_bound, seg.right)
                        print("\t", j, "->", s, sp)
            if self.gc_mass_index is not None:
                print(
                    "gc_mass_index [%d]: %d"
                    % (label, self.gc_mass_index[label].get_total())
                )
                for j in range(1, self.max_segments + 1):
                    s = self.gc_mass_index[label].get_value(j)
                    if s != 0:
                        seg = self.segments[j]
                        left_bound = self.get_gc_left_bound(seg)
                        sp = self.gc_map.mass_between(left_bound, seg.right)
                        print("\t", j, "->", s, sp)
        print("nodes")
        print(self.tables.nodes)
        print("edges")
        print(self.tables.edges)
        if verify:
            self.verify()

    def verify_segments(self):
        for pop in self.P:
            for label in range(self.num_labels):
                for head in pop.iter_label(label):
                    assert head.prev is None
                    prev = head
                    u = head.next
                    while u is not None:
                        assert prev.next is u
                        assert u.prev is prev
                        assert u.left >= prev.right
                        assert u.label == head.label
                        assert u.population == head.population
                        prev = u
                        u = u.next

    def verify_overlaps(self):
        overlap_counter = OverlapCounter(self.L)
        for pop in self.P:
            for label in range(self.num_labels):
                for u in pop.iter_label(label):
                    while u is not None:
                        overlap_counter.increment_interval(u.left, u.right)
                        u = u.next

        for pos, count in self.S.items():
            if pos != self.L:
                assert count == overlap_counter.overlaps_at(pos)

        assert self.S[self.L] == -1
        # Check the ancestry tracking.
        A = bintrees.AVLTree()
        A[0] = 0
        A[self.L] = -1
        for pop in self.P:
            for label in range(self.num_labels):
                for u in pop.iter_label(label):
                    while u is not None:
                        if u.left not in A:
                            k = A.floor_key(u.left)
                            A[u.left] = A[k]
                        if u.right not in A:
                            k = A.floor_key(u.right)
                            A[u.right] = A[k]
                        k = u.left
                        while k < u.right:
                            A[k] += 1
                            k = A.succ_key(k)
                        u = u.next
        # Now, defrag A
        j = 0
        k = 0
        while k < self.L:
            k = A.succ_key(j)
            if A[j] == A[k]:
                del A[k]
            else:
                j = k
        assert list(A.items()) == list(self.S.items())

    def verify_mass_index(self, label, mass_index, rate_map, compute_left_bound):
        assert mass_index is not None
        total_mass = 0
        alt_total_mass = 0
        for pop_index, pop in enumerate(self.P):
            for u in pop.iter_label(label):
                assert u.prev is None
                left = compute_left_bound(u)
                while u is not None:
                    assert u.population == pop_index
                    assert u.left < u.right
                    left_bound = compute_left_bound(u)
                    s = rate_map.mass_between(left_bound, u.right)
                    right = u.right
                    index_value = mass_index.get_value(u.index)
                    total_mass += index_value
                    assert math.isclose(s, index_value, abs_tol=1e-6)
                    v = u.next
                    if v is not None:
                        assert v.prev == u
                        assert u.right <= v.left
                    u = v

                s = rate_map.mass_between(left, right)
                alt_total_mass += s
        assert math.isclose(total_mass, mass_index.get_total(), abs_tol=1e-6)
        assert math.isclose(total_mass, alt_total_mass, abs_tol=1e-6)

    def verify(self):
        """
        Checks that the state of the simulator is consistent.
        """
        self.verify_segments()
        if self.model != "fixed_pedigree":
            # The fixed_pedigree model doesn't maintain a bunch of stuff.
            # It would probably be simpler if it did.
            self.verify_overlaps()
            for label in range(self.num_labels):
                if self.recomb_mass_index is None:
                    assert self.recomb_map.total_mass == 0
                else:
                    self.verify_mass_index(
                        label,
                        self.recomb_mass_index[label],
                        self.recomb_map,
                        self.get_recomb_left_bound,
                    )

                if self.gc_mass_index is None:
                    assert self.gc_map.total_mass == 0
                else:
                    self.verify_mass_index(
                        label,
                        self.gc_mass_index[label],
                        self.gc_map,
                        self.get_gc_left_bound,
                    )


def run_simulate(args):
    """
    Runs the simulation and outputs the results in text.
    """
    n = args.sample_size
    m = args.sequence_length
    rho = args.recombination_rate
    gc_rate = args.gene_conversion_rate[0]
    mean_tract_length = args.gene_conversion_rate[1]
    num_populations = args.num_populations
    migration_matrix = [
        [args.migration_rate * int(j != k) for j in range(num_populations)]
        for k in range(num_populations)
    ]
    sample_configuration = [0 for j in range(num_populations)]
    population_growth_rates = [0 for j in range(num_populations)]
    population_sizes = [1 for j in range(num_populations)]
    sample_configuration[0] = n
    if args.sample_configuration is not None:
        sample_configuration = args.sample_configuration
    if args.population_growth_rates is not None:
        population_growth_rates = args.population_growth_rates
    if args.population_sizes is not None:
        population_sizes = args.population_sizes
    if args.recomb_positions is None or args.recomb_rates is None:
        recombination_map = RateMap([0, m], [rho, 0])
    else:
        positions = args.recomb_positions
        rates = args.recomb_rates
        recombination_map = RateMap(positions, rates)
    num_labels = 1
    sweep_trajectory = None
    if args.model == "single_sweep":
        if num_populations > 1:
            raise ValueError("Multiple populations not currently supported")
        # Compute the trajectory
        if args.trajectory is None:
            raise ValueError("Must provide trajectory (init_freq, end_freq, alpha)")
        init_freq, end_freq, alpha = args.trajectory
        traj_sim = TrajectorySimulator(init_freq, end_freq, alpha, args.time_slice)
        sweep_trajectory = traj_sim.run()
        num_labels = 2
    random.seed(args.random_seed)
    np.random.seed(args.random_seed + 1)

    if args.from_ts is None:
        tables = tskit.TableCollection(m)
        for pop_id, sample_count in enumerate(sample_configuration):
            tables.populations.add_row()
            for _ in range(sample_count):
                tables.nodes.add_row(
                    flags=tskit.NODE_IS_SAMPLE, time=0, population=pop_id
                )
    else:
        from_ts = tskit.load(args.from_ts)
        tables = from_ts.dump_tables()

    s = Simulator(
        tables=tables,
        recombination_map=recombination_map,
        migration_matrix=migration_matrix,
        model=args.model,
        population_growth_rates=population_growth_rates,
        population_sizes=population_sizes,
        population_growth_rate_changes=args.population_growth_rate_change,
        population_size_changes=args.population_size_change,
        migration_matrix_element_changes=args.migration_matrix_element_change,
        bottlenecks=args.bottleneck,
        census_times=args.census_time,
        max_segments=100000,
        num_labels=num_labels,
        full_arg=args.full_arg,
        store_unary=args.store_unary,
        sweep_trajectory=sweep_trajectory,
        time_slice=args.time_slice,
        gene_conversion_rate=gc_rate,
        gene_conversion_length=mean_tract_length,
        discrete_genome=args.discrete,
    )
    ts = s.simulate(args.end_time)
    ts.dump(args.output_file)
    if args.verbose:
        s.print_state()


def add_simulator_arguments(parser):
    parser.add_argument("sample_size", type=int)
    parser.add_argument("output_file")
    parser.add_argument(
        "-v", "--verbose", help="increase output verbosity", action="store_true"
    )
    parser.add_argument(
        "-l",
        "--log-level",
        type=int,
        help="Set log-level to the specified value",
        default=logging.WARNING,
    )
    parser.add_argument("--random-seed", "-s", type=int, default=1)
    parser.add_argument("--sequence-length", "-L", type=int, default=100)
    parser.add_argument("--discrete", "-d", action="store_true")
    parser.add_argument("--num-replicates", "-R", type=int, default=1000)
    parser.add_argument("--recombination-rate", "-r", type=float, default=0.01)
    parser.add_argument("--recomb-positions", type=float, nargs="+", default=None)
    parser.add_argument("--recomb-rates", type=float, nargs="+", default=None)
    parser.add_argument(
        "--gene-conversion-rate", "-c", type=float, nargs=2, default=[0, 3]
    )
    parser.add_argument("--num-populations", "-p", type=int, default=1)
    parser.add_argument("--migration-rate", "-g", type=float, default=1)
    parser.add_argument("--sample-configuration", type=int, nargs="+", default=None)
    parser.add_argument(
        "--population-growth-rates", type=float, nargs="+", default=None
    )
    parser.add_argument("--population-sizes", type=float, nargs="+", default=None)
    parser.add_argument(
        "--population-size-change", type=float, nargs=3, action="append", default=[]
    )
    parser.add_argument(
        "--population-growth-rate-change",
        type=float,
        nargs=3,
        action="append",
        default=[],
    )
    parser.add_argument(
        "--migration-matrix-element-change",
        type=float,
        nargs=4,
        action="append",
        default=[],
    )
    parser.add_argument(
        "--bottleneck", type=float, nargs=3, action="append", default=[]
    )
    parser.add_argument(
        "--census-time", type=float, nargs=1, action="append", default=[]
    )
    parser.add_argument(
        "--trajectory",
        type=float,
        nargs=3,
        default=None,
        help="Parameters for the allele frequency trajectory simulation",
    )
    parser.add_argument(
        "--full-arg",
        action="store_true",
        default=False,
        help="Store the full ARG with all recombination and common ancestor nodes",
    )
    parser.add_argument(
        "--store-unary",
        action="store_true",
        default=False,
        help="Store unary nodes.",
    )
    parser.add_argument(
        "--time-slice",
        type=float,
        default=1e-6,
        help="The delta_t value for selective sweeps",
    )
    parser.add_argument("--model", default="hudson")
    parser.add_argument(
        "--from-ts",
        "-F",
        default=None,
        help=(
            "Specify the tree sequence to complete. The sample_size argument "
            "is ignored if this is provided"
        ),
    )
    parser.add_argument(
        "--end-time", type=float, default=np.inf, help="The end for simulations."
    )


def main(args=None):
    parser = argparse.ArgumentParser()
    add_simulator_arguments(parser)
    args = parser.parse_args(args)
    log_output = daiquiri.output.Stream(
        sys.stdout,
        formatter=daiquiri.formatter.ColorFormatter(fmt="[%(levelname)s] %(message)s"),
    )
    daiquiri.setup(level=args.log_level, outputs=[log_output])
    run_simulate(args)


if __name__ == "__main__":
    main()
back to top