"""
Fichier contenant une classe pour utiliser la DLL externe CPU_PALM, exécuter les algorithmes de détection de points et les paramètres liés.
"""
import ctypes
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import pandas as pd
import psutil
from palm_tracer.Processing.Parsing import FILES_COLUMNS, N_COL_LOC, N_COL_TRC, parse_localization_for_tracking, parse_result, SHAPE_MODEL
from palm_tracer.Tools import FileIO, Ui
N_TRC_CP_FIT = 12
DENSITY = 0.1
C_IMG, C_TAB = ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_double)
C_UINT, C_BOOL, C_DBL = ctypes.c_uint32, ctypes.c_bool, ctypes.c_double
##################################################
[docs]
@dataclass
class Palm:
"""Classe permettant d'utiliser la DLL externe PALM, exécuter les algorithmes de détection de points et les paramètres liés."""
_type: str = field(init=False, default="CPU")
"""Type de DLL, par défaut CPU, GPU également possible."""
_dll: ctypes.CDLL = field(init=False)
"""DLL chargée."""
# ==================================================
# region Initialization
# ==================================================
##################################################
def __post_init__(self):
"""Méthode appelée automatiquement après l'initialisation du dataclass."""
self._dll = FileIO.load_dll(self._type)
self._bind()
##################################################
[docs]
def is_valid(self) -> bool:
"""
Vérifie la validité de la DLL utilisée pour PALM.
:return: True si la DLL est valide, False sinon.
"""
return self._dll is not None
##################################################
def _bind(self):
"""Déclare les signatures ctypes (argtypes/restype) une seule fois."""
# double AutoThreshold(uint16_t* image, uint32_t h, uint32_t w, double* fit_params)
fn = self._dll.AutoThreshold
fn.restype = C_DBL
fn.argtypes = [C_IMG, C_UINT, C_UINT, C_TAB]
# uint32_t Localization(uint16_t* stack, double* locs, uint32_t n, uint32_t h, uint32_t w, uint32_t planes,
# double thr, double watershed, uint32_t fit, double* fit_params)
fn = self._dll.Localization
fn.restype = C_UINT
fn.argtypes = [C_IMG, C_TAB, C_UINT, C_UINT, C_UINT, C_UINT, C_DBL, C_DBL, C_UINT, C_TAB]
# uint32_t Tracking(double* points, double* tracks, double max_dist, uint32_t min_life, double decrease, double cost_birth, uint32_t planes)
fn = self._dll.Tracking
fn.restype = C_UINT
fn.argtypes = [C_TAB, C_TAB, C_DBL, C_UINT, C_DBL, C_DBL, C_UINT]
# uint32_t BlinkingReconnection(double* input, double* output, uint32_t nRow, double pixel_size, uint32_t mode,
# uint32_t max_duration, double max_speed)
fn = self._dll.BlinkingReconnection
fn.restype = C_UINT
fn.argtypes = [C_TAB, C_TAB, C_UINT, C_DBL, C_UINT, C_UINT, C_DBL]
# bool TracksCompute(double* input, double* o_msd, double* o_ind, double* o_fit, uint32_t nRow, bool is_msd, bool is_ind, bool is_3d,
# double pixel_size, double exposure_time, uint32_t fit_mode, double* fit_params)
fn = self._dll.TracksCompute
fn.restype = C_BOOL
fn.argtypes = [C_TAB, C_TAB, C_TAB, C_TAB, C_UINT, C_BOOL, C_BOOL, C_BOOL, C_DBL, C_DBL, C_UINT, C_TAB]
# void Alignment(uint16_t* input, uint16_t* output, uint32_t h, uint32_t w, uint32_t planes, double* factors, uint32_t upsampling)
fn = self._dll.Alignment
fn.restype = None
fn.argtypes = [C_IMG, C_IMG, C_UINT, C_UINT, C_UINT, C_TAB, C_UINT]
# void Wavelett(uint16_t* input, double* output, uint32_t h, uint32_t w, uint32_t planes, uint32_t level)
fn = self._dll.Wavelett
fn.restype = None
fn.argtypes = [C_IMG, C_TAB, C_UINT, C_UINT, C_UINT, C_UINT]
# void Astigmatism3DCalibration(const double* input, double* output, uint32_t size, double pixelSize)
fn = self._dll.Astigmatism3DCalibration
fn.restype = None
fn.argtypes = [C_TAB, C_TAB, C_UINT, C_DBL, C_BOOL]
# void Astigmatism3DEstimation(const double* input, double* output, uint32_t size, double pixelSize, double* model, double zMax)
fn = self._dll.Astigmatism3DEstimation
fn.restype = None
fn.argtypes = [C_TAB, C_TAB, C_UINT, C_DBL, C_TAB, C_DBL]
# ==================================================
# endregion Initialization
# ==================================================
# ==================================================
# region Argument Parser
# ==================================================
##################################################
@staticmethod
def _as_c_contig(a: np.ndarray, dtype: np.dtype, *, writeable: bool) -> np.ndarray:
"""
Retourne un tableau C-contigu du dtype demandé, sans copie si possible.
:param a: Tableau d'entrée.
:param dtype: Dtype souhaité.
:param writeable: Garantit un buffer modifiable si True.
:return: Tableau C-contigu compatible DLL.
"""
if not isinstance(a, np.ndarray): a = np.asarray(a)
# Vérification de copie nécessaire ou non
if (a.dtype != dtype) or (not a.flags["C_CONTIGUOUS"]) or (writeable and not a.flags["WRITEABLE"]):
a = np.ascontiguousarray(a, dtype=dtype) # Rend contigu
if writeable: a.setflags(write=True) # . Rend modifiable ou non
return a
##################################################
[docs]
@staticmethod
def max_allocation_bytes(fraction_available: float = 0.5, safety_gb: int = 1) -> int:
"""
Permet de calculer la quantité de mémoire disponible au maximum pour une allocation.
:param fraction_available: Pourcentage de la RAM disponible à utiliser au maximum.
:param safety_gb: Marge de sécurité à garder disponible.
:return: Valeur en byte de l'allocation maximum tolérée.
"""
giga = 1024 * 1024 * 1024
avail = psutil.virtual_memory().available
safety = safety_gb * giga
budget = int(max(0, avail - safety) * fraction_available)
return max(budget, giga)
# ==================================================
# endregion Argument Parser
# ==================================================
# ==================================================
# region DLL Call
# ==================================================
##################################################
[docs]
def localization(self, stack: np.ndarray, threshold: float, watershed: bool, fit: int, fit_params: np.ndarray,
planes: Optional[list[int]] = None) -> pd.DataFrame:
"""
Exécute un traitement d'image avec une DLL PALM externe pour détecter des points dans une pile ou une image.
:param stack: Pile d'images en entrée sous forme de tableau numpy (possibilité d'envoyer une image directement).
:param threshold: Seuil pour la détection.
:param watershed: Active ou désactive le mode watershed.
:param fit: Mode d'ajustement (défini par `get_fit`).
:param fit_params: Paramètres du mode d'ajustement.
:param planes: Liste des plans à analyser (None pour tous les plans, les plans sont contigus par principe).
:return: Liste des points détectés sous forme de dataframe contenant toutes les informations reçues de la DLL.
"""
# --- Initialisation ---
stk = self._as_c_contig(stack, np.dtype(np.uint16), writeable=False) # . Assurance de contiguité
params = self._as_c_contig(fit_params, np.dtype(np.float64), writeable=False) # Assurance de contiguité
height, width = stk.shape[-2:] # . Récupère les deux dernières dimensions
n_planes = 1 if stk.ndim == 2 else stk.shape[0] # . Récupère le nombre de plans si 3D sinon 1
if planes is None: planes = list(range(n_planes)) # . Si aucune sélection, liste de tous les plans
else: planes = [p for p in planes if 0 <= p < n_planes] # . Sinon, sélection des plans valides
n_planes = len(planes) # . Nouveau nombre de plans
# Ajoute une dimension plan artificielle pour une Image 2D ou une vue mémoire (slice) pour une pile 3D
stk = stk[np.newaxis, :, :] if stk.ndim == 2 else stk[planes[0]:planes[0] + n_planes]
# --- Calcul du budget RAM Disponible ---
max_points = int(self.max_allocation_bytes() // 8) # . Nombre de points maximum allouable en une fois
plane_points = int(height * width * DENSITY) * N_COL_LOC # . Taille théorique max pour un seul plan (N points max * N Col localisation)
if max_points < plane_points: return pd.DataFrame() # . pragma: no cover — Cas extrême un seul plan est gargantuesque.
n_plane_max = int(min(max_points // plane_points, n_planes)) # Nombre de plans qui tiennent dans max_allocation
dfs: list[pd.DataFrame] = []
i = 0
while i < len(planes):
k = min(n_plane_max, n_planes - i) # . Taille réelle du bloc, soit le max, soit "ce qui reste".
stk_block = stk[i:i + k] # . Indices relatifs (0..n_planes-1)
n_block = plane_points * k # . Nombre de points pour ce bloc
locs = np.empty((n_block,), dtype=np.float64, order="C") # Création de la sortie
count = self._dll.Localization(stk_block.ctypes.data_as(C_IMG), locs.ctypes.data_as(C_TAB), C_UINT(n_block), C_UINT(height), C_UINT(width),
C_UINT(k), C_DBL(threshold), C_DBL(0 if watershed else 10), C_UINT(fit), params.ctypes.data_as(C_TAB))
res = parse_result(locs[:count], "Localization")
if "Plane" in res.columns: res["Plane"] += planes[0] + i # . En cas de filtre des plans, on incrémente par i + premier plan.
dfs.append(res)
i += k
res = pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()
if not res.empty:
res.reset_index(drop=True, inplace=True)
res["Id"] = res.index + 1 # 1-based comme attendu
return res
##################################################
[docs]
def auto_threshold(self, image: np.ndarray, fit_params: np.ndarray) -> float:
"""
Calcule un seuil automatique basé sur la segmentation de l'image.
:param image: Image 2D sous forme de tableau NumPy.
:param fit_params: Paramètres du mode d'ajustement.
:return: Seuil calculé (écart type final).
"""
img = self._as_c_contig(image, np.dtype(np.uint16), writeable=False) # . Assurance de contiguité
params = self._as_c_contig(fit_params, np.dtype(np.float64), writeable=False) # Assurance de contiguité
height, width = img.shape # . Récupère les dimensions
return float(self._dll.AutoThreshold(img.ctypes.data_as(C_IMG), C_UINT(height), C_UINT(width), params.ctypes.data_as(C_TAB)))
##################################################
[docs]
def tracking(self, localizations: pd.DataFrame, max_distance: float, min_life: int = 1, decrease: float = 10, cost_birth: float = 0.5) -> pd.DataFrame:
"""
Exécute l'algorithme de tracking sur les points localisés.
Cette méthode applique un algorithme de suivi (tracking) sur les données de localisation fournies,
en prenant en compte divers paramètres influençant le coût et la durée de vie des trajectoires.
:param localizations: Liste des points détectés sous forme de dataframe contenant toutes les informations reçues de la DLL.
:param max_distance: Distance maximale autorisée entre deux points pour les relier entre deux plans successifs.
:param min_life: Longueur minimale d'une trajectoire pour qu'elle soit conservée dans le résultat final.
:param decrease: Facteur de pénalisation appliqué au coût d'association entre des plans éloignés.
:param cost_birth: Coût associé à la création d'une nouvelle trajectoire (point non associé à une trajectoire existante).
:return: :class:`DataFrame <pandas.DataFrame>` contenant les trajectoires détectées.
"""
required = FILES_COLUMNS["Localization"]["columns"]
if localizations.empty or not set(required).issubset(localizations.columns): return pd.DataFrame()
points = parse_localization_for_tracking(localizations[required]) # . Tracking a besoin d'un format particulier de localisations
points = self._as_c_contig(points, np.dtype(np.float64), writeable=False) # Assurance de contiguité
track_size = len(localizations) * N_COL_TRC # . Taille du tableau final
tracks = np.empty((track_size,), dtype=np.float64, order="C") # . Tableau final
count = self._dll.Tracking(points.ctypes.data_as(C_TAB), tracks.ctypes.data_as(C_TAB), C_DBL(max_distance), C_UINT(min_life),
C_DBL(decrease), C_DBL(cost_birth), C_UINT(localizations["Plane"].max()))
return parse_result(tracks[:count], "Tracking")
##################################################
[docs]
def blinking_reconnection(self, tracks: pd.DataFrame, pixel_size: float, mode: int, max_duration: int, max_speed: float) -> pd.DataFrame:
"""
Exécute l'algorithme de reconnexion des trajectoires sur celles déjà localisées.
:param pixel_size: Taille des pixels en nanomètres.
:param tracks: Liste des points déjà trackés sous forme de dataframe contenant toutes les informations reçues de la DLL.
:param mode: Mode de dispersion des points (0: immobile, 1: diffus, 2: linéaire).
:param max_duration: Durée maximale d'un scintillement.
:param max_speed: Vitesse maximale d'un point entre deux plans (en pixel).
:return: :class:`DataFrame <pandas.DataFrame>` contenant les trajectoires détectées.
"""
required = FILES_COLUMNS["Tracking"]["columns"]
if tracks.empty or not set(required).issubset(tracks.columns): return pd.DataFrame()
in_tracks = tracks[required].to_numpy(dtype=np.float64, copy=False)
in_tracks = self._as_c_contig(in_tracks, np.dtype(np.float64), writeable=False)
n = len(tracks)
track_size = n * N_COL_TRC
out = np.empty((track_size,), dtype=np.float64, order="C")
count = self._dll.BlinkingReconnection(in_tracks.ctypes.data_as(C_TAB), out.ctypes.data_as(C_TAB), C_UINT(n), C_DBL(pixel_size),
C_UINT(mode), C_UINT(max_duration), C_DBL(max_speed))
return parse_result(out[:count], "Tracking")
##################################################
[docs]
def tracks_compute(self, tracks: pd.DataFrame, is_msd: bool, is_ind: bool, is_3d: bool, is_log: bool,
pixel_size: float, exposure_time: float, fit_mode: int, fit_params: np.ndarray) -> dict[str, pd.DataFrame]:
"""
Exécute l'algorithme de calcul sur les trajectoires.
:param tracks: Liste des points déjà trackés sous forme de dataframe contenant toutes les informations reçues de la DLL.
:param is_msd: Calcul MSD.
:param is_ind: Calcul de la diffusion instantanée.
:param is_3d: Calcul sur la 3D.
:param is_log: Applique un logarithme sur le résultat.
:param pixel_size: Taille des pixels en micromètre.
:param exposure_time: Calibration temporelle utile pour les calculs.
:param fit_mode: Mode d'ajustement.
:param fit_params: Paramètres de l'ajustement (pour le moment uniquement fit length).
:return: :class:`DataFrame <pandas.DataFrame>` contenant les trajectoires détectées.
.. note::
Pour obtenir des valeurs de diffusion en **µm²**, le paramètre ``pixel_size``
doit être exprimé en micromètres par pixel (µm/px)
"""
res: dict[str, pd.DataFrame] = {"MSD": pd.DataFrame(), "InD": pd.DataFrame(), "Fit": pd.DataFrame()}
required = FILES_COLUMNS["Tracking"]["columns"]
if tracks.empty or not set(required).issubset(tracks.columns): return res
in_tracks = tracks[required].copy()
if not is_3d: in_tracks["Z"] = 0 # On simplifie, la suite les calculs se font toujours en 3D, mais la dernière dimension sera nulle
in_tracks = in_tracks.to_numpy(dtype=np.float64, copy=False) # . Passage en tableau numpy
in_tracks = self._as_c_contig(in_tracks, np.dtype(np.float64), writeable=False) # Assurance de contiguité
params = self._as_c_contig(fit_params, np.dtype(np.float64), writeable=False) # . Assurance de contiguité
n_row = len(tracks) # . Nombre de points pour les trajectoires
n = n_row * N_TRC_CP_FIT # . Taille maximale des tableaux finaux
# Sorties : buffer réel si demandé, sinon buffer dummy (évite pointeurs NULL)
o_msd = np.empty((n,), dtype=np.float64, order="C") if is_msd else np.empty((1,), dtype=np.float64, order="C")
o_ind = np.empty((n,), dtype=np.float64, order="C") if is_ind else np.empty((1,), dtype=np.float64, order="C")
o_fit = np.empty((n,), dtype=np.float64, order="C") if fit_mode != 0 else np.empty((1,), dtype=np.float64, order="C")
self._dll.TracksCompute(in_tracks.ctypes.data_as(C_TAB), o_msd.ctypes.data_as(C_TAB), o_ind.ctypes.data_as(C_TAB), o_fit.ctypes.data_as(C_TAB),
C_UINT(n_row), C_BOOL(is_msd), C_BOOL(is_ind), C_BOOL(is_3d), C_DBL(pixel_size), C_DBL(exposure_time),
C_UINT(fit_mode), params.ctypes.data_as(C_TAB))
if is_msd: res["MSD"] = parse_result(o_msd[:n], "MSD", is_log)
if is_ind: res["InD"] = parse_result(o_ind[:n], "Instant Diffusion", is_log)
if fit_mode != 0: res["Fit"] = parse_result(o_fit[:n], "Fit", is_log, fit_mode)
# Restauration des identifiants de trajectoire
# TODO un fix devra être fait dans la DLL pour qu'elle stocke l'identifiant elle même et que cette partie devienne inutile
track_ids = pd.unique(tracks["Track"])
for key in res:
if len(res[key]) != track_ids.size: Ui.print_warning("Problem with trajectory identifiers, be careful with filtering")
else:
res[key].drop(columns=["Track"], inplace=True)
res[key].insert(0, "Track", track_ids)
return res
##################################################
[docs]
def align(self, stack: np.ndarray, factors: np.ndarray, upsampling: int = 1) -> np.ndarray:
"""
Exécute un traitement d'image avec une DLL PALM externe pour détecter des points dans une pile ou une image.
:param stack: Pile d'images en entrée sous forme de tableau numpy (possibilité d'envoyer une image directement).
:param factors: Facteurs d'alignement.
:param upsampling: Facteur d'agrandissement de l'image (par défaut : `1` aucun agrandissement).
:return: Image alignée.
"""
stk = self._as_c_contig(stack, np.dtype(np.uint16), writeable=False) # . Assurance de contiguité
params = self._as_c_contig(factors, np.dtype(np.float64), writeable=False) # Assurance de contiguité
height, width = stk.shape[-2:] # . Récupère les deux dernières dimensions
planes = 1 if stk.ndim == 2 else stk.shape[0] # . Récupère le nombre de plans si 3D sinon 1
out = np.empty((planes, height * upsampling, width * upsampling), dtype=np.uint16, order="C")
self._dll.Alignment(stk.ctypes.data_as(C_IMG), out.ctypes.data_as(C_IMG), C_UINT(height), C_UINT(width), C_UINT(planes),
params.ctypes.data_as(C_TAB), C_UINT(upsampling))
return out
##################################################
[docs]
def wavelett(self, stack: np.ndarray, level: int = 2) -> np.ndarray:
"""
Exécute un traitement d'image avec une DLL PALM externe pour détecter des points dans une pile ou une image.
:param stack: Pile d'images en entrée sous forme de tableau numpy (possibilité d'envoyer une image directement).
:param level: Niveau d'ondelette en sortie.
:return: Image décomposée.
"""
stk = self._as_c_contig(stack, np.dtype(np.uint16), writeable=False) # Assurance de contiguité
height, width = stk.shape[-2:] # . Récupère les deux dernières dimensions
planes = 1 if stk.ndim == 2 else stk.shape[0] # . Récupère le nombre de plans si 3D sinon 1
out = np.empty((planes, height, width), dtype=np.float64, order="C")
self._dll.Wavelett(stk.ctypes.data_as(C_IMG), out.ctypes.data_as(C_TAB), C_UINT(height), C_UINT(width), C_UINT(planes), C_UINT(level))
return out
##################################################
[docs]
def astigmatism_3d_calibration(self, points: np.ndarray, pixel_size: float, center: bool = True) -> pd.DataFrame:
"""
Exécute un traitement avec une DLL PALM externe pour calibrer un modèle d'astigmatisme permettant d'estimer une position axiale.
:param points: Ensemble des points nécessaire à la calibration sous forme de tableau numpy 2D avec pour colonnes [Sigma X, Sigma Y, Z].
:param pixel_size: Taille des pixels en nanomètres.
:param center: Permet de centrer le modèle pour que si :math:`\\sigma_x(0) \\approx \\sigma_y(0)`.
:return: Modèle d'astigmatisme (un tableau numpy 2D de 2 lignes et 5 paramètres par ligne).
"""
pts = self._as_c_contig(points, np.dtype(np.float64), writeable=False)
out = np.empty(SHAPE_MODEL, dtype=np.float64, order="C")
n = pts.shape[0]
self._dll.Astigmatism3DCalibration(pts.ctypes.data_as(C_TAB), out.ctypes.data_as(C_TAB), C_UINT(n), C_DBL(pixel_size), C_BOOL(center))
return parse_result(out, "Astigmatism 3D Model")
##################################################
[docs]
def astigmatism_3d_estimation(self, sigmas: np.ndarray, pixel_size: float, model: np.ndarray, z_max: float = 800) -> np.ndarray:
"""
Exécute un traitement avec une DLL PALM externe pour estimer une position axiale à partir d'un modèle.
:param sigmas: Ensemble des points à estimer sous forme de tableau numpy 2D avec pour colonnes [sigma_x, sigma_y].
:param pixel_size: Taille des pixels en nanomètres.
:param model: Modèle d'astigmatisme (un tableau numpy 2D de 2 lignes et 5 paramètres par ligne).
:param z_max: Distance absolue maximale sur Z par rapport à l'origine.
:return: Ensemble des Z estimés pour chaque point.
"""
n = len(sigmas)
s = self._as_c_contig(sigmas, np.dtype(np.float64), writeable=False)
m = self._as_c_contig(model, np.dtype(np.float64), writeable=False)
out = np.empty((n,), dtype=np.float64, order="C")
self._dll.Astigmatism3DEstimation(s.ctypes.data_as(C_TAB), out.ctypes.data_as(C_TAB), C_UINT(n), C_DBL(pixel_size),
m.ctypes.data_as(C_TAB), C_DBL(z_max))
return out