Source code for osaft.plotting.scattering.particle_plots

from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation

from osaft.core.functions import pi
from osaft.plotting.datacontainers.scattering_datacontainer import (
    ParticleScatteringData,
)
from osaft.plotting.datacontainers.wireframe_datacontainer import (
    ParticleWireframeData,
)
from osaft.plotting.scattering.base_plotter import BaseScatteringPlotter
from osaft.plotting.scattering.tri_plotter import TriangulationPlotter
from osaft.plotting.scattering.wireframe_plotter import WireframePlotter
from osaft.solutions.base_scattering import BaseScattering


[docs]class ParticleScatteringPlot: """Class for plotting scattering field of the particle Plots the acoustic field inside the particle using Matplotlib tricontourf or tripcolor plotting methods. :param sol: solution to be plotted :param theta_min: lower limit for tangential plot range :param theta_max: upper limit for tangential plot range :param resolution: if tuple (radial resolution, tangential resolution) :param cmap: color map """ def __init__( self, sol: BaseScattering, theta_min: float = 0, theta_max: float = pi, resolution: Union[int, tuple[int, int]] = 100, cmap: str = 'winter', ): """Constructor method""" self.data = ParticleScatteringData( sol, theta_min, theta_max, resolution, ) self.plotter = TriangulationPlotter(True, cmap) # ------------------------------------------------------------------------- # API # -------------------------------------------------------------------------
[docs] def plot( self, inst: bool = True, phase: float = 0, mode: Optional[int] = None, displacement: bool = True, symmetric: bool = True, tripcolor: bool = False, ax: Optional[plt.Axes] = None, **kwargs, ) -> tuple[plt.Figure, plt.Axes]: """Tricontourf plot for acoustic velocity field of the particle Plots the velocity amplitude of the first-order acoustic velocity field of the particle using Matplotlib's `tricontourf <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.tricontour.html>`_ or `tripcolor <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.tripcolor.html>`_ if ˜tripcolor = True` :param inst: if ``True`` instantaneous amplitude is plotted :param phase: phase [0, 2 * pi] :param mode: mode of oscillation :param displacement: if ``True`` displacement, else velocity plot :param symmetric: if ``True`` the symmetry of the solution is used :param tripcolor: switches between tripcolor and tricontourf plot :param ax: Axes object :param kwargs: passed through to tricontourf() """ return self._triangulation_plot( tripcolor=tripcolor, inst=inst, phase=phase, mode=mode, displacement=displacement, symmetric=symmetric, ax=ax, **kwargs )
[docs] def animate( self, frames: int = 64, interval: float = 100.0, mode: Optional[int] = None, displacement: bool = True, symmetric: bool = True, tripcolor: bool = False, ax: Optional[plt.Axes] = None, **kwargs, ) -> FuncAnimation: """Tricontourf animation for acoustic velocity field of the particle Animates the velocity amplitude of the first-order acoustic velocity field of the particle over one period using Matplotlib's `tricontourf <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.tricontour.html>`_ or `tripcolor <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.tripcolor.html>`_ if ˜tripcolor = True` :param frames: number of frames for the animation :param interval: interval between frames in ms :param mode: mode of oscillation :param displacement: if ``True`` displacement, else velocity plot :param symmetric: if ``True`` the symmetry of the solution is used :param tripcolor: switches between tripcolor and tricontourf plot :param ax: Axes object :param kwargs: passed through to tricontourf() """ return self._triangulation_animation( tripcolor=tripcolor, frames=frames, interval=interval, mode=mode, displacement=displacement, symmetric=symmetric, ax=ax, **kwargs, )
# ------------------------------------------------------------------------- # Private Methods # ------------------------------------------------------------------------- def _triangulation_plot( self, tripcolor: bool, inst: bool, phase: float, mode: Optional[int], displacement: bool, symmetric: bool, ax: Optional[plt.Axes], **kwargs, ) -> tuple[plt.Figure, plt.Axes]: """Helper function for tripcolor/tricontourf plot :param tripcolor: if ``True`` tripcolor, else tricontourf plot :param inst: if ``True`` instantaneous amplitude is plotted :param phase: phase [0, 2 * pi] :param mode: mode of oscillation :param displacement: if ``True`` displacement, else velocity plot :param symmetric: if ``True`` the symmetry of the solution is used :param ax: Axes object :param kwargs: passed through to plotting method """ # Data if displacement: X, Z, C = self.data.get_displacement_magnitude(inst, phase, mode) cbar_label = 'Displacement [m]' else: X, Z, C = self.data.get_velocity_magnitude(inst, phase, mode) cbar_label = 'Velocity [m/s]' # Replace nan in the center C[0] = C[1] # Plot fig, ax, _, _, _ = self.plotter.plot( X=Z, Y=X, C=C, radius=self.data.sol.R_0, symmetric=symmetric, tripcolor=tripcolor, cbar_label=cbar_label, use_diverging_cmap=False, ax=ax, **kwargs, ) return fig, ax def _triangulation_animation( self, tripcolor: bool, frames: int, interval: float, mode: Optional[int], displacement: bool, symmetric: bool, ax: Optional[plt.Axes], **kwargs, ) -> FuncAnimation: """Helper function for tripcolor/tricontourf animation :param tripcolor: if ``True`` tripcolor, else tricontourf plot :param frames: number of frames for the animation :param interval: interval between frames in ms :param mode: mode of oscillation :param displacement: if ``True`` displacement, else velocity plot :param symmetric: if ``True`` the symmetry of the solution is used :param ax: Axes object :param kwargs: passed through to tricontourf() """ # Data function for animation def data_func(phase: float) -> tuple[ np.ndarray, np.ndarray, np.ndarray, ]: """Returns the velocity field for given phase Closure is used to fix all additional parameters :param phase: phase """ if displacement: return self.data.get_displacement_magnitude(True, phase, mode) else: return self.data.get_velocity_magnitude(True, phase, mode) # Color bar label if displacement: cbar_label = 'Displacement [m]' else: cbar_label = 'Velocity [m/s]' # Get norm if displacement: _, _, C = self.data.get_displacement_magnitude(False, mode=mode) else: _, _, C = self.data.get_velocity_magnitude(False, mode=mode) return self.plotter.animate( frames=frames, interval=interval, tripcolor=tripcolor, symmetric=symmetric, cbar_label=cbar_label, use_diverging_cmap=False, animate_meth=data_func, C_norm=C, radius=self.data.sol.R_0, ax=ax, **kwargs )
[docs] def plot_evolution( self, inst: bool = True, mode: Optional[int] = None, displacement: bool = False, symmetric: bool = True, tripcolor: bool = False, layout: tuple[int, int] = (3, 3), **kwargs, ) -> tuple[plt.Figure, plt.Axes]: """Tricontourf plot for acoustic velocity field evolution of the particle Plots the velocity amplitude of the first-order acoustic velocity field of the fluid over one period at different phases using Matplotlib's `tricontourf <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.tricontour.html>`_ or `tripcolor <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.tripcolor.html>`_ if ``tripcolor = True``. The first phase value is always :math:``0\\pi`` and the last one :math:``2\\pi``. The total number of plots and, hence, also the steps between the different phase values is the defined by the product of the ``layout`` tuple. :param inst: if ``True`` instantaneous amplitude is plotted :param mode: mode of oscillation :param displacement: if ``True`` displacement, else velocity plot :param symmetric: if ``True`` the symmetry of the solution is used :param tripcolor: switches between tripcolor and tricontourf plot :param layout: number of rows and columns for plotting :param kwargs: passed through to the parent subplots command """ n_row, n_col = layout n = n_col * n_row phases = np.linspace(0, 2, num=n) fig, axes = plt.subplots( n_row, n_col, sharex=True, sharey=True, **kwargs, ) # Data if displacement: cbar_label = 'Displacement [m]' else: cbar_label = 'Velocity [m/s]' # Data function for subplots def data_func(phase: float) -> tuple[ np.ndarray, np.ndarray, np.ndarray, ]: """Returns the velocity field for given phase Closure is used to fix all additional parameters :param phase: phase """ if displacement: return self.data.get_displacement_magnitude(inst, phase, mode) else: return self.data.get_velocity_magnitude(inst, phase, mode) # Get norm if displacement: _, _, C_norm = self.data.get_displacement_magnitude( False, mode=mode, ) else: _, _, C_norm = self.data.get_velocity_magnitude(False, mode=mode) for i, phase in enumerate(phases): row = i // n_col col = i % n_col ax = axes.flat[i] X, Z, C = data_func(phase=phase * np.pi) # Replace nan in the center C[0] = C[1] # Plot _, _, cnf, cbar, _ = self.plotter.plot( X=Z, Y=X, C=C, radius=self.data.sol.R_0, symmetric=symmetric, tripcolor=tripcolor, cbar_label=cbar_label, use_diverging_cmap=False, ax=ax, vmin=0, vmax=1.01 * np.nanmax(C_norm), ) # remove colorbar cbar.remove() ax.set_title(f'{phase:.2f}' + r'$\pi$') if row != (n_row - 1): ax.set_xlabel('') if col > 0: ax.set_ylabel('') fig.tight_layout() cbar = fig.colorbar(cnf, ax=axes.ravel().tolist()) cbar.ax.set_ylabel(cbar_label) return fig, axes
[docs]class ParticleWireframePlot(BaseScatteringPlotter): """Plotting class for wireframe plot of the particle :param sol: solution to be plotted :param nbr_r_levels: number of circles shown in the wireframe :param nbr_theta_levels: number of radii shown in the wireframe :param resolution: resolution, if tuple `(radial res, tangential res)` :param scale_factor: scaling factor for the displacements, if `None` auto """ def __init__( self, sol: BaseScattering, nbr_r_levels: int = 10, nbr_theta_levels: int = 10, resolution: Union[int, tuple[int, int]] = (100, 100), scale_factor: float = 0.1, ) -> None: # Call to parent class super().__init__(True) self.data = ParticleWireframeData( sol, nbr_r_levels, nbr_theta_levels, resolution, scale_factor, ) self.plotter = WireframePlotter() # ------------------------------------------------------------------------- # API # -------------------------------------------------------------------------
[docs] def plot( self, phase: float = 0, mode: Optional[int] = None, ax: Optional[plt.Axes] = None, ) -> tuple[plt.Figure, plt.Axes]: """Plot wireframe :param phase: phase to be plotted :param mode: mode to be plotted, if ``None`` superposition of all modes :param ax: if ``ax`` is passed, plot will be drawn ``ax`` """ self.data.mode = mode radii, circles = self.data.get_displacements() fig, ax = self.plotter.plot(radii, circles, phase, ax) return fig, ax
[docs] def animate( self, frames: int = 64, interval: float = 100.0, mode: Optional[int] = None, ax: Optional[plt.Axes] = None, ): """Animate wireframe :param frames: number of frames for the animation :param interval: interval between frames in milliseconds :param mode: mode to be plotted, if ``None`` superposition of all modes :param ax: if `ax` is passed, animation will be drawn `ax` """ self.data.mode = mode anim = self.plotter.animate( frames=frames, interval=interval, deformed_radii=self.data.deformed_radii, deformed_circles=self.data.deformed_circles, ax=ax, ) return anim
[docs] def plot_evolution( self, layout: tuple[int, int] = (3, 3), **kwargs, ) -> tuple[plt.Figure, plt.Axes]: """Plot wireframe over one period :param layout: subplot layout of the figure """ n_row, n_col = layout n = n_col * n_row phases = np.linspace(0, 2, num=n) fig, axes = plt.subplots( n_row, n_col, sharex=True, sharey=True, **kwargs, ) for i, phase in enumerate(phases): row = i // n_col col = i % n_col ax = axes.flat[i] radii, circles = self.data.get_displacements() _, _ = self.plotter.plot(radii, circles, phase * np.pi, ax) ax.set_title(f'{phase:.2f}' + r'$\pi$') if row != (n_row - 1): ax.set_xlabel('') if col > 0: ax.set_ylabel('') fig.tight_layout() return fig, axes
if __name__ == '__main__': pass