https://github.com/geomstats/geomstats
Raw File
Tip revision: 2e530e7934d7b49090232743bb83bf0e52f1e815 authored by Nina Miolane on 09 March 2021, 02:40:46 UTC
Merge branch 'stable' of github.com:geomstats/geomstats into stable
Tip revision: 2e530e7
plot_knn_s2.py
"""Plot the result of a KNN classification on the sphere."""

import logging
import os

import matplotlib.pyplot as plt

import geomstats.backend as gs
import geomstats.visualization as visualization
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.learning.knn import KNearestNeighborsClassifier


def main():
    """Plot the result of a KNN classification on the sphere."""
    sphere = Hypersphere(dim=2)
    sphere_distance = sphere.metric.dist

    n_labels = 2
    n_samples_per_dataset = 10
    n_targets = 200

    dataset_1 = sphere.random_von_mises_fisher(
        kappa=10,
        n_samples=n_samples_per_dataset)
    dataset_2 = - sphere.random_von_mises_fisher(
        kappa=10,
        n_samples=n_samples_per_dataset)
    training_dataset = gs.concatenate((dataset_1, dataset_2), axis=0)
    labels_dataset_1 = gs.zeros([n_samples_per_dataset], dtype=gs.int64)
    labels_dataset_2 = gs.ones([n_samples_per_dataset], dtype=gs.int64)
    labels = gs.concatenate((labels_dataset_1, labels_dataset_2))
    target = sphere.random_uniform(n_samples=n_targets)

    neigh = KNearestNeighborsClassifier(
        n_neighbors=2,
        distance=sphere_distance)
    neigh.fit(training_dataset, labels)
    target_labels = neigh.predict(target)

    plt.figure(0)
    ax = plt.subplot(111, projection='3d')
    plt.title('Training set')
    sphere_plot = visualization.Sphere()
    sphere_plot.draw(ax=ax)
    for i_label in range(n_labels):
        points_label_i = training_dataset[labels == i_label, ...]
        sphere_plot.draw_points(ax=ax, points=points_label_i)

    plt.figure(1)
    ax = plt.subplot(111, projection='3d')
    plt.title('Classification')
    sphere_plot = visualization.Sphere()
    sphere_plot.draw(ax=ax)
    for i_label in range(n_labels):
        target_points_label_i = target[target_labels == i_label, ...]
        sphere_plot.draw_points(ax=ax, points=target_points_label_i)

    plt.show()


if __name__ == '__main__':
    if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow':
        logging.info('Examples with visualizations are only implemented '
                     'with numpy backend.\n'
                     'To change backend, write: '
                     'export GEOMSTATS_BACKEND = \'numpy\'.')
    else:
        main()
back to top