swh:1:snp:46f6d1bb45c3612bf3206208647700c940699a40
Raw File
Tip revision: 630d4e32f343e86a4921d0773f1f5c5adf90553f authored by Eric Denovellis on 18 September 2021, 19:04:35 UTC
Add readme about data
Tip revision: 630d4e3
analysis.py
import os
from glob import glob

import networkx as nx
import numpy as np
import pandas as pd
import scipy
import xarray as xr
from loren_frank_data_processing import make_tetrode_dataframe
from loren_frank_data_processing.track_segment_classification import \
    project_points_to_segment
from scipy.ndimage import label
from scipy.ndimage.filters import gaussian_filter1d
from src.parameters import (_BRAIN_AREAS, ANIMALS, PROBABILITY_THRESHOLD,
                            PROCESSED_DATA_DIR)
from trajectory_analysis_tools import (get_ahead_behind_distance,
                                       get_trajectory_data)


def get_replay_info(results, spikes, ripple_times, position_info,
                    track_graph, sampling_frequency, probability_threshold,
                    epoch_key, classifier, ripple_consensus_trace_zscore):
    '''

    Parameters
    ----------
    results : xarray.Dataset, shape (n_ripples, n_position_bins, n_states,
                                     n_ripple_time)
    spikes : pandas.DataFrame (n_time, n_neurons)
    ripple_times : pandas.DataFrame (n_ripples, 2)
    position_info : pandas.DataFrame (n_time, n_covariates)
    track_graph : networkx.Graph
    sampling_frequency : float
    probability_threshold : float
    epoch_key : tuple

    Returns
    -------
    replay_info : pandas.DataFrame, shape (n_ripples, n_covariates)

    '''

    # Downsample the ripple consensus trace to match the spiking data sampling
    # rate
    new_index = pd.Index(np.unique(np.concatenate(
        (ripple_consensus_trace_zscore.index, position_info.index))),
        name='time')
    ripple_consensus_trace_zscore = (ripple_consensus_trace_zscore
                                     .reindex(index=new_index)
                                     .interpolate(method='linear')
                                     .reindex(index=position_info.index)
                                     )
    replay_info = pd.DataFrame(
        [get_ripple_replay_info(ripple, results, spikes,
                                ripple_consensus_trace_zscore,
                                position_info, sampling_frequency,
                                probability_threshold, track_graph,
                                classifier)
         for ripple in ripple_times.itertuples()], index=ripple_times.index)

    animal, day, epoch = epoch_key

    replay_info['animal'] = animal
    replay_info['day'] = int(day)
    replay_info['epoch'] = int(epoch)

    min_max = (
        classifier
        ._nodes_df[classifier._nodes_df.is_bin_edge]
        .groupby('edge_id')
        .aggregate(['min', 'max']))

    replay_info['center_well_position'] = min_max.loc[0].linear_position.min()
    replay_info['choice_position'] = min_max.loc[0].linear_position.max()

    replay_info['left_arm_start'] = min_max.loc[1].linear_position.min()
    replay_info['left_well_position'] = min_max.loc[3].linear_position.max()

    replay_info['right_arm_start'] = min_max.loc[2].linear_position.min()
    replay_info['right_well_position'] = min_max.loc[4].linear_position.max()
    center_well_id = 0
    replay_info['max_linear_distance'] = list(
        classifier.distance_between_nodes_[center_well_id].values())[-1]

    return replay_info


def get_probability(results):
    '''Get probability of each state and two states derived from mixtures of
    each state.

    Parameters
    ----------
    results : xarray.Dataset

    Returns
    -------
    probability : xarray.DataArray

    '''
    try:
        probability = (results
                       .acausal_posterior
                       .sum(['x_position', 'y_position'], skipna=True))
    except ValueError:
        probability = (results
                       .acausal_posterior
                       .dropna('position', how='all')
                       .sum('position', skipna=False))

    return xr.concat(
        (probability,
         probability
            .sel(state=['Hover', 'Continuous'])
            .sum('state', skipna=False)
            .assign_coords(state='Hover-Continuous-Mix'),
         probability
            .sel(state=['Fragmented', 'Continuous'])
            .sum('state', skipna=False)
            .assign_coords(state='Fragmented-Continuous-Mix'),
         ), dim='state')


def get_is_classified(probability, probablity_threshold):
    '''Classify each state by the confidence threshold and make sure two
    derived states exclude their parent states.

    Parameters
    ----------
    probability : xarray.DataArray
    probablity_threshold : float

    Returns
    -------
    is_classified : xarray.DataArray

    '''
    if probablity_threshold < 1.00:
        is_classified = probability > probablity_threshold
        is_classified.loc[dict(state='Hover-Continuous-Mix')] = (
            is_classified.sel(state='Hover-Continuous-Mix') &
            ~is_classified.sel(state='Hover') &
            ~is_classified.sel(state='Continuous') &
            (probability.sel(state='Fragmented') <
             (1 - probablity_threshold) / 2))

        is_classified.loc[dict(state='Fragmented-Continuous-Mix')] = (
            is_classified.sel(state='Fragmented-Continuous-Mix') &
            ~is_classified.sel(state='Fragmented') &
            ~is_classified.sel(state='Continuous') &
            (probability.sel(state='Hover') < (1 - probablity_threshold) / 2))
    else:
        is_classified = ((probability.copy() * 0.0).fillna(False)).astype(bool)
        A = probability.sel(state=["Hover", "Continuous", "Fragmented"]).values
        A = A.argmax(axis=-1)[..., None] == np.arange(A.shape[-1])
        is_classified.values = np.concatenate(
            (A, np.zeros((*A.shape[:2], 2), dtype=bool)), axis=-1)
    is_classified = is_classified.rename('is_classified')
    is_classified = is_classified.where(~np.isnan(probability))

    return is_classified


def get_ripple_replay_info(ripple, results, spikes,
                           ripple_consensus_trace_zscore, position_info,
                           sampling_frequency, probability_threshold,
                           track_graph, classifier):

    start_time = ripple.start_time
    end_time = ripple.end_time
    ripple_duration = ripple.duration

    ripple_time_slice = slice(start_time, end_time)

    try:
        result = (results
                  .sel(ripple_number=ripple.Index)
                  .dropna('time', how='all')
                  .assign_coords(
                      time=lambda ds: ds.time / np.timedelta64(1, 's')))
    except ValueError:
        result = (results
                  .sel(event_number=ripple.Index)
                  .dropna('time', how='all')
                  .assign_coords(
                      time=lambda ds: ds.time / np.timedelta64(1, 's')))
    probability = get_probability(result)
    is_classified = get_is_classified(
        probability, probability_threshold).astype(bool)
    is_unclassified = (is_classified.sum('state') < 1).assign_coords(
        state='Unclassified')
    is_classified = xr.concat((is_classified, is_unclassified), dim='state')

    classified = (~is_classified.sel(
        state='Unclassified')).sum('time').values > 0

    ripple_position_info = position_info.loc[ripple_time_slice]
    is_immobility = ripple_position_info.speed <= 4
    ripple_position_info = ripple_position_info.loc[is_immobility]

    ripple_spikes = spikes.loc[ripple_time_slice]
    ripple_spikes = ripple_spikes.loc[is_immobility]

    ripple_consensus = np.asarray(
        ripple_consensus_trace_zscore.loc[ripple_time_slice])
    ripple_consensus = ripple_consensus[is_immobility]

    posterior = result.acausal_posterior
    map_estimate = maximum_a_posteriori_estimate(posterior.sum('state'))

    trajectory_data = get_trajectory_data(
        posterior.sum("state"), track_graph, classifier, ripple_position_info)

    replay_distance_from_actual_position = np.abs(get_ahead_behind_distance(
        track_graph, *trajectory_data))
    center_well_id = 0
    replay_distance_from_center_well = np.abs(get_ahead_behind_distance(
        track_graph, *trajectory_data, source=center_well_id))

    try:
        replay_total_displacement = np.abs(
            replay_distance_from_center_well[-1] -
            replay_distance_from_center_well[0])
    except IndexError:
        replay_total_displacement = np.nan

    time = np.asarray(posterior.time)
    map_estimate = map_estimate.squeeze()
    replay_speed = get_map_speed(
        np.asarray(posterior.sum("state")),
        classifier.track_graph_,
        classifier.place_bin_center_ind_to_node_,
        1 / sampling_frequency,
    )
    SMOOTH_SIGMA = 0.0025
    replay_speed = gaussian_smooth(
        replay_speed, SMOOTH_SIGMA, sampling_frequency)
    replay_velocity_actual_position = np.gradient(
        replay_distance_from_actual_position, time)
    replay_velocity_center_well = np.gradient(
        replay_distance_from_center_well, time)

    hpd_threshold = highest_posterior_density(
        posterior.sum("state"), coverage=0.95)
    isin_hpd = posterior.sum("state") >= hpd_threshold[:, np.newaxis]
    spatial_coverage = (
        isin_hpd * np.diff(posterior.position)[0]).sum("position").values
    n_position_bins = (posterior.sum("state", skipna=True)
                       > 0).sum("position").values[0]
    spatial_coverage_percentage = (isin_hpd.sum("position") /
                                   n_position_bins).values
    distance_change = np.abs(np.diff(replay_distance_from_center_well))
    distance_change = np.insert(distance_change, 0, 0)

    metrics = {
        'start_time': start_time,
        'end_time': end_time,
        'duration': ripple_duration,
        'is_classified': classified,
        'n_unique_spiking': n_tetrodes_active(ripple_spikes),
        'n_total_spikes': n_total_spikes(ripple_spikes),
        'median_fraction_spikes_under_6_ms': np.nanmedian(
            fraction_spikes_less_than_6_ms(
                ripple_spikes, sampling_frequency)
        ),
        'median_spikes_per_bin': median_spikes_per_bin(
            ripple_spikes),
        'population_rate': population_rate(
            ripple_spikes, sampling_frequency),
        'actual_x_position': np.mean(
            np.asarray(ripple_position_info.x_position)),
        'actual_y_position': np.mean(
            np.asarray(ripple_position_info.y_position)),
        'actual_linear_distance': np.mean(
            np.asarray(ripple_position_info.linear_distance)),
        'actual_linear_position': np.mean(
            np.asarray(ripple_position_info.linear_position)),
        'actual_speed': np.mean(
            np.asarray(ripple_position_info.speed)),
        'actual_velocity_center_well': np.mean(
            np.asarray(ripple_position_info.linear_velocity)),
        'replay_distance_from_actual_position': np.mean(
            replay_distance_from_actual_position),
        'replay_speed': np.mean(replay_speed),
        'replay_velocity_actual_position': np.mean(
            replay_velocity_actual_position),
        'replay_velocity_center_well': np.mean(replay_velocity_center_well),
        'replay_distance_from_center_well': np.mean(
            replay_distance_from_center_well),
        'replay_linear_position': np.mean(map_estimate),
        'replay_total_distance': np.sum(distance_change),
        'replay_total_displacement': replay_total_displacement,
        'state_order': get_state_order(is_classified),
        'spatial_coverage': np.mean(spatial_coverage),
        'spatial_coverage_percentage': np.mean(spatial_coverage_percentage),
        'mean_ripple_consensus_trace_zscore': np.mean(
            ripple_consensus),
        'max_ripple_consensus_trace_zscore': np.max(
            ripple_consensus),
        'n_ripples': label(ripple_consensus > 2.5)[1]
    }

    for state, above_threshold in is_classified.groupby('state'):
        above_threshold = above_threshold.astype(bool).values.squeeze()
        metrics[f'{state}'] = np.sum(above_threshold) > 0
        try:
            metrics[f'{state}_max_probability'] = np.max(
                np.asarray(probability.sel(state=state)))
        except (KeyError, ValueError):
            metrics[f'{state}_max_probability'] = np.nan

        metrics[f'{state}_duration'] = duration(
            above_threshold, sampling_frequency)
        metrics[f'{state}_fraction_of_time'] = fraction_of_time(
            above_threshold, time)

        if np.any(above_threshold):
            metrics[f'{state}_replay_distance_from_actual_position'] = np.mean(
                replay_distance_from_actual_position[above_threshold])  # cm
            metrics[f'{state}_replay_speed'] = np.mean(
                replay_speed[above_threshold])  # cm / s
            metrics[f'{state}_replay_velocity_actual_position'] = np.mean(
                replay_velocity_actual_position[above_threshold])  # cm / s
            metrics[f'{state}_replay_velocity_center_well'] = np.mean(
                replay_velocity_center_well[above_threshold])  # cm / s
            metrics[f'{state}_replay_distance_from_center_well'] = np.mean(
                replay_distance_from_center_well[above_threshold])  # cm
            metrics[f'{state}_replay_linear_position'] = get_replay_linear_position(
                above_threshold, map_estimate)  # cm
            metrics[f'{state}_replay_total_distance'] = np.sum(
                distance_change[above_threshold])  # cm
            metrics[f'{state}_min_time'] = np.min(time[above_threshold])  # s
            metrics[f'{state}_max_time'] = np.max(time[above_threshold])  # s
            metrics[f'{state}_n_unique_spiking'] = n_tetrodes_active(
                ripple_spikes.iloc[above_threshold])
            metrics[f'{state}_n_total_spikes'] = n_total_spikes(
                ripple_spikes.iloc[above_threshold])
            metrics[f'{state}_median_fraction_spikes_under_6_ms'] = np.nanmedian(
                fraction_spikes_less_than_6_ms(
                    ripple_spikes.iloc[above_threshold], sampling_frequency)
            )
            metrics[f'{state}_population_rate'] = population_rate(
                ripple_spikes.iloc[above_threshold], sampling_frequency)
            metrics[f'{state}_median_spikes_per_bin'] = median_spikes_per_bin(
                ripple_spikes.loc[above_threshold])
            metrics[f'{state}_spatial_coverage'] = np.median(
                spatial_coverage[above_threshold])  # cm
            metrics[f'{state}_spatial_coverage_percentage'] = np.median(
                spatial_coverage_percentage[above_threshold])
            metrics[f"{state}_Hov_avg_prob"] = float(
                probability.sel(state="Hover").isel(
                    time=above_threshold).mean()
            )
            metrics[f"{state}_Cont_avg_prob"] = float(
                probability.sel(state="Continuous").isel(
                    time=above_threshold).mean()
            )
            metrics[f"{state}_Frag_avg_prob"] = float(
                probability.sel(state="Fragmented").isel(
                    time=above_threshold).mean()
            )
            metrics[f"{state}_mean_ripple_consensus_trace_zscore"] = np.mean(
                ripple_consensus[above_threshold])
            metrics[f"{state}_max_ripple_consensus_trace_zscore"] = np.max(
                ripple_consensus[above_threshold])

    return metrics


def get_replay_linear_position(is_classified, map_estimate):
    labels, n_labels = scipy.ndimage.label(is_classified)
    return np.asarray([np.mean(map_estimate[labels == label])
                       for label in range(1, n_labels + 1)])


def get_n_unique_spiking(ripple_spikes):
    try:
        return (ripple_spikes.groupby('ripple_number').sum() > 0).sum(axis=1)
    except KeyError:
        return (ripple_spikes.groupby('event_number').sum() > 0).sum(axis=1)


def get_n_total_spikes(ripple_spikes):
    try:
        return ripple_spikes.groupby('ripple_number').sum().sum(axis=1)
    except KeyError:
        return ripple_spikes.groupby('event_number').sum().sum(axis=1)


def n_tetrodes_active(spikes):
    return (np.asarray(spikes).sum(axis=0) > 0).sum()


def n_total_spikes(spikes):
    return np.asarray(spikes).sum().astype(int)


def median_spikes_per_bin(spikes):
    return np.median(np.asarray(spikes).sum(axis=1))


def _fraction_spikes_less_than_6_ms(spikes, sampling_frequency):
    interspike_interval = (
        1000 * np.diff(np.nonzero(spikes)[0]) / sampling_frequency)  # ms
    return np.nanmean(interspike_interval < 6)


def fraction_spikes_less_than_6_ms(spikes, sampling_frequency):
    return np.asarray(
        [_fraction_spikes_less_than_6_ms(
            spikes_per_tetrode, sampling_frequency)
         for spikes_per_tetrode in np.asarray(spikes).T])


def duration(above_threshold, sampling_frequency):
    return np.nansum(np.asarray(above_threshold)) / sampling_frequency  # ms


def fraction_of_time(above_threshold, time):
    return np.nansum(np.asarray(above_threshold)) / len(time)


def population_rate(spikes, sampling_frequency):
    return sampling_frequency * np.asarray(spikes).mean()


def maximum_a_posteriori_estimate(posterior_density):
    '''

    Parameters
    ----------
    posterior_density : xarray.DataArray, shape (n_time, n_x_bins, n_y_bins)

    Returns
    -------
    map_estimate : ndarray, shape (n_time,)

    '''
    try:
        stacked_posterior = np.log(posterior_density.stack(
            z=['x_position', 'y_position']))
        map_estimate = stacked_posterior.z[stacked_posterior.argmax('z')]
        map_estimate = np.asarray(map_estimate.values.tolist())
    except KeyError:
        map_estimate = posterior_density.position[
            np.log(posterior_density).argmax('position')]
        map_estimate = np.asarray(map_estimate)[:, np.newaxis]
    return map_estimate


def get_place_field_max(classifier):
    try:
        max_ind = classifier.place_fields_.argmax('position')
        return np.asarray(
            classifier.place_fields_.position[max_ind].values.tolist())
    except AttributeError:
        return np.asarray(
            [classifier.place_bin_centers_[gpi.argmax()]
             for gpi in classifier.ground_process_intensities_])


def get_linear_position_order(position_info, place_field_max):
    position = position_info.loc[:, ['x_position', 'y_position']]
    linear_place_field_max = []

    for place_max in place_field_max:
        min_ind = np.sqrt(
            np.sum(np.abs(place_max - position) ** 2, axis=1)).idxmin()
        linear_place_field_max.append(
            position_info.loc[min_ind, 'linear_position'])

    linear_place_field_max = np.asarray(linear_place_field_max)
    return np.argsort(linear_place_field_max), linear_place_field_max


def reshape_to_segments(time_series, segments):
    df = []
    for row in segments.itertuples():
        row_series = time_series.loc[row.start_time:row.end_time]
        row_series.index = row_series.index - row_series.index[0]
        df.append(row_series)

    return pd.concat(df, axis=0, keys=segments.index).sort_index()


def _get_closest_ind(map_estimate, all_positions):
    map_estimate = np.asarray(map_estimate)
    all_positions = np.asarray(all_positions)
    return np.argmin(np.linalg.norm(
        map_estimate[:, np.newaxis, :] - all_positions[np.newaxis, ...],
        axis=-2), axis=1)


def _get_projected_track_positions(position, track_segments, track_segment_id):
    projected_track_positions = project_points_to_segment(
        track_segments, position)
    n_time = projected_track_positions.shape[0]
    projected_track_positions = projected_track_positions[(
        np.arange(n_time), track_segment_id)]
    return projected_track_positions


def get_state_order(is_classified):
    order = is_classified.state[
        is_classified[is_classified.sum("state").astype(bool)].argmax("state")
    ]

    return [
        current_state
        for ind, (previous_state, current_state)
        in enumerate(zip(order.values[:-1], order.values[1:]))
        if current_state != previous_state or ind == 0
    ]


def highest_posterior_density(posterior_density, coverage=0.95):
    """
    Same as credible interval
    https://stats.stackexchange.com/questions/240749/how-to-find-95-credible-interval

    Parameters
    ----------
    posterior_density : xarray.DataArray, shape (n_time, n_position_bins) or
        shape (n_time, n_x_bins, n_y_bins)
    coverage : float, optional

    Returns
    -------
    threshold : ndarray, shape (n_time,)

    """
    try:
        posterior_density = posterior_density.stack(
            z=["x_position", "y_position"]
        ).values
    except KeyError:
        posterior_density = posterior_density.values
    const = np.sum(posterior_density, axis=1, keepdims=True)
    sorted_norm_posterior = np.sort(posterior_density, axis=1)[:, ::-1] / const
    posterior_less_than_coverage = np.cumsum(
        sorted_norm_posterior, axis=1) >= coverage
    crit_ind = np.argmax(posterior_less_than_coverage, axis=1)
    # Handle case when there are no points in the posterior less than coverage
    crit_ind[posterior_less_than_coverage.sum(axis=1) == 0] = (
        posterior_density.shape[1] - 1
    )

    n_time = posterior_density.shape[0]
    threshold = sorted_norm_posterior[(
        np.arange(n_time), crit_ind)] * const.squeeze()
    return threshold


def gaussian_smooth(data, sigma, sampling_frequency, axis=0):
    '''1D convolution of the data with a Gaussian.

    The standard deviation of the gaussian is in the units of the sampling
    frequency. The function is just a wrapper around scipy's
    `gaussian_filter1d`, The support is truncated at 8 by default, instead
    of 4 in `gaussian_filter1d`

    Parameters
    ----------
    data : array_like
    sigma : float
    sampling_frequency : int
    axis : int, optional

    Returns
    -------
    smoothed_data : array_like

    '''
    return gaussian_filter1d(
        data, sigma * sampling_frequency, axis=axis)


def load_all_replay_info(
    n_unique_spiking=2,
    data_type="clusterless",
    dim="1D",
    probability_threshold=PROBABILITY_THRESHOLD,
    speed_threshold=4,
    exclude_interneuron_spikes=False,
    use_multiunit_HSE=False,
    brain_areas=None,
):
    tetrode_info = make_tetrode_dataframe(ANIMALS)
    prob = int(probability_threshold * 100)
    epoch_identifier = f'*_{data_type}_{dim}'

    if exclude_interneuron_spikes:
        epoch_identifier += '_no_interneuron'

    if brain_areas is not None:
        area_str = '-'.join(brain_areas)
        epoch_identifier += f'_{area_str}'
    else:
        brain_areas = _BRAIN_AREAS

    if use_multiunit_HSE:
        epoch_identifier += '_multiunit_HSE'

    file_regex = f"{epoch_identifier}_replay_info_{prob:02d}.csv"
    file_paths = glob(os.path.join(PROCESSED_DATA_DIR, file_regex))
    replay_info = pd.concat(
        [pd.read_csv(file_path) for file_path in file_paths], axis=0,
    )
    try:
        replay_info = replay_info.set_index(
            ["animal", "day", "epoch", "ripple_number"])
    except KeyError:
        replay_info = replay_info.set_index(
            ["animal", "day", "epoch", "event_number"])
    replay_info["fraction_unclassified"] = (
        replay_info.Unclassified_duration
        / replay_info.duration
    )
    replay_info["duration_classified"] = (
        replay_info.duration - replay_info.Unclassified_duration)
    replay_info = replay_info.loc[
        (replay_info.n_unique_spiking >= n_unique_spiking) &
        (replay_info.actual_speed <= speed_threshold)
    ].sort_index()

    is_brain_areas = tetrode_info.area.astype(
        str).str.upper().isin(brain_areas)
    n_tetrodes = (
        tetrode_info.loc[is_brain_areas]
        .groupby(["animal", "day", "epoch"])
        .tetrode_id.count()
        .rename("n_tetrodes")
    )
    replay_info = pd.merge(
        replay_info.reset_index(), pd.DataFrame(n_tetrodes).reset_index()
    )
    try:
        replay_info = replay_info.set_index(
            ["animal", "day", "epoch", "ripple_number"])
    except KeyError:
        replay_info = replay_info.set_index(
            ["animal", "day", "epoch", "event_number"])

    replay_info = replay_info.rename(index={"Cor": "cor"}).rename_axis(
        index={"animal": "Animal ID"}
    )

    return replay_info


def get_map_speed(
    posterior,
    track_graph1,
    place_bin_center_ind_to_node,
    dt,
):
    map_position_ind = np.argmax(posterior, axis=1)
    node_ids = place_bin_center_ind_to_node[map_position_ind]
    n_time = len(node_ids)
    if n_time == 1:
        return np.asarray([np.nan])
    elif n_time == 2:
        speed = np.asarray([])
        speed = np.insert(
            speed,
            0,
            nx.shortest_path_length(
                track_graph1, source=node_ids[0], target=node_ids[1],
                weight="distance",
            )
            / dt,
        )
        speed = np.insert(
            speed,
            -1,
            nx.shortest_path_length(
                track_graph1, source=node_ids[-2], target=node_ids[-1],
                weight="distance",
            )
            / dt,
        )
    else:
        speed = []
        for node1, node2 in zip(node_ids[:-2], node_ids[2:]):
            speed.append(
                nx.shortest_path_length(
                    track_graph1, source=node1, target=node2,
                    weight="distance",
                )
                / (2.0 * dt)
            )
        speed = np.asarray(speed)
        speed = np.insert(
            speed,
            0,
            nx.shortest_path_length(
                track_graph1, source=node_ids[0], target=node_ids[1],
                weight="distance",
            )
            / dt,
        )
        speed = np.insert(
            speed,
            -1,
            nx.shortest_path_length(
                track_graph1, source=node_ids[-2], target=node_ids[-1],
                weight="distance",
            )
            / dt,
        )
    return np.abs(speed)
back to top