Source code for wulfric.visualization._plotly

# Wulfric - Cell, Atoms, K-path.
# Copyright (C) 2023-2025 Andrey Rybakov
#
# e-mail: anry@uv.es, web: adrybakov.com
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.


from random import choices
from string import ascii_lowercase
from typing import Iterable

import numpy as np

from wulfric._kpoints_class import Kpoints
from wulfric.cell._basic_manipulation import get_reciprocal
from wulfric.cell._sc_standardize import get_conventional
from wulfric.cell._voronoi import _get_voronoi_cell
from wulfric.geometry._geometry import get_volume
from wulfric.visualization._interface import AbstractBackend

try:
    import plotly.graph_objects as go

    PLOTLY_AVAILABLE = True
except ImportError:
    PLOTLY_AVAILABLE = False

# Save local scope at this moment
old_dir = set(dir())
old_dir.add("old_dir")


[docs] class PlotlyBackend(AbstractBackend): r""" Plotting engine based on |plotly|_. Parameters ---------- fig : plotly graph object Figure to plot on. If not provided, a new figure is created. Attributes ---------- fig : plotly graph object Figure to plot on. Notes ----- This class is a part of ``wulfric[visual]`` """ def __init__(self, fig=None): if not PLOTLY_AVAILABLE: raise ImportError( 'Plotly is not available. Install it with "pip install plotly"' ) super().__init__() if fig is None: fig = go.Figure() self.fig = fig def show(self, axes_visible=True, **kwargs): r""" Shows the figure in the interactive mode. Parameters ---------- axes_visible : bool, default True Whether to show axes. **kwargs Passed directly to the |plotly-update-layout|_. """ if not axes_visible: self.fig.update_scenes( xaxis_visible=False, yaxis_visible=False, zaxis_visible=False ) # Set up defaults if "width" not in kwargs: kwargs["width"] = 800 if "height" not in kwargs: kwargs["height"] = 700 if "yaxis_scaleanchor" not in kwargs: kwargs["yaxis_scaleanchor"] = "x" if "showlegend" not in kwargs: kwargs["showlegend"] = False if "autosize" not in kwargs: kwargs["autosize"] = False self.fig.update_layout(**kwargs) self.fig.show() def save( self, output_name="lattice_graph.png", kwargs_update_layout=None, kwargs_write_html=None, axes_visible=True, ): r""" Saves the figure in the html file. Parameters ---------- output_name : str, default "lattice_graph.png" Name of the file to be saved. With extension. kwargs_update_layout : dict, optional Passed directly to the |plotly-update-layout|_. kwargs_write_html : dict, optional Passed directly to the |plotly-write-html|_. axes_visible : bool, default True Whether to show axes. """ if kwargs_update_layout is None: kwargs_update_layout = {} if kwargs_write_html is None: kwargs_write_html = {} self.fig.update_scenes(aspectmode="data") if not axes_visible: self.fig.update_scenes( xaxis_visible=False, yaxis_visible=False, zaxis_visible=False ) self.fig.update_layout(**kwargs_update_layout) self.fig.write_html(output_name, **kwargs_write_html) def plot_unit_cell( self, cell, vectors=True, color="#274DD1", label=None, conventional=False, reciprocal=False, normalize=False, ): r""" Plots real or reciprocal unit cell. Parameters ---------- cell : (3, 3) |array-like|_ Matrix of a cell, rows are interpreted as vectors. vectors : bool, default True Whether to plot lattice vectors. color : str, default "#274DD1" Colour for the plot. Any value supported Plotly. label : str, optional Label for the plot. conventional : bool, default False Whether to plot conventional cell. Only primitive unit cell is supported for reciprocal space. reciprocal : bool, default False Whether to plot reciprocal or real unit cell. normalize : bool, default False Whether to normalize volume of the cell to one. """ if reciprocal and conventional: raise ValueError("Conventional cell is not supported in reciprocal space.") if conventional: artist_group = "conventional" else: artist_group = "primitive" if reciprocal: artist_group += "_reciprocal" vector_label = "b" else: artist_group += "_real" vector_label = "a" if conventional: cell = get_conventional(cell) elif reciprocal: cell = get_reciprocal(cell) if normalize: cell /= abs(get_volume(cell) ** (1 / 3.0)) legendgroup = "".join(choices(ascii_lowercase, k=10)) if vectors: labels = [f"{vector_label}{i+1}" for i in range(3)] for i in range(3): x = [0, cell[i][0]] y = [0, cell[i][1]] z = [0, cell[i][2]] self.fig.add_traces( data=[ { "x": x, "y": y, "z": z, "mode": "lines", "type": "scatter3d", "hoverinfo": "none", "line": {"color": color, "width": 3}, "showlegend": False, "legendgroup": legendgroup, }, { "type": "cone", "x": [x[1]], "y": [y[1]], "z": [z[1]], "u": [0.2 * (x[1] - x[0])], "v": [0.2 * (y[1] - y[0])], "w": [0.2 * (z[1] - z[0])], "anchor": "tip", "hoverinfo": "none", "colorscale": [[0, color], [1, color]], "showscale": False, "showlegend": False, "legendgroup": legendgroup, }, ] ) self.fig.add_traces( data=go.Scatter3d( mode="text", x=[1.2 * x[1]], y=[1.2 * y[1]], z=[1.2 * z[1]], marker=dict(size=0, color=color), text=labels[i], hoverinfo="none", textposition="top center", textfont=dict(size=12), showlegend=False, legendgroup=legendgroup, ) ) def plot_line(line, shift, showlegend=False): self.fig.add_traces( data=go.Scatter3d( mode="lines", x=[shift[0], shift[0] + line[0]], y=[shift[1], shift[1] + line[1]], z=[shift[2], shift[2] + line[2]], line=dict(color=color), hoverinfo="none", legendgroup=legendgroup, name=label, showlegend=showlegend, ), ) showlegend = label is not None for i in range(0, 3): j = (i + 1) % 3 k = (i + 2) % 3 plot_line(cell[i], np.zeros(3), showlegend=showlegend) if showlegend: showlegend = False plot_line(cell[i], cell[j]) plot_line(cell[i], cell[k]) plot_line(cell[i], cell[j] + cell[k]) def plot_wigner_seitz( self, cell, vectors=True, label=None, color="black", reciprocal=False, normalize=False, ): r""" Plots Wigner-Seitz cell. Parameters ---------- cell : (3, 3) |array-like|_ Matrix of a cell, rows are interpreted as vectors. vectors : bool, default True Whether to plot lattice vectors. label : str, optional Label for the plot. color : str, default "black" or "#FF4D67" Colour for the plot. Any value supported Plotly. reciprocal : bool, default False Whether to plot reciprocal or real Wigner-Seitz cell. normalize : bool, default False Whether to normalize volume of the cell to one. """ if reciprocal: cell = get_reciprocal(cell) vector_label = "b" else: vector_label = "a" if color is None: color = "black" if normalize: cell /= abs(get_volume(cell) ** (1 / 3.0)) v1, v2, v3 = cell[0], cell[1], cell[2] vs = [v1, v2, v3] legendgroup = "".join(choices(ascii_lowercase, k=10)) if vectors: labels = [f"{vector_label}{i+1}" for i in range(3)] for i in range(3): x = [0, vs[i][0]] y = [0, vs[i][1]] z = [0, vs[i][2]] self.fig.add_traces( data=[ { "x": x, "y": y, "z": z, "mode": "lines", "type": "scatter3d", "hoverinfo": "none", "line": {"color": color, "width": 3}, "showlegend": False, "legendgroup": legendgroup, }, { "type": "cone", "x": [x[1]], "y": [y[1]], "z": [z[1]], "u": [0.2 * (x[1] - x[0])], "v": [0.2 * (y[1] - y[0])], "w": [0.2 * (z[1] - z[0])], "anchor": "tip", "hoverinfo": "none", "colorscale": [[0, color], [1, color]], "showscale": False, "showlegend": False, "legendgroup": legendgroup, }, ] ) self.fig.add_traces( data=go.Scatter3d( mode="text", x=[1.2 * x[1]], y=[1.2 * y[1]], z=[1.2 * z[1]], marker=dict(size=0, color=color), text=labels[i], hoverinfo="none", textposition="top center", textfont=dict(size=12), showlegend=False, legendgroup=legendgroup, ) ) edges, _ = _get_voronoi_cell(cell) showlegend = label is not None for p1, p2 in edges: xyz = np.array([p1, p2]).T self.fig.add_traces( data=go.Scatter3d( mode="lines", x=xyz[0], y=xyz[1], z=xyz[2], line=dict(color=color), hoverinfo="none", showlegend=showlegend, legendgroup=legendgroup, name=label, ), ) if showlegend: showlegend = False def plot_kpath(self, cell, color="#000000", label=None, normalize=False, **kwargs): r""" Plots k path in the reciprocal space. Parameters ---------- cell : (3, 3) |array-like|_ Matrix of a cell, rows are interpreted as vectors. color : str, default "#000000" Colour for the plot. Any value supported Plotly. label : str, optional Label for the plot. normalize : bool, default False Whether to normalize volume of the cell to one. """ if normalize: cell /= get_volume(cell) ** (1 / 3.0) kp = Kpoints.from_cell(cell=cell) cell = get_reciprocal(cell) p_abs = [] p_rel = [] labels = [] for point in kp.hs_names: p_abs.append(tuple(kp.hs_coordinates[point] @ cell)) p_rel.append(kp.hs_coordinates[point]) labels.append(kp.hs_labels[point]) p_abs = np.array(p_abs).T self.fig.add_traces( data=go.Scatter3d( mode="markers+text", x=p_abs[0], y=p_abs[1], z=p_abs[2], marker=dict(size=6, color=color), text=labels, hoverinfo="text", hovertext=p_rel, textposition="top center", textfont=dict(size=16), showlegend=False, ) ) for subpath in kp.path: xyz = [] for i in range(len(subpath)): xyz.append(kp.hs_coordinates[subpath[i]] @ cell) xyz = np.array(xyz).T self.fig.add_traces( data=go.Scatter3d( mode="lines", x=xyz[0], y=xyz[1], z=xyz[2], line=dict(color=color), hoverinfo="none", showlegend=False, ), )
# Populate __all__ with objects defined in this file __all__ = list(set(dir()) - old_dir) # Remove all semi-private objects __all__ = [i for i in __all__ if not i.startswith("_")] del old_dir