Source code for mascaf.validation

"""
Validation module for SWC models.

Provides the Validation class to compare SWC models against original mesh
and skeleton data.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Union

import trimesh

from .mesh import MeshManager
from .morphology_graph import MorphologyGraph
from .skeleton import SkeletonGraph

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


[docs] class Validation: """ Validation of MorphologyGraph models against original mesh and skeleton. This class provides methods to validate a MorphologyGraph by comparing it to the original mesh (geometry and surface) and the original skeleton (topology and centerline). Parameters ---------- mesh : trimesh.Trimesh or MeshManager The original mesh to validate against. Can be either a trimesh.Trimesh object or a MeshManager instance. skeleton : SkeletonGraph The original skeleton graph used to generate the morphology. morphology : MorphologyGraph or str or Path Either a MorphologyGraph instance to validate, or a path to an SWC file which will be loaded as a MorphologyGraph. Attributes ---------- mesh : trimesh.Trimesh The mesh object (extracted from MeshManager if needed). mesh_manager : MeshManager or None The MeshManager instance if provided, otherwise None. skeleton : SkeletonGraph The original skeleton graph. morphology : MorphologyGraph The morphology graph to validate. swc_path : Path or None Path to the SWC file if loaded from file, otherwise None. Examples -------- >>> from mascaf import MeshManager, SkeletonGraph, Validation >>> mesh_mgr = MeshManager(mesh_path="neuron.obj") >>> skeleton = SkeletonGraph.from_polylines(polylines) >>> # Option 1: Validate from SWC file >>> validator = Validation(mesh_mgr, skeleton, "output.swc") >>> # Option 2: Validate existing MorphologyGraph >>> graph = MorphologyGraph.from_swc_file("output.swc") >>> validator = Validation(mesh_mgr, skeleton, graph) >>> volume_ratio = validator.compare_volumes() >>> radius_errors = validator.validate_radii() """
[docs] def __init__( self, mesh: Union[trimesh.Trimesh, MeshManager], skeleton: SkeletonGraph, morphology: Union[MorphologyGraph, str, Path], ): # Handle mesh input if isinstance(mesh, MeshManager): self.mesh_manager = mesh self.mesh = mesh.mesh elif isinstance(mesh, trimesh.Trimesh): self.mesh = mesh self.mesh_manager = None else: raise TypeError( f"mesh must be trimesh.Trimesh or MeshManager, " f"got {type(mesh)}" ) # Handle skeleton input if not isinstance(skeleton, SkeletonGraph): raise TypeError(f"skeleton must be SkeletonGraph, got {type(skeleton)}") self.skeleton = skeleton # Handle morphology input if isinstance(morphology, MorphologyGraph): self.morphology = morphology self.swc_path = None elif isinstance(morphology, (str, Path)): self.swc_path = Path(morphology) if not self.swc_path.exists(): raise FileNotFoundError(f"SWC file not found: {self.swc_path}") self.morphology = MorphologyGraph.from_swc_file(str(self.swc_path)) else: raise TypeError( f"morphology must be MorphologyGraph or path to SWC file, " f"got {type(morphology)}" ) # Log initialization if self.swc_path: logger.info(f"Initialized Validation from SWC: {self.swc_path.name}") else: logger.info("Initialized Validation from MorphologyGraph") logger.info( f" Mesh: {len(self.mesh.vertices)} vertices, " f"{len(self.mesh.faces)} faces" ) logger.info( f" Skeleton: {self.skeleton.number_of_nodes()} nodes, " f"{self.skeleton.number_of_edges()} edges" ) logger.info( f" MorphologyGraph: {self.morphology.number_of_nodes()} nodes, " f"{self.morphology.number_of_edges()} edges" )
def __repr__(self) -> str: return ( f"Validation(\n" f" swc_path={self.swc_path.name if self.swc_path else None},\n" f" mesh_vertices={len(self.mesh.vertices)},\n" f" skeleton_nodes={self.skeleton.number_of_nodes()},\n" f" morphology_nodes={self.morphology.number_of_nodes()}\n" f")" )
[docs] def compare_volumes(self, account_for_overlaps: bool = False) -> dict: """ Compare total volume between mesh and morphology model. Computes the mesh volume using trimesh and the morphology volume by summing truncated cone volumes for each edge segment. Parameters ---------- account_for_overlaps : bool, default False If True, subtract branch-point overlap corrections in the morphology volume (see :meth:`MorphologyGraph.compute_volume`). Returns ------- dict Dictionary containing: - 'mesh_volume': float, volume of the mesh - 'morphology_volume': float, volume of the morphology model - 'ratio': float, morphology_volume / mesh_volume - 'absolute_difference': float, abs(morphology_volume - mesh_volume) - 'relative_error': float, abs difference / mesh_volume Examples -------- >>> result = validator.compare_volumes() >>> print(f"Volume ratio: {result['ratio']:.3f}") """ # Get mesh volume mesh_volume = float(self.mesh.volume) if not mesh_volume > 0.0: raise ValueError("Mesh has zero volume.") # Calculate morphology volume using MorphologyGraph method morphology_volume = self.morphology.compute_volume( account_for_overlaps=account_for_overlaps ) # Calculate comparison metrics ratio = morphology_volume / mesh_volume error = morphology_volume - mesh_volume rel_error = error / mesh_volume return { "mesh_volume": mesh_volume, "morphology_volume": morphology_volume, "ratio": ratio, "error": error, "relative_error": rel_error, }
[docs] def compare_surface_areas(self, account_for_overlaps: bool = False) -> dict: """ Compare total surface area between mesh and morphology model. Computes the mesh surface area using trimesh and the morphology surface area by summing lateral surface areas of truncated cones for each edge segment. Parameters ---------- account_for_overlaps : bool, default False If True, subtract branch-point overlap corrections in the morphology surface area (see :meth:`MorphologyGraph.compute_surface_area`). Returns ------- dict Dictionary containing: - 'mesh_area': float, surface area of the mesh - 'morphology_area': float, surface area of the morphology model - 'ratio': float, morphology_area / mesh_area - 'error': float, morphology_area - mesh_area - 'relative_error': float, error / mesh_area Examples -------- >>> result = validator.compare_surface_areas() >>> print(f"Surface area ratio: {result['ratio']:.3f}") """ # Get mesh surface area mesh_area = float(self.mesh.area) if not mesh_area > 0.0: raise ValueError("Mesh has zero area.") # Calculate morphology surface area using MorphologyGraph method morphology_area = self.morphology.compute_surface_area( account_for_overlaps=account_for_overlaps ) # Calculate comparison metrics ratio = morphology_area / mesh_area error = morphology_area - mesh_area rel_error = error / mesh_area return { "mesh_area": mesh_area, "morphology_area": morphology_area, "ratio": ratio, "error": error, "relative_error": rel_error, }
[docs] def full_validation(self) -> None: """Run all validation checks and log comprehensive results. Compares volumes and surface areas (with and without branch-point overlap correction when branch vertices are present) and logs each result via the module logger at ``INFO`` level. Returns ------- None Results are emitted via logging rather than returned. """ has_branch_vertices = any( self.morphology.degree[n] > 2 for n in self.morphology.nodes() ) if not has_branch_vertices: logger.info( "No branch vertices (degree > 2) found; overlap correction " "has no effect. Showing account_for_overlaps=False only." ) overlap_flags = [False, True] if has_branch_vertices else [False] for account_for_overlaps in overlap_flags: vol_result = self.compare_volumes(account_for_overlaps=account_for_overlaps) area_result = self.compare_surface_areas( account_for_overlaps=account_for_overlaps ) logger.info( f"Validation Results, account_for_overlaps={account_for_overlaps}:" ) logger.info("-- Volume Comparison:") logger.info(f"---- Mesh volume: {vol_result['mesh_volume']:.4f}") logger.info( f"---- Morphology volume: {vol_result['morphology_volume']:.4f}" ) logger.info(f"---- Ratio: {vol_result['ratio']:.4f}") logger.info(f"---- Error: {vol_result['error']:.4f}") logger.info(f"---- Relative error: {vol_result['relative_error']:.2%}") logger.info("-- Surface Area Comparison:") logger.info(f"---- Mesh area: {area_result['mesh_area']:.4f}") logger.info(f"---- Morphology area: {area_result['morphology_area']:.4f}") logger.info(f"---- Ratio: {area_result['ratio']:.4f}") logger.info(f"---- Error: {area_result['error']:.4f}") logger.info(f"---- Relative error: {area_result['relative_error']:.2%}") logger.info("") return None