Source code for wulfric.visualization._matplotlib

# Wulfric - Cell, Atoms, K-path.
# Copyright (C) 2023-2025 Andrey Rybakov
#
# e-mail: anry@uv.es, web: adrybakov.com
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.


from random import choices
from string import ascii_lowercase
from typing import Iterable

import numpy as np

from wulfric._kpoints_class import Kpoints
from wulfric.cell._basic_manipulation import get_reciprocal
from wulfric.cell._sc_standardize import get_conventional
from wulfric.cell._voronoi import _get_voronoi_cell
from wulfric.geometry._geometry import get_volume
from wulfric.visualization._interface import AbstractBackend

try:
    import matplotlib.pyplot as plt
    from matplotlib import rcParams
    from matplotlib.patches import FancyArrowPatch
    from mpl_toolkits.mplot3d import Axes3D, proj3d

    MATPLOTLIB_AVAILABLE = True
except ImportError:
    MATPLOTLIB_AVAILABLE = False


if MATPLOTLIB_AVAILABLE:
    # Better 3D arrows, see: https://stackoverflow.com/questions/22867620/putting-arrowheads-on-vectors-in-a-3d-plot
    class Arrow3D(FancyArrowPatch):
        def __init__(self, ax, xs, ys, zs, *args, **kwargs):
            FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
            self._verts3d = xs, ys, zs
            self.ax = ax

        def draw(self, renderer):
            xs3d, ys3d, zs3d = self._verts3d
            xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.ax.axes.M)
            self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
            FancyArrowPatch.draw(self, renderer)

        def do_3d_projection(self, *_, **__):
            return 0


# Save local scope at this moment
old_dir = set(dir())
old_dir.add("old_dir")


[docs] class MatplotlibBackend(AbstractBackend): r""" Plotting engine based on |matplotlib|_. Parameters ---------- fig : matplotlib figure, optional Figure to plot on. If not provided, a new figure and ``ax`` is created. ax : matplotlib axis, optional Axis to plot on. If not provided, a new axis is created. background : bool, default True Whether to keep the axis in the plot. focal_length : float, default 0.2 See: |matplotlibFocalLength|_ Attributes ---------- fig : matplotlib figure Figure to plot on. ax : matplotlib axis Axis to plot on. artists : dict Dictionary of the artists. Keys are the plot kinds, values are the lists of artists. Notes ----- This class is a part of ``wulfric[visual]`` """ def __init__(self, fig=None, ax=None, background=True, focal_length=0.2): if not MATPLOTLIB_AVAILABLE: raise ImportError( 'Matplotlib is not available. Install it with "pip install matplotlib"' ) super().__init__() if fig is None: fig = plt.figure(figsize=(6, 6)) ax = fig.add_subplot(projection="3d") elif ax is None: ax = fig.add_subplot(projection="3d") rcParams["axes.linewidth"] = 0 rcParams["xtick.color"] = "#B3B3B3" ax.set_proj_type("persp", focal_length=focal_length) if background: ax.axes.linewidth = 0 ax.xaxis._axinfo["grid"]["color"] = (1, 1, 1, 1) ax.yaxis._axinfo["grid"]["color"] = (1, 1, 1, 1) ax.zaxis._axinfo["grid"]["color"] = (1, 1, 1, 1) ax.set_xlabel("x", fontsize=15, alpha=0.5) ax.set_ylabel("y", fontsize=15, alpha=0.5) ax.set_zlabel("z", fontsize=15, alpha=0.5) ax.tick_params(axis="both", zorder=0, color="#B3B3B3") else: ax.axis("off") self.fig = fig self.ax = ax self.artists = {} def remove(self, kind="primitive"): r""" Removes a set of artists from the plot. Parameters ---------- kind : str or list of str Type of the plot to be removed. Supported kinds: * "conventional" * "primitive" * "brillouin" * "kpath" * "brillouin_kpath" * "wigner_seitz" """ if kind == "brillouin_kpath": kinds = ["brillouin", "kpath"] else: kinds = [kind] for kind in kinds: if kind not in self.artists: raise ValueError(f"No artists for the {kind} kind.") for artist in self.artists[kind]: if isinstance(artist, list): for i in artist: i.remove() else: artist.remove() del self.artists[kind] self.ax.relim(visible_only=True) self.ax.set_aspect("equal") def plot(self, cell, kind="primitive", **kwargs): r""" Main plotting method. Actual list of supported kinds can be check with: .. doctest:: >>> self.kinds.keys() # doctest: +SKIP Parameters ---------- cell : (3, 3) |array-like|_ Matrix of a cell, rows are interpreted as vectors. kind : str or list od str Type of the plot to be plotted. Supported plots: * "conventional" * "primitive" * "brillouin" * "kpath" * "brillouin-kpath" * "wigner-seitz" * "unit-cell" **kwargs Parameters to be passed to the specialized plotting function. See each function for the list of supported parameters. Raises ------ ValueError If the plot kind is not supported. See Also -------- plot_conventional : "conventional" plot. plot_primitive : "primitive" plot. plot_brillouin : "brillouin" plot. plot_kpath : "kpath" plot. plot_brillouin_kpath : "brillouin_kpath" plot. plot_wigner_seitz : "wigner-seitz" plot. plot_unit_cell : "unit-cell" plot. show : Shows the plot. save : Save the figure in the file. """ super().plot(cell, kind=kind, **kwargs) self.ax.relim() self.ax.set_aspect("equal") def show(self, elev=30, azim=-60): r""" Shows the figure in the interactive mode. Parameters ---------- elev : float, default 30 Passed directly to matplotlib. See |matplotlibViewInit|_. azim : float, default -60 Passed directly to matplotlib. See |matplotlibViewInit|_. """ self.ax.set_aspect("equal") self.ax.view_init(elev=elev, azim=azim) plt.show() self.fig = None self.ax = None plt.close() def save(self, output_name="cell_graph.png", elev=30, azim=-60, **kwargs): r""" Saves the figure in the file. Parameters ---------- output_name : str, default "cell_graph.png" Name of the file to be saved. With extension. elev : float, default 30 Passed directly to matplotlib. See |matplotlibViewInit|_. azim : float, default -60 Passed directly to matplotlib. See |matplotlibViewInit|_. **kwargs Parameters to be passed to the |matplotlibSavefig|_. """ self.ax.set_aspect("equal") self.ax.view_init(elev=elev, azim=azim) self.fig.savefig(output_name, **kwargs) def clear(self): r""" Clears the axis. """ if self.ax is not None: self.ax.cla() def legend(self, **kwargs): r""" Adds legend to the figure. Parameters ---------- **kwargs : Directly passed to the |matplotlibLegend|_. """ self.ax.legend(**kwargs) def plot_unit_cell( self, cell, vectors=True, color="#274DD1", label=None, vector_pad=1.1, conventional=False, reciprocal=False, normalize=False, ): r""" Plots real or reciprocal unit cell. Parameters ---------- cell : (3, 3) |array-like|_ Matrix of a cell, rows are interpreted as vectors. vectors : bool, default True Whether to plot lattice vectors. color : str, default "#274DD1" Colour for the plot. Any format supported by matplotlib (see |matplotlibColor|_). label : str, optional Label for the plot. vector_pad : float, default 1.1 Multiplier for the position of the vectors labels. 1 = position of the vector. conventional : bool, default False Whether to plot conventional cell. Only primitive unit cell is supported for reciprocal space. reciprocal : bool, default False Whether to plot reciprocal or real unit cell. normalize : bool, default False Whether to normalize volume of the cell to one. """ if reciprocal and conventional: raise ValueError("Conventional cell is not supported in reciprocal space.") if conventional: artist_group = "conventional" else: artist_group = "primitive" if reciprocal: artist_group += "_reciprocal" vector_label = "b" else: artist_group += "_real" vector_label = "a" self.artists[artist_group] = [] if conventional: cell = get_conventional(cell) elif reciprocal: cell = get_reciprocal(cell) if normalize: cell /= abs(get_volume(cell) ** (1 / 3.0)) if label is not None: self.artists[artist_group].append( self.ax.scatter(0, 0, 0, color=color, label=label) ) if vectors: if not isinstance(vector_pad, Iterable): vector_pad = [vector_pad, vector_pad, vector_pad] for i in range(3): self.artists[artist_group].append( self.ax.text( cell[i][0] * vector_pad[i], cell[i][1] * vector_pad[i], cell[i][2] * vector_pad[i], f"${vector_label}_{i+1}$", fontsize=20, color=color, ha="center", va="center", ) ) # Try beautiful arrows try: self.artists[artist_group].append( self.ax.add_artist( Arrow3D( self.ax, [0, cell[i][0]], [0, cell[i][1]], [0, cell[i][2]], mutation_scale=20, arrowstyle="-|>", color=color, lw=2, alpha=0.7, ) ) ) # Go to default except: self.artists[artist_group].append( self.ax.quiver( 0, 0, 0, *tuple(cell[i]), arrow_length_ratio=0.2, color=color, alpha=0.7, linewidth=2, ) ) # Ghost point to account for the plot range self.artists[artist_group].append(self.ax.scatter(*tuple(cell[i]), s=0)) def plot_line(line, shift): self.artists[artist_group].append( self.ax.plot( [shift[0], shift[0] + line[0]], [shift[1], shift[1] + line[1]], [shift[2], shift[2] + line[2]], color=color, ) ) for i in range(0, 3): j = (i + 1) % 3 k = (i + 2) % 3 plot_line(cell[i], np.zeros(3)) plot_line(cell[i], cell[j]) plot_line(cell[i], cell[k]) plot_line(cell[i], cell[j] + cell[k]) def plot_wigner_seitz( self, cell, vectors=True, color="black", label=None, vector_pad=1.1, reciprocal=False, normalize=False, ): r""" Plots Wigner-Seitz cell. Parameters ---------- cell : (3, 3) |array-like|_ Matrix of a cell, rows are interpreted as vectors. vectors : bool, default True Whether to plot lattice vectors. color : str, default "black" Colour for the plot. Any format supported by matplotlib (see |matplotlibColor|_). label : str, optional Label for the plot. vector_pad : float, default 1.1 Multiplier for the position of the vectors labels. 1 = position of the vector. reciprocal : bool, default False Whether to plot reciprocal or real Wigner-Seitz cell. normalize : bool, default False Whether to normalize volume of the cell to one. """ if reciprocal: artist_group = "brillouin" else: artist_group = "wigner_seitz" self.artists[artist_group] = [] if reciprocal: cell = get_reciprocal(cell) vector_label = "b" else: vector_label = "a" if color is None: color = "black" if normalize: cell /= abs(get_volume(cell) ** (1 / 3.0)) v1, v2, v3 = cell[0], cell[1], cell[2] vs = [v1, v2, v3] if label is not None: self.artists[artist_group].append( self.ax.scatter(0, 0, 0, color=color, label=label) ) if vectors: if not isinstance(vector_pad, Iterable): vector_pad = [vector_pad, vector_pad, vector_pad] for i in range(3): self.artists[artist_group].append( self.ax.text( vs[i][0] * vector_pad[i], vs[i][1] * vector_pad[i], vs[i][2] * vector_pad[i], f"${vector_label}_{i+1}$", fontsize=20, color=color, ha="center", va="center", ) ) # Try beautiful arrows try: self.artists[artist_group].append( self.ax.add_artist( Arrow3D( self.ax, [0, vs[i][0]], [0, vs[i][1]], [0, vs[i][2]], mutation_scale=20, arrowstyle="-|>", color=color, lw=2, alpha=0.8, ) ) ) # Go to default except: self.artists[artist_group].append( self.ax.quiver( 0, 0, 0, *tuple(vs[i]), arrow_length_ratio=0.2, color=color, alpha=0.5, ) ) # Ghost point to account for the plot range self.artists[artist_group].append(self.ax.scatter(*tuple(vs[i]), s=0)) edges, _ = _get_voronoi_cell(cell) for p1, p2 in edges: self.artists[artist_group].append( self.ax.plot( [p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, ) ) def plot_kpath(self, cell, color="black", label=None, normalize=False): r""" Plots k path in the reciprocal space. Parameters ---------- cell : (3, 3) |array-like|_ Matrix of a cell, rows are interpreted as vectors. color : str, default "black" Colour for the plot. Any format supported by matplotlib (see |matplotlibColor|_). label : str, optional Label for the plot. normalize : bool, default False Whether to normalize volume of the cell to one. """ artist_group = "kpath" self.artists[artist_group] = [] if normalize: cell /= get_volume(cell) ** (1 / 3.0) kp = Kpoints.from_cell(cell) cell = get_reciprocal(cell) for point in kp.hs_names: self.artists[artist_group].append( self.ax.scatter( *tuple(kp.hs_coordinates[point] @ cell), s=36, color=color, ) ) self.artists[artist_group].append( self.ax.text( *tuple( kp.hs_coordinates[point] @ cell + 0.025 * cell[0] + +0.025 * cell[1] + 0.025 * cell[2] ), kp.hs_labels[point], fontsize=20, color=color, ) ) if label is not None: self.artists[artist_group].append( self.ax.scatter( 0, 0, 0, s=36, color=color, label=label, ) ) for subpath in kp.path: for i in range(len(subpath) - 1): self.artists[artist_group].append( self.ax.plot( *tuple( np.concatenate( ( kp.hs_coordinates[subpath[i]] @ cell, kp.hs_coordinates[subpath[i + 1]] @ cell, ) ) .reshape(2, 3) .T ), color=color, alpha=0.5, linewidth=3, ) )
# Populate __all__ with objects defined in this file __all__ = list(set(dir()) - old_dir) # Remove all semi-private objects __all__ = [i for i in __all__ if not i.startswith("_")] del old_dir