"""Fichier contenant des fonctions pour le drift."""
from dataclasses import dataclass
import numpy as np
import pandas as pd
from scipy.spatial import cKDTree
from palm_tracer.Processing.Parsing import apply_dataframe_type, FILES_COLUMNS
##################################################
@dataclass
class _ActiveTrack:
"""Structure interne de suivi actif d'une bille à travers les plans."""
track_id: int
ids: list[int]
last_pos: np.ndarray # shape (D,)
##################################################
def _check_cols(data: pd.DataFrame, required: set[str]):
"""Vérification rapide des colonnes."""
missing = sorted(required - set(data.columns))
if missing: raise ValueError(f"Missing columns in data: {missing}.")
##################################################
def _check_planes(data: pd.DataFrame) -> np.ndarray:
"""
Vérifie les plans disponibles dans le dataframe et leur continuité.
:param data: Dataframe à analyser.
:return: Liste des plans triés.
"""
planes = np.array(sorted(pd.unique(data["Plane"])))
expected = np.arange(planes[0], planes[-1] + 1, dtype=planes.dtype)
if planes.size < 2: raise ValueError(f"We need at least 2 planes.")
if planes.size != expected.size or np.any(planes != expected): raise ValueError(f"The planes are not consecutive: {planes}.")
return planes
##################################################
def _assign_tracks_to_points_greedy(indices: np.ndarray, distances: np.ndarray, n_points: int) -> tuple[np.ndarray, np.ndarray]:
"""
Assigne des suivis à des points du plan courant en gérant les collisions, à partir des résultats de :meth:`scipy.spatial.cKDTree.query` avec ``k>1``.
Pour chaque suivi, on dispose d'une liste de candidats triés par distance croissante (jusqu'à k).
On construit toutes les paires (distance, track, point) valides puis on fait un appariement glouton par distance croissante :
- un suivi reçoit au plus 1 point ;
- un point est assigné à au plus 1 suivi.
.. note:: Cette stratégie permet à un suivi de "prendre son 2ᵉ choix" si son meilleur candidat est déjà pris.
:param indices: Indices retournés par KDTree (forme (T, k) ou (T,) si k==1).
:param distances: Distances retournées par KDTree (même forme que ``ind``).
:param n_points: Nombre de points dans le plan courant (``tree.n``).
:returns: (keep_tracks, keep_points) indices des suivis (0..T-1) et points (0..P-1) retenus.
"""
# Normalisation shape : (n, k)
if indices.ndim == 1: indices, distances = indices[:, None], distances[:, None]
n, k = indices.shape
pairs: list[tuple[float, int, int]] = []
# Construction des propositions valides.
for i in range(n): # . Pour chaque suivi.
for j in range(k): # . Pour les k plus proches voisins.
p_j = int(indices[i, j]) # . Récupération du jieme voisin du suivi i.
if p_j >= n_points: continue # . >= n_points signifie "pas de voisin dans le rayon" ⇒ on ignore.
pairs.append((float(distances[i, j]), i, p_j)) # On ajoute aux propositions, la distance, le suivi et le point.
# Aucun voisin valide pour aucun suivi (tous les indices == n_points) ⇒ aucune paire candidate.
if not pairs: return np.empty((0,), dtype=np.int32), np.empty((0,), dtype=np.int32)
# Appariement glouton : plus petites distances d'abord.
pairs.sort(key=lambda x: x[0])
used_t: set[int] = set()
used_p: set[int] = set()
keep_t: list[int] = []
keep_p: list[int] = []
for d, i, p_j in pairs:
if i in used_t or p_j in used_p: continue # Le suivi/point a déjà été assigné
used_t.add(i)
used_p.add(p_j)
keep_t.append(i)
keep_p.append(p_j)
return np.asarray(keep_t, dtype=np.int32), np.asarray(keep_p, dtype=np.int32)
##################################################
##################################################
[docs]
def remove_beads(data: pd.DataFrame, beads: pd.DataFrame, decimals: int = 5) -> pd.DataFrame:
"""
Mets à jour data en lui enlevant les billes identifiées (matching exact sur Plane/X/Y/Z).
:param data: Données de localisation (doit contenir Plane, X, Y, Z).
:param beads: Billes à retirer (doit contenir Plane, X, Y, Z).
:param decimals: Nombre de décimales conservées pour la comparaison (évite les problèmes d'arrondi float).
:returns: Copie de data sans les lignes correspondant aux billes.
"""
if data.empty or beads.empty: return data
required = {"X", "Y", "Z"}
_check_cols(data, required)
_check_cols(beads, required)
keys = ["Plane", "X", "Y", "Z"]
# --- Création de copies arrondies pour comparaison ---
data_cmp = data.assign(X=data["X"].round(decimals), Y=data["Y"].round(decimals), Z=data["Z"].round(decimals))
beads_cmp = beads.assign(X=beads["X"].round(decimals), Y=beads["Y"].round(decimals), Z=beads["Z"].round(decimals)).drop_duplicates()
# --- Anti-join ---
mask = ~pd.MultiIndex.from_frame(data_cmp[keys]).isin(pd.MultiIndex.from_frame(beads_cmp[keys]))
return data.loc[mask].copy().reset_index(drop=True)
##################################################
[docs]
def get_drift(beads: pd.DataFrame, is_3d: bool = True) -> pd.DataFrame:
"""
Calcule le drift moyen interplans à partir d'un DataFrame de billes suivies.
Le drift est calculé pour chaque bille comme la différence entre deux plans consécutifs :
- :math:`\\Delta X(n) = X(n) - X(n-1)`
- :math:`\\Delta Y(n) = Y(n) - Y(n-1)`
- :math:`\\Delta Z(n) = Z(n) - Z(n-1)` (si ``is_3d=True``)
Et il est associé au plan ``n`` (donc le drift commence au plan 2).
Ensuite, on moyenne ces drifts entre toutes les billes disponibles.
:param beads: DataFrame contenant au minimum ``Bead``, ``Plane``, ``X``, ``Y``, ``Z``.
Chaque bille doit avoir exactement une ligne par plan et les plans doivent être consécutifs.
:param is_3d: Si ``True``, calcule aussi le drift en Z. Sinon, la colonne Z est renvoyée à 0.
:returns: DataFrame avec colonnes ``Plane``, ``X``, ``Y``, ``Z`` contenant le drift moyen, pour les plans de 2 à N.
:raises ValueError: Si des colonnes sont manquantes, ou incohérence de plans au sein des billes.
"""
# ----- Vérifications initiales -----
if beads.empty: return pd.DataFrame()
required = {"Bead", "Plane", "X", "Y", "Z"}
_check_cols(beads, required)
work = beads.loc[:, list(required)] # . Copie légère : uniquement les colonnes nécessaires.
work = work.sort_values(by=["Bead", "Plane"], kind="stable") # Théoriquement inutile, mais en cas de mauvaise utilisation...
planes = _check_planes(work)
# --- Calcul des deltas par bille ---
work["dX"] = work.groupby("Bead")["X"].diff()
work["dY"] = work.groupby("Bead")["Y"].diff()
work["dZ"] = work.groupby("Bead")["Z"].diff() if is_3d else 0.0
work = work.loc[work["Plane"] != int(planes[0])]
# --- Moyenne sur les billes, plan par plan ---
drift = (work.groupby("Plane", as_index=False)[["dX", "dY", "dZ"]].mean().rename(columns={"dX": "X", "dY": "Y", "dZ": "Z"}))
return drift.sort_values(by=["Plane"], kind="stable").reset_index(drop=True)
##################################################
[docs]
def apply_drift(data: pd.DataFrame, drift: pd.DataFrame, is_3d: bool = True) -> pd.DataFrame:
"""
Applique une correction de drift à un DataFrame de points.
Le DataFrame ``drift`` doit contenir une ligne par plan, avec les colonnes ``Plane``, ``X``, ``Y`` et ``Z``.
Le drift est supposé défini pour les plans 2..N et associé au plan courant (même convention que :func:`get_drift`).
La correction appliquée est :
- :math:`X_{corr}(n) = X(n) - X_{drift}(n)`
- :math:`Y_{corr}(n) = Y(n) - Y_{drift}(n)`
- :math:`Z_{corr}(n) = Z(n) - Z_{drift}(n)`
:param data: DataFrame contenant au minimum ``Plane``, ``X``, ``Y`` ``Z``.
:param drift: DataFrame contenant au minimum ``Plane``, ``X``, ``Y`` ``Z``, typiquement pour les plans 2..N.
:param is_3d: Si ``True``, applique aussi la correction en Z.
:returns: Un nouveau DataFrame corrigé (mêmes colonnes que ``data``).
:raises ValueError: Si colonnes manquantes dans ``data`` ou ``drift``.
"""
# ----- Vérifications initiales -----
if data.empty or drift.empty: return pd.DataFrame()
required = {"Plane", "X", "Y", "Z"}
_check_cols(data, required)
_check_cols(drift, required)
# --- drift incrémental -> drift cumulatif ---
drift_work = drift.loc[:, list(required)].copy()
drift_work = drift_work.sort_values("Plane", kind="stable").reset_index(drop=True)
drift_work[["X", "Y", "Z"]] = drift_work[["X", "Y", "Z"]].cumsum()
drift_work = drift_work.set_index("Plane")
# --- Application : join + soustraction ---
out = data.copy()
out = out.join(drift_work, on="Plane", rsuffix="_drift") # On fait un merge aligné sur Plane pour vectoriser (évite une boucle Python par plan).
out = out.fillna(0.0)
# Soustraction (correction).
out["X"] -= out["X_drift"]
out["Y"] -= out["Y_drift"]
if is_3d: out["Z"] -= out["Z_drift"]
out.drop(columns=[c for c in out.columns if c.endswith("_drift")], inplace=True) # Nettoyage
return out
##################################################
[docs]
def drift_correction(data: pd.DataFrame, max_distance: float = 1, is_3d: bool = True, *, strict: bool = True, k: int = 4) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Trouve des billes au sein d'un jeu de donnée, calcule le drift et le corrige.
:param data: DataFrame contenant au minimum les colonnes ``Plane``, ``X``, ``Y`` et ``Z``.
Chaque ligne représente une détection (un point) dans un plan donné.
:param max_distance: Distance maximale autorisée entre deux plans (en unités des coordonnées) selon la norme L∞.
:param is_3d: Si ``True``, utilise (X,Y,Z). Sinon, utilise uniquement (X,Y).
:param strict: Si ``True``, la distance doit être strictement inférieure à la distance maximale (comportement par défaut).
:param k: Nombre de matchs maximum pour chaques points, permet de gérer les collisions de suivis (par défaut 4 maximum).
Dans la réalité, avec des données et paramètres cohérents, il n'y aura qu'un seul match ou aucun pour chaques points.
:returns: Les billes identifiées et un nouveau DataFrame ne contenant les données corrigées.
"""
beads = extract_beads(data, max_distance=max_distance, is_3d=is_3d, strict=strict, k=k)
drift = get_drift(beads, is_3d=is_3d)
return beads, apply_drift(data, drift, is_3d=is_3d)