https://github.com/geomstats/geomstats
Tip revision: 38fd937dd5cfee3851c30579f635767d7d670d55 authored by L. F. Pereira on 28 June 2023, 09:42:46 UTC
Add pytest-xdist to dependencies
Add pytest-xdist to dependencies
Tip revision: 38fd937
plot_geodesics_poincare_polydisk.py
"""Plot a geodesic on the Poincare polydisk.
Plot a geodesic on the Poincare polydisk,
with Poincare Disk visualization.
"""
import matplotlib.pyplot as plt
import geomstats.backend as gs
import geomstats.visualization as visualization
from geomstats.geometry.poincare_polydisk import PoincarePolydisk
N_DISKS = 4
POINCARE_POLYDISK = PoincarePolydisk(n_disks=N_DISKS)
METRIC = POINCARE_POLYDISK.metric
def plot_geodesic_between_two_points(initial_point, end_point, n_steps=11, ax=None):
"""Plot the geodesic between two points."""
geodesic = METRIC.geodesic(initial_point=initial_point, end_point=end_point)
t = gs.linspace(0.0, 1.0, n_steps)
points = geodesic(t)
visualization.plot(points, ax=ax, space="poincare_polydisk")
def plot_geodesic_with_initial_tangent_vector(
initial_point, initial_tangent_vec, n_steps=11, ax=None
):
"""Plot the geodesic with initial speed the tangent vector."""
geodesic = METRIC.geodesic(
initial_point=initial_point, initial_tangent_vec=initial_tangent_vec
)
t = gs.linspace(0.0, 1.0, n_steps)
points = geodesic(t)
visualization.plot(points, ax=ax, space="poincare_polydisk")
def main():
"""Plot the geodesics."""
initial_point = gs.array([gs.sqrt(2.0), 1.0, 0.0])
stack_initial_point = gs.stack([initial_point] * N_DISKS, axis=0)
initial_point = gs.to_ndarray(stack_initial_point, to_ndim=3)
end_point_intrinsic = gs.array([1.5, 1.5])
end_point_intrinsic = gs.reshape(end_point_intrinsic, (1, 1, 2))
end_point = POINCARE_POLYDISK.intrinsic_to_extrinsic_coords(end_point_intrinsic)
end_point = gs.concatenate([end_point] * N_DISKS, axis=1)
vector = gs.array([3.5, 0.6, 0.8])
stack_vector = gs.stack([vector] * N_DISKS, axis=0)
vector = gs.to_ndarray(stack_vector, to_ndim=3)
initial_tangent_vec = POINCARE_POLYDISK.to_tangent(
vector=vector, base_point=initial_point
)
fig = plt.figure()
plot_geodesic_between_two_points(
initial_point=initial_point, end_point=end_point, ax=fig
)
plot_geodesic_with_initial_tangent_vector(
initial_point=initial_point, initial_tangent_vec=initial_tangent_vec, ax=fig
)
plt.show()
if __name__ == "__main__":
main()