https://github.com/VignoliniLab/PyLlama
Raw File
Tip revision: e54fdb5700002ac9d6072d315a18766c807c1c5e authored by mmbay on 03 May 2022, 20:06:44 UTC
added acknowledgements
Tip revision: e54fdb5
geometry.py
"""
Geometry
====

Provides
  A visual representation of cylinders in a
  Matplotlib 3D plot.


Class
  Cylinder: contains the parameters of the
  cylinder (centre, axis, length...) and
  enables the manipulation of cylinders
  and their representation in 3D volumes
  with user-chosen resolution.

Author
  Mélanie Bay (mmb54@cam.ac.uk)
"""


import random as rnd
import numpy as np
from scipy.linalg import norm
from warnings import warn
from matplotlib.patches import FancyArrowPatch   # for the arrows
from mpl_toolkits.mplot3d import proj3d


def set_axes_equal(ax, manual_lims=None):
    """
    From https://stackoverflow.com/questions/13685386/matplotlib-equal-unit-length-with-equal-aspect-ratio-z-axis-is-not-equal-to
    Make axes of 3D plot have equal scale so that spheres appear as spheres,
    cubes as cubes, etc..  This is one possible solution to Matplotlib's
    ax.set_aspect('equal') and ax.axis('equal') not working for 3D.
    :param ax: a Matplotlib axis, e.g., as output from plt.gca().
    :return:
    """

    if manual_lims == None:
        x_limits = ax.get_xlim3d()
        y_limits = ax.get_ylim3d()
        z_limits = ax.get_zlim3d()
    else:
        x_limits = manual_lims[0]
        y_limits = manual_lims[1]
        z_limits = manual_lims[2]

    x_range = abs(x_limits[1] - x_limits[0])
    x_middle = np.mean(x_limits)
    y_range = abs(y_limits[1] - y_limits[0])
    y_middle = np.mean(y_limits)
    z_range = abs(z_limits[1] - z_limits[0])
    z_middle = np.mean(z_limits)

    # The plot bounding box is a sphere in the sense of the infinity
    # norm, hence I call half the max range the plot radius.
    plot_radius = 0.5*max([x_range, y_range, z_range])

    ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
    ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
    ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])


def line_two_points(p1, p2):
    """
    Returns the x, y, z to use in plot3D(x, y, z) to plot a line between two points
    :param p1: Numpy array of the first point, coordinates x, y, z
    :param p2: Numpy array of the second point, coordinates x, y, z
    :return: x, y, z to plot
    """
    return [p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]]


def line_point_direction(p1, dir, length=1):
    """
    Returns the x, y, z to use in plot3D(x, y, z) to plot a line given a starting point, a direction and a length
    """
    p2 = p1 + length * dir
    return line_two_points(p1, p2)


def line_middle_direction(p1, dir, length=1):
    """
    Returns the x, y, z to use in plot3D(x, y, z) to plot a line given a point in the middle of the line, a direction and a length
    """
    p2 = p1 + length * dir
    p0 = p1 - length * dir
    return line_two_points(p0, p2)


def translate_point(p1, dir, length=1):
    """
    Returns a translated point given a starting point, a direction and a length
    """
    p2 = p1 + length * dir
    return p2


class Cylinder(object):

    def __init__(self, centre=np.array([0, 0, 0]), direction=np.array([0, 0, 1]), length=1, radius=0.25, color='white'):
        self.centre = centre
        self.direction = direction / norm(direction)
        self.length = length
        self.radius = radius
        self.color = color
        self.ptsTubX = None  # Meshed points
        self.ptsTubY = None  # Meshed points
        self.ptsTubZ = None  # Meshed points
        self.ptsTopX = None  # Meshed points
        self.ptsTopY = None  # Meshed points
        self.ptsTopZ = None  # Meshed points
        self.ptsBotX = None  # Meshed points
        self.ptsBotY = None  # Meshed points
        self.ptsBotZ = None  # Meshed points

    @property
    def centre(self):
        return self.__centre

    @centre.setter
    def centre(self, centre):
        self.__centre = centre
        self.delPoints()  # The points don't match the structure anymore: delete!

    @property
    def direction(self):
        return self.__direction

    @direction.setter
    def direction(self, direction):
        self.__direction = direction / norm(direction)
        self.delPoints()  # The points don't match the structure anymore: delete!

    @property
    def length(self):
        return self.__length

    @length.setter
    def length(self, length):
        if length < 0:
            length = - length
            warn("Negative Cylinder length. Absolute value has been taken.")
        self.__length = length
        self.delPoints()  # The points don't match the structure anymore: delete!

    @property
    def radius(self):
        return self.__radius

    @radius.setter
    def radius(self, radius):
        if radius < 0:
            radius = - radius
            warn("Negative Cylinder radius. Absolute value has been taken.")
        self.__radius = radius
        self.delPoints()  # The points don't match the structure anymore: delete!

    def delPoints(self):
        self.ptsTubX = None
        self.ptsTubY = None
        self.ptsTubZ = None
        self.ptsTopX = None
        self.ptsTopY = None
        self.ptsTopZ = None
        self.ptsBotX = None
        self.ptsBotY = None
        self.ptsBotZ = None

    def posrandomise(self, std_pos):
        """
        Takes as input a Cylinder and updates it with a random position, parameter tau
        :param cyl: geometry.Cylinder
        :param std_pos: for generation of random number
        :return: nothing (it updates the Cylinder)
        """
        ct = self.centre
        newx = ct[0] + rnd.gauss(0, std_pos)
        newy = ct[1] + rnd.gauss(0, std_pos)
        newz = ct[2] + rnd.gauss(0, std_pos)
        self.centre = np.array([newx, newy, newz])
        self.delPoints()

    def rotrandomise(self, std_pos):
        """
        Takes as input a Cylinder and updates it with a random axis (direction), parameter twist
        Randomness of rotation around x, around y and around z
        :param cyl: geometry.Cylinder
        :param std_pos: for generation of random number
        :return: nothing (it updates the Cylinder)
        """
        self.rotate(axis='x', theta=np.radians(rnd.gauss(0, std_pos)), point=self.centre)
        self.rotate(axis='y', theta=np.radians(rnd.gauss(0, std_pos)), point=self.centre)
        self.rotate(axis='z', theta=np.radians(rnd.gauss(0, std_pos)), point=self.centre)
        self.delPoints()

    def lengthrandomise(self, std_rodlength):
        """
        Takes as input a Cylinder and updates it with a random length, parameter tau
        :param cyl: geometry.Cylinder
        :param std_rodlength: for generation of random number
        :return: nothing (it updates the Cylinder)
        """
        self.length = rnd.gauss(self.length, std_rodlength)
        self.delPoints()

    def rotate(self, axis='z', theta=0, point=np.array([0, 0, 0])):
        """
        Rotate the Cylinder around an axis (direction) from an angle theta_rad, and the axis has a position
        :param axis: numpy array, or string that can be 'x', 'y' or 'z'
        :param theta: radians
        :param point: numpy array
        :return:
        """
        # Axis as char to make it easier to read
        if isinstance(axis, str):
            if axis == 'x':
                axis = np.array([1, 0, 0])
            elif axis == 'y':
                axis = np.array([0, 1, 0])
            elif axis == 'z':
                axis = np.array([0, 0, 1])
            else:
                raise Exception('Invalid rotation axis.')
        if norm(axis) == 0:
            raise Exception('Invalid axis. Axis can not be (0, 0, 0).')
        else:
            # Translate everything from point to origin
            old_dir = self.direction
            old_ctr = self.centre
            old_ctr = old_ctr - point
            # Perform the rotation around the axis
            # Matrix here: https://en.wikipedia.org/wiki/Rotation_matrix
            axis = axis / norm(axis)
            ux = axis[0]
            uy = axis[1]
            uz = axis[2]
            costheta = np.cos(theta)
            sintheta = np.sin(theta)
            r = np.array([[costheta+(ux**2)*(1-costheta), ux*uy*(1-costheta)-uz*sintheta, ux*uz*(1-costheta)+uy*sintheta],
                          [uy*ux*(1-costheta)+uz*sintheta, costheta+(uy**2)*(1-costheta), uy*uz*(1-costheta)-ux*sintheta],
                          [uz*ux*(1-costheta)-uy*sintheta, uz*uy*(1-costheta)+ux*sintheta, costheta+(uz**2)*(1-costheta)]])
            # Translate everything back from origin to point
            new_dir = old_dir.dot(r)
            new_ctr = old_ctr.dot(r) + point
            # Save
            self.direction = new_dir/norm(new_dir)
            self.centre = new_ctr
            # Delete the points because the cylinder has been updated
            self.delPoints()

    def getPoints(self, resolution=10):
        """
        Calculates all the points to represent the Cylinder in 3D
        :param resolution: integer that defines the mesh size
        From here:
        https://stackoverflow.com/questions/39822480/plotting-a-solid-cylinder-centered-on-a-plane-in-matplotlib
        """
        # Make non-collinear unit vector
        vec_nc = np.array([rnd.random(), rnd.random(), rnd.random()])
        while norm(np.cross(self.direction, vec_nc)) == 0:
            vec_nc = np.array([rnd.random(), rnd.random(), rnd.random()])
        vec_nc = vec_nc / norm(vec_nc)
        # Make unit vectors perpendicular to direction
        vec1 = np.cross(self.direction, vec_nc)
        vec1 = vec1 / norm(vec1)
        # Make third vector of the set
        vec2 = np.cross(self.direction, vec1)
        # Mesh the surface
        pts_length = np.linspace(0, self.length, 2)
        pts_theta = np.linspace(0, 2 * np.pi, resolution)
        pts_radius = np.linspace(0, self.radius, 2)
        mesh_length, mesh_thetal = np.meshgrid(pts_length, pts_theta)
        mesh_r, mesh_thetar = np.meshgrid(pts_radius, pts_theta)
        # Find the bottom point
        p0 = self.centre - 0.5 * self.length * self.direction
        # Generate coordinates for the tube
        # Bottom-center point + take all perpendicular directions from the axis and plot on a circle perpendicular to the axis
        X, Y, Z = [p0[i] + self.direction[i] * pts_length + self.radius * np.sin(mesh_thetal) * vec1[i] + self.radius * np.cos(mesh_thetal) * vec2[i] for i in [0, 1, 2]]
        # Generate coordinates for the bottom
        X2, Y2, Z2 = [p0[i] + mesh_r[i] * np.sin(mesh_thetar) * vec1[i] + mesh_r[i] * np.cos(mesh_thetar) * vec2[i] for i in [0, 1, 2]]
        # Generate coordinates for the top
        X3, Y3, Z3 = [p0[i] + self.direction[i] * self.length + mesh_r[i] * np.sin(mesh_thetar) * vec1[i] + mesh_r[i] * np.cos(mesh_thetar) * vec2[i] for i in [0, 1, 2]]
        # Store points
        self.ptsTubX = X
        self.ptsTubY = Y
        self.ptsTubZ = Z
        self.ptsBotX = X2
        self.ptsBotY = Y2
        self.ptsBotZ = Z2
        self.ptsTopX = X3
        self.ptsTopY = Y3
        self.ptsTopZ = Z3

    def plot(self, ax, resolution=10):
        """
        Plots the Cylinder in 3D in a plot
        :param ax: axis from Pyplot (in 3D)
        :return:
        """
        self.getPoints(resolution)
        ax.plot_surface(self.ptsTubX, self.ptsTubY, self.ptsTubZ, color=self.color)
        ax.plot_surface(self.ptsTopX, self.ptsTopY, self.ptsTopZ, color='green')  # green
        ax.plot_surface(self.ptsBotX, self.ptsBotY, self.ptsBotZ, color='red')  # red

    def copy(self, cyl):
        """
        Copies all info from one Cylinder to another
        The empty Cylinder must be created first
        :param cyl: the Cylinder to duplicate
        :return: nothing, updates the Cylinder
        """
        self.centre = cyl.centre
        self.direction = cyl.direction
        self.length = cyl.length
        self.radius = cyl.radius
        self.color = cyl.colour
        self.ptsTubX = cyl.ptsTubX
        self.ptsTubY = cyl.ptsTubY
        self.ptsTubZ = cyl.ptsTubZ
        self.ptsTopX = cyl.ptsTopX
        self.ptsTopY = cyl.ptsTopY
        self.ptsTopZ = cyl.ptsTopZ
        self.ptsBotX = cyl.ptsBotX
        self.ptsBotY = cyl.ptsBotY
        self.ptsBotZ = cyl.ptsBotZ


class Arrow3D(FancyArrowPatch):
    # https://stackoverflow.com/questions/22867620/putting-arrowheads-on-vectors-in-matplotlibs-3d-plot
    def __init__(self, xs, ys, zs, *args, **kwargs):
        FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        FancyArrowPatch.draw(self, renderer)


back to top