"""Fichier contenant une classe pour créer des graphiques."""
from dataclasses import dataclass
from typing import Optional
import numpy as np
import plotly.graph_objects as go
from scipy.stats import gaussian_kde, multivariate_normal
from palm_tracer.Processing.Astigmatism3D import sigma_model
from palm_tracer.Processing.Parsing import SHAPE_MODEL
# Palette "deep" de seaborn (approx)
_SEABORN_DEEP = ["#4C72B0", "#55A868", "#C44E52", "#8172B2", "#CCB974", "#64B5CD", "#FFD92F", "#E7298A", "#66A61E", "#E6AB02"]
_TEMPLATE = "plotly_white"
_BLANK_ANNOTATIONS = [dict(text="No valid data.", x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False)]
_GRID_COLOR = "#e6e6e6"
_GRID_WIDTH = 0.75
_MARGIN = dict(l=60, r=30, t=60, b=50)
MESH_SIZE = 128
##################################################
[docs]
@dataclass
class Grapher:
"""Créateur de graphiques avec Plotly."""
# ==================================================
# region Statistic Figure
# ==================================================
##################################################
[docs]
@staticmethod
def blank(title: str = "") -> go.Figure:
"""
Créé une figure vide avec une annotation standard au centre ``_BLANK_ANNOTATIONS``.
:param title: Titre de la figure
:return: :class:`go.Figure <plotly.graph_objects.Figure>` Figure avec l'annotation
"""
fig = go.Figure()
fig.update_layout(title=title, template=_TEMPLATE, annotations=_BLANK_ANNOTATIONS, margin=_MARGIN)
return fig
##################################################
[docs]
def histogram(self, data: np.ndarray, title: str = "", xlabel: str = "", ylabel: str = "",
limit: bool = False, show_sigma: bool = False, kde: bool = False,
gaussian: bool = False, density: bool = True, cumulative: bool = False, bins: Optional[int] = None) -> go.Figure:
"""
Trace un histogramme des données "façon" Seaborn avec Plotly et optionnellement une courbe kernel density estimation.
:param data: Données sous forme de tableau numpy 1D/ND (aplati).
:param title: Titre du graphe.
:param xlabel: Label optionnel pour l'axe X. Si la chaine est vide, ne change rien.
:param ylabel: Label optionnel pour l'axe Y. Si la chaine est vide, ne change rien.
:param limit: Si True, applique la règle des 3 sigmas pour limiter les données (trim des outliers).
:param show_sigma: Si True, superpose la moyenne, ±1,±2,±3 sigma.
:param kde: Si True, superpose la KDE gaussienne.
:param gaussian: Si True, superpose la gaussienne.
:param density: Affiche l'histogramme en densité (True) ou en comptes (False).
:param bins: Nombre de bins explicite (sinon Sturges).
:param cumulative: Si True, affiche l'histogramme cumulé ainsi que les courbes KDE / gaussienne en version cumulée.
:return: :class:`go.Figure <plotly.graph_objects.Figure>`
"""
if data.ndim == 2: # On considère la première ligne/colonne comme l'identifiant/compteur pour la valeur d'intérêt
if data.shape[0] == 2: _, x = data[0, :], data[1, :] # . (2, N) -> lignes = (x, y)
elif data.shape[1] == 2: _, x = data[:, 0], data[:, 1] # (N, 2) -> colonnes = (x, y)
else: x = np.asarray(data).ravel()
else: x = np.asarray(data).ravel()
x = x[np.isfinite(x)]
# Aucunes données valides
if x.size == 0: return self.blank(title)
fig = go.Figure()
# Limite des données avec la règle des 3 Sigmas
x, limits, mu, sigma = self._get_range(x, limit)
# Récupération du nombre de bin
if bins is None: bins = self._get_bins_number(x)
bin_width = (limits[1] - limits[0]) / max(int(bins), 1)
# Histogramme
histnorm = "probability density" if density else None
fig.add_histogram(x=x, nbinsx=bins, histnorm=histnorm, cumulative=dict(enabled=cumulative), marker=dict(color=_SEABORN_DEEP[0], line=dict(width=0)),
opacity=0.75, name="Histogram", hovertemplate="(%{x:.2f}, %{y:.2f})<extra></extra>")
# KDE
if x.size > 1 and sigma > 0:
x_grid = np.linspace(limits[0], limits[1], MESH_SIZE) # grille régulière sur l'intervalle affiché
if kde:
kde_model = gaussian_kde(x) # choisit sa propre bandwidth
y_pdf = kde_model(x_grid)
y = self._scale_curve(x_grid, y_pdf, x.size, bin_width, density, cumulative)
fig.add_trace(go.Scatter(x=x_grid, y=y, mode="lines", line=dict(dash="dash", color=_SEABORN_DEEP[1]),
name="KDE", hoverinfo="skip", hovertemplate=None))
# Gaussian
if gaussian:
y_pdf = (1.0 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x_grid - mu) / sigma) ** 2)
y = self._scale_curve(x_grid, y_pdf, x.size, bin_width, density, cumulative)
fig.add_trace(go.Scatter(x=x_grid, y=y, mode="lines", line=dict(dash="dash", color=_SEABORN_DEEP[2]),
name="Gaussian", hoverinfo="skip", hovertemplate=None))
# Mu et Sigmas
if show_sigma and x.size > 1 and sigma > 0: self._draw_sigma(fig, mu, sigma, True)
# Style "seaborn-like" + Espacement entre barres
xlabel = "Values" if xlabel == "" else xlabel
ylabel = ("Density" if density else "Count") if ylabel == "" else ylabel
fig.update_layout(title=f"{title} (μ = {mu:.2f}, σ = {sigma:.2f})", template=_TEMPLATE, margin=_MARGIN,
xaxis=self._axis_dict(xlabel, limits), yaxis=self._axis_dict(ylabel),
hovermode="x", showlegend=True, bargap=0.15, bargroupgap=0.05)
return fig
##################################################
[docs]
def scatter(self, data: np.ndarray, title: str = "", xlabel: str = "", ylabel: str = "", limit: bool = False, show_sigma: bool = False) -> go.Figure:
"""
Trace une courbe des données "façon" Seaborn avec Plotly.
:param data: Données sous forme de tableau numpy 1D ou 2D.
:param title: Titre du graphe.
:param xlabel: Label optionnel pour l'axe X. Si la chaine est vide, ne change rien.
:param ylabel: Label optionnel pour l'axe Y. Si la chaine est vide, ne change rien.
:param limit: Si True, applique la règle des 3 sigmas pour limiter les données (trim des outliers).
:param show_sigma: Si True, superpose la moyenne, ±1,±2,±3 sigma.
:return: :class:`go.Figure <plotly.graph_objects.Figure>`
:raises ValueError: Si les dimensions du tableau ne correspondent pas à ceux attendus (1D, 2D, mais avec uniquement 2 lignes ou 2 colonnes)
"""
# Déterminer x,y
if data.ndim == 1:
y = data[np.isfinite(data)]
x = np.arange(y.size, dtype=float)
elif data.ndim == 2:
if data.shape[0] == 2: x, y = data[0, :], data[1, :] # . (2, N) -> lignes = (x, y)
elif data.shape[1] == 2: x, y = data[:, 0], data[:, 1] # (N, 2) -> colonnes = (x, y)
else: raise ValueError("data 2D doit avoir 2 lignes ou 2 colonnes (x,y).")
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
else: raise ValueError("data doit être 1D ou 2D.")
# Aucunes données valides
if x.size == 0: return self.blank(title)
fig = go.Figure()
# Limite des données avec la règle des 3 Sigmas
_, limits, mu, sigma = self._get_range(y, limit)
# faire une courbe style "seaborn-like"
fig.add_trace(go.Scatter(x=x, y=y, mode="lines+markers", line=dict(color=_SEABORN_DEEP[0]), hovertemplate="x=%{x:.2f}<br>y=%{y:.2f}<extra></extra>"))
# Mu et Sigmas
if show_sigma and x.size > 1 and sigma > 0: self._draw_sigma(fig, mu, sigma, False)
# Style "seaborn-like" + Espacement entre barres
fig.update_layout(title=title, template=_TEMPLATE, margin=_MARGIN, xaxis=self._axis_dict(xlabel), yaxis=self._axis_dict(ylabel, limits),
hovermode="closest", showlegend=False)
return fig
##################################################
[docs]
def cloud(self, data: np.ndarray, title: str = "", xlabel: str = "", ylabel: str = "", limit: bool = False, show_sigma: bool = False,
kde: bool = False, gaussian: bool = False) -> go.Figure:
"""
Trace une courbe des données "façon" Seaborn avec Plotly.
:param data: Données sous forme de tableau numpy 1D ou 2D.
:param title: Titre du graphe.
:param xlabel: Label optionnel pour l'axe X. Si la chaine est vide, ne change rien.
:param ylabel: Label optionnel pour l'axe Y. Si la chaine est vide, ne change rien.
:param limit: Si True, applique la règle des 3 sigmas pour limiter les données (trim des outliers).
:param show_sigma: Si True, superpose la moyenne, ±1,±2,±3 sigma.
:param kde: Si True, superpose la KDE gaussienne 2D.
:param gaussian: Si True, superpose la gaussienne 2D.
:return: :class:`go.Figure <plotly.graph_objects.Figure>`
:raises ValueError: Si les dimensions du tableau ne correspondent pas à ceux attendus (1D, 2D, mais avec uniquement 2 lignes ou 2 colonnes)
"""
if data.size == 0: return self.blank(title)
if data.ndim == 2:
if data.shape[0] != 2:
if data.shape[1] == 2: data = data.T # (N, 2) => passage en mode ligne
else: raise ValueError("data doit avoir 2 lignes ou 2 colonnes (x,y).")
mask = np.isfinite(data[0, :]) & np.isfinite(data[1, :])
data = data[:, mask]
x, y = data[0, :], data[1, :]
else: raise ValueError("data doit être 2D.")
# Aucunes données valides
if data.size == 0: return self.blank(title)
fig = go.Figure()
# Test d'histogramme en heatmap de fond
# fig.add_trace(go.Histogram2d(x=x, y=y, nbinsx=self._get_bins_number(x), nbinsy=self._get_bins_number(y), colorscale="Viridis", showscale=True,
# opacity=0.5, name="Histogramm", hoverinfo="skip", hovertemplate=None))
# Limite des données avec la règle des 3 Sigmas
_, limits_x, mu_x, sigma_x = self._get_range(x, limit)
_, limits_y, mu_y, sigma_y = self._get_range(y, limit)
if x.size > 1 and sigma_x > 0 and sigma_y > 0:
if kde:
xg, yg = np.linspace(limits_x[0], limits_x[1], MESH_SIZE), np.linspace(limits_y[0], limits_y[1], MESH_SIZE)
xm, ym = np.meshgrid(xg, yg)
k = gaussian_kde(np.vstack([x, y])) # 2D KDE
z = k(np.vstack([xm.ravel(), ym.ravel()])).reshape(MESH_SIZE, MESH_SIZE)
fig.add_trace(go.Heatmap(x=xg, y=yg, z=z, colorscale="Viridis", opacity=0.5, name="KDE", hoverinfo="skip", hovertemplate=None))
if gaussian:
mu, cov = np.array([mu_x, mu_y]), np.cov(np.vstack([x, y]))
xg, yg = np.linspace(limits_x[0], limits_x[1], MESH_SIZE), np.linspace(limits_y[0], limits_y[1], MESH_SIZE)
xm, ym = np.meshgrid(xg, yg)
rv = multivariate_normal(mean=mu, cov=cov, allow_singular=True)
z = rv.pdf(np.dstack([xm, ym]))
fig.add_trace(go.Heatmap(x=xg, y=yg, z=z, colorscale="Viridis", opacity=0.5, name="Gaussian", hoverinfo="skip", hovertemplate=None))
fig.add_trace(go.Scattergl(x=x, y=y, mode="markers", marker=dict(size=4, color=_SEABORN_DEEP[0]), opacity=0.75, name="Data",
hovertemplate="x=%{x:.2f}<br>y=%{y:.2f}<extra></extra>"))
# Mu et Sigmas sur X
if show_sigma and x.size > 1 and sigma_x > 0: self._draw_sigma(fig, mu_x, sigma_x, False)
if show_sigma and y.size > 1 and sigma_y > 0: self._draw_sigma(fig, mu_y, sigma_y, True)
# Style "seaborn-like" + Espacement entre barres
fig.update_layout(title=title, template=_TEMPLATE, margin=_MARGIN, xaxis=self._axis_dict(xlabel, limits_x), yaxis=self._axis_dict(ylabel, limits_y),
hovermode="closest", showlegend=False)
return fig
# ==================================================
# endregion Statistic Figure
# ==================================================
# ==================================================
# region Misc Figure
# ==================================================
##################################################
[docs]
def astigmatism3d_curve(self, model: np.ndarray, title: str = "", pixel_size: float = 160, z_max: float = 500, n_points: int = 5000) -> go.Figure:
"""
:param model: Modèle astigmatique de forme (2, 5) : paramètres X puis Y, chaque ligne = [Z0, W, C3, C4, A].
:param title: Titre du graphe.
:param pixel_size: Taille du pixel dans les mêmes unités que Z (ex. nm).
:param z_max: Valeur absolue maximale sur Z.
:param n_points: Nombre de points sur la courbe (résolution)
:return: :class:`go.Figure <plotly.graph_objects.Figure>`
:raises ValueError: Si les dimensions du modèle ne correspondent pas à celles attendues (2x5)
"""
if model.shape != SHAPE_MODEL: raise ValueError(f"Le modèle doit être de dimension {SHAPE_MODEL}.")
fig = go.Figure()
z = np.linspace(-z_max, z_max, n_points, dtype=np.float64)
sx = sigma_model(model[0], z, pixel_size, 1)
sy = sigma_model(model[1], z, pixel_size, 1)
fig.add_trace(go.Scatter(x=sx, y=sy, customdata=z,
mode="markers", marker=dict(size=6, color=z, colorscale="Viridis", colorbar=dict(title="Z (nm)"), showscale=True),
hovertemplate="σ(x:%{x:.3f}, y:%{y:.3f}) = %{customdata:.0f} nm<extra></extra>"))
fig.update_layout(title=title, template=_TEMPLATE, margin=_MARGIN, xaxis=self._axis_dict("Sigma X"), yaxis=self._axis_dict("Sigma Y"),
hovermode="closest", showlegend=False)
# fig.update_xaxes(showspikes=True, spikemode="across", spikesnap="cursor", spikecolor="gray", spikethickness=1)
# fig.update_yaxes(showspikes=True, spikemode="across", spikesnap="cursor", spikecolor="gray", spikethickness=1)
return fig
# ==================================================
# endregion Misc Figure
# ==================================================
# ==================================================
# region Tools
# ==================================================
##################################################
@staticmethod
def _get_bins_number(data: np.ndarray, limits=(30, 300)) -> int:
"""
Calcule un nombre de bin adaptatif pour un histogramme.
:param data: Données à analyser.
:param limits: Bornes pour le nombre de bins.
:return: Nombre de bins.
"""
n_values = len(data)
# bins = int(np.sqrt(n_values)) # Règle de racine carrée
bins = int(np.ceil(np.log2(n_values) + 1)) # .Règle de Sturges
return max(limits[0], min(bins, limits[1])) # Bornes pour éviter des valeurs extrêmes
##################################################
@staticmethod
def _get_range(data: np.ndarray, limit) -> tuple[np.ndarray, list[float], float, float]:
"""
Calcule les limites du graphique avec la règle des 3 sigmas et ajuste le tableau si nécessaire.
:param data: Données à analyser.
:param limit: Limite ou non les données.
:return: Le tableau (en cas de modification) et les limites du graphique.
"""
mu, sigma = float(np.mean(data)), float(np.std(data))
if limit and sigma > 0:
limits = [mu - 3 * sigma, mu + 3 * sigma] # . Limite théoriques des datas
data = data[(data >= limits[0]) & (data <= limits[1])] # . Suppression des datas au dela des limites
limits = [max(limits[0], min(data)), min(limits[1], max(data))] # On resserre les limites autour des datas
else:
limits = [min(data), max(data)]
return data, limits, mu, sigma
##################################################
@staticmethod
def _scale_curve(x_grid: np.ndarray, y_pdf: np.ndarray, n: int, bin_width: float, density: bool = False, cumulative: bool = False):
"""
Adapte une courbe PDF pour l'affichage selon les modes densité / comptes et normal / cumulé.
:param x_grid: Abscisses régulières.
:param y_pdf: Densité (PDF) à convertir.
:param n: nombre de bins.
:param bin_width: Largeur d'une bin.
:param density: Affiche l'histogramme en densité (True) ou en comptes (False).
:param cumulative: Si True, calcule la version cumulée de la courbe.
:return: Courbe prête à être affichée.
"""
if cumulative:
dx = float(x_grid[1] - x_grid[0])
y_cdf = np.cumsum(y_pdf) * dx
# Protection numérique pour rester dans [0, 1] si possible.
if y_cdf.size > 0 and y_cdf[-1] > 0: y_cdf = y_cdf / y_cdf[-1]
y = np.clip(y_cdf, 0.0, 1.0)
return y if density else y * n # Conversion densité -> comptes approximatifs.
return y_pdf if density else y_pdf * n * bin_width # convertir la densité en comptes ~ dens * N * bin_width
##################################################
@staticmethod
def _draw_sigma(fig, mu, sigma, x_axis: bool = True):
"""
Ajoute les séparations entre chaque sigma.
:param fig: Figure à modifier
:param mu: Moyenne
:param sigma: Écart-type
:param x_axis: ``True`` pour des séparations verticales sur l'axe X, ``False`` sinon.
"""
params = [[mu, _SEABORN_DEEP[3], "μ"],
[mu - sigma, _SEABORN_DEEP[4], "μ - 1σ"], [mu + sigma, _SEABORN_DEEP[4], "μ + 1σ"],
[mu - 2 * sigma, _SEABORN_DEEP[5], "μ - 2σ"], [mu + 2 * sigma, _SEABORN_DEEP[5], "μ + 2σ"],
[mu - 3 * sigma, _SEABORN_DEEP[6], "μ - 3σ"], [mu + 3 * sigma, _SEABORN_DEEP[6], "μ + 3σ"]]
if x_axis:
for p in params: fig.add_vline(x=p[0], line_color=p[1], line_dash="dot", line_width=1.5, name=p[2])
else:
for p in params: fig.add_hline(y=p[0], line_color=p[1], line_dash="dot", line_width=1.5, name=p[2])
##################################################
@staticmethod
def _axis_dict(title: str, limits: Optional[list] = None) -> dict:
return dict(title=title, range=limits, zeroline=False, showgrid=True, gridcolor=_GRID_COLOR, gridwidth=_GRID_WIDTH)