Source code for mascaf.basis_optimizer

"""MorphologyGraph basis optimization prior to radius fitting."""

from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np
import trimesh

from .graph3d import Graph3D
from .morphology_graph import MorphologyGraph

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


[docs] @dataclass class BasisOptimizerOptions: """Configuration for morphology-basis optimization. All fields are keyword arguments to the dataclass constructor; see each field's inline annotation for defaults and semantics. """ do_pruning: bool = False pruning_min_length: Optional[float] = None pruning_min_length_fraction: Optional[float] = None pruning_iterative: bool = True do_snapping: bool = True do_forcing: bool = True snap_distance_multiplier: float = 1.1 max_iterations: int = 100 step_size: float = 0.1 convergence_threshold: float = 1e-4 preserve_terminal_nodes: bool = True preserve_branch_nodes: bool = False n_rays: int = 6 fallback_distance: float = 10.0 smoothing_weight: float = 0.5 verbose: bool = False
[docs] class BasisOptimizer: """Optimize a downsampled :class:`~mascaf.morphology_graph.MorphologyGraph` basis against a target mesh. Runs up to three sequential phases controlled by :class:`BasisOptimizerOptions`: 1. **Pruning** — remove short terminal branches. 2. **Snapping** — move outside nodes back inside the mesh. 3. **Forcing** — iteratively pull nodes toward the medial axis. Parameters ---------- graph : MorphologyGraph The morphology basis to optimize. A deep copy is made internally so the original is not modified. mesh : trimesh.Trimesh The target mesh that defines the interior/exterior and surface. options : BasisOptimizerOptions or None Optimization configuration. Defaults to :class:`BasisOptimizerOptions` with all defaults when ``None``. Examples -------- >>> from mascaf import BasisOptimizer, BasisOptimizerOptions >>> opts = BasisOptimizerOptions(do_snapping=True, do_forcing=True) >>> optimized = BasisOptimizer(morphology, mesh, opts).optimize() """
[docs] def __init__( self, graph: MorphologyGraph, mesh: trimesh.Trimesh, options: Optional[BasisOptimizerOptions] = None, ): self.graph = graph.copy() self.mesh = mesh self.options = options or BasisOptimizerOptions() self._surface_crossing_detected = False self._outside_nodes: list[int] = [] logger.debug( "Initialized BasisOptimizer with %d nodes, %d edges, do_pruning=%s, " "do_snapping=%s, do_forcing=%s", self.graph.number_of_nodes(), self.graph.number_of_edges(), self.options.do_pruning, self.options.do_snapping, self.options.do_forcing, )
[docs] def get_outside_nodes(self) -> Tuple[list[int], bool, int, float]: """Identify graph nodes that lie outside the mesh surface. Returns ------- tuple ``(outside_node_ids, has_crossing, num_outside, max_dist)`` where: * **outside_node_ids** — list of node IDs outside the mesh. * **has_crossing** — ``True`` if any node is outside. * **num_outside** — count of outside nodes. * **max_dist** — maximum distance from any outside node to the nearest mesh surface point. """ if self.graph.number_of_nodes() == 0: self._outside_nodes = [] self._surface_crossing_detected = False logger.debug("Outside-node query on empty morphology basis") return [], False, 0, 0.0 node_ids = list(self.graph.nodes()) all_pts = self.graph.get_all_positions() logger.debug( "Checking %d basis nodes against mesh containment", len(node_ids), ) try: inside_mask = self.mesh.contains(all_pts) outside_mask = ~inside_mask num_outside = int(np.sum(outside_mask)) outside_node_ids = [ node_ids[i] for i, is_out in enumerate(outside_mask) if is_out ] max_dist = 0.0 if num_outside > 0: from trimesh.proximity import closest_point outside_pts = all_pts[outside_mask] _, distances, _ = closest_point(self.mesh, outside_pts) max_dist = float(np.max(distances)) has_crossing = num_outside > 0 self._surface_crossing_detected = has_crossing self._outside_nodes = outside_node_ids logger.debug( ( "Outside-node detection complete: " "has_crossing=%s, outside_nodes=%s" ), has_crossing, outside_node_ids, ) return outside_node_ids, has_crossing, num_outside, max_dist except Exception as exc: logger.warning("Failed to identify outside basis nodes: %s", exc) self._outside_nodes = [] return [], False, 0, 0.0
[docs] def optimize(self) -> MorphologyGraph: """Run the configured optimization phases and return the result. Runs pruning, snapping, and forcing in sequence (each phase is skipped when its corresponding ``do_*`` flag is ``False``). Returns ------- MorphologyGraph The optimized morphology basis (a modified copy of the input). """ logger.info("Starting basis optimization...") logger.info(" Nodes: %d", self.graph.number_of_nodes()) if self.options.do_pruning: self._run_pruning_phase() else: logger.debug("Skipping pruning phase because do_pruning is False") if self.options.do_snapping: self._run_snapping_phase() else: logger.debug("Skipping snapping phase because do_snapping is False") if self.options.do_forcing: self._run_forcing_phase() else: logger.debug("Skipping forcing phase because do_forcing is False") self._update_edge_lengths() self.get_outside_nodes() logger.info("Basis optimization complete") return self.graph
def _run_pruning_phase(self) -> None: """Prune short terminal branches before geometric optimization.""" logger.info("Phase 0 - Pruning") if self.graph.number_of_nodes() == 0: logger.debug("Skipping pruning because basis graph is empty") return threshold = self._resolve_pruning_threshold() if threshold is None: logger.debug("Skipping pruning because no pruning threshold was configured") return logger.info(" Removing branches with length < %.4f", threshold) current = self.graph.copy() while True: terminal_nodes = sorted(current.get_terminal_nodes()) nodes_to_remove: set[int] = set() visited_terminals: set[int] = set() for terminal in terminal_nodes: if terminal in visited_terminals or terminal not in current: continue end, path, length = self._trace_from_terminal(current, terminal) if len(path) <= 1: visited_terminals.add(terminal) continue visited_terminals.add(terminal) if end != terminal and current.degree(end) == 1: visited_terminals.add(end) is_isolated = end != terminal and current.degree(end) == 1 ends_at_branch = end != terminal and current.degree(end) >= 3 should_remove = is_isolated or (ends_at_branch and length < threshold) if not should_remove: continue if ends_at_branch: nodes_to_remove.update(path[:-1]) else: nodes_to_remove.update(path) if not nodes_to_remove: break logger.debug( "Pruning %d nodes from short branches", len(nodes_to_remove), ) current.remove_nodes_from([n for n in nodes_to_remove if n in current]) if not self.options.pruning_iterative: break self.graph = current def _resolve_pruning_threshold(self) -> Optional[float]: """Resolve the branch-pruning threshold from absolute or fraction input.""" if self.options.pruning_min_length is not None: return float(self.options.pruning_min_length) fraction = self.options.pruning_min_length_fraction if fraction is None: return None branch_lengths = list(self._compute_branch_lengths(self.graph).values()) if not branch_lengths: return None if fraction <= 0 or fraction >= 1: raise ValueError(f"Pruning fraction must be in (0,1), got {fraction}") return float(np.percentile(branch_lengths, float(fraction * 100.0))) def _run_snapping_phase(self) -> None: """Snap outside basis nodes back into the mesh.""" outside_node_ids, has_crossing, num_outside, _ = self.get_outside_nodes() logger.info("Phase 1 - Snapping: %d nodes outside mesh", num_outside) logger.debug("Snapping candidate nodes: %s", outside_node_ids) if not has_crossing: logger.debug("No snapping required because all nodes are inside the mesh") return for node in outside_node_ids: pos = self.graph.get_node_position(node) direction, dist = self._compute_snap_direction(pos) if dist < 1e-10: logger.debug( "Skipping snap for node %s because distance is negligible", node, ) continue displacement = direction * dist * self.options.snap_distance_multiplier logger.debug( "Snapping node %s from %s with distance %.6f and displacement %s", node, pos, dist, displacement, ) self.graph.set_node_position(node, pos + displacement) def _run_forcing_phase(self) -> None: """Iteratively move basis nodes toward the medial axis.""" logger.info("Phase 2 - Forcing: max %d iterations", self.options.max_iterations) terminal_nodes = ( self.graph.get_terminal_nodes() if self.options.preserve_terminal_nodes else set() ) branch_nodes = ( self.graph.get_branch_nodes() if self.options.preserve_branch_nodes else set() ) for iteration in range(self.options.max_iterations): old_positions = self.graph.get_all_positions() if old_positions.size == 0: logger.info("Phase 2 - Forcing skipped because basis graph is empty") break for node in self.graph.nodes(): if node in terminal_nodes or node in branch_nodes: continue pos = self.graph.get_node_position(node) direction = self._compute_centering_direction(pos) smoothing_direction = np.zeros(3) if self.options.smoothing_weight > 0: smoothing_direction = self._compute_smoothing_direction_for_node( node ) total_direction = ( 1.0 - self.options.smoothing_weight ) * direction + self.options.smoothing_weight * smoothing_direction self.graph.set_node_position( node, pos + self.options.step_size * total_direction, ) movement = self._average_movement( old_positions, self.graph.get_all_positions(), ) logger.info(" Iteration %d: avg movement = %.6f", iteration, movement) if movement < self.options.convergence_threshold: logger.info(" Converged at iteration %d", iteration) break def _update_edge_lengths(self) -> None: """Update edge lengths after node positions have changed.""" for u, v in self.graph.edges(): pos_u = self.graph.get_node_position(u) pos_v = self.graph.get_node_position(v) self.graph.edges[u, v]["length"] = float(np.linalg.norm(pos_v - pos_u)) def _compute_smoothing_direction_for_node(self, node: int) -> np.ndarray: """Compute a unit smoothing direction from the node's neighbors.""" neighbors = list(self.graph.neighbors(node)) if not neighbors: return np.zeros(3) pos = self.graph.get_node_position(node) neighbor_positions = np.array( [self.graph.get_node_position(n) for n in neighbors], dtype=float, ) direction = neighbor_positions.mean(axis=0) - pos norm = np.linalg.norm(direction) if norm > 1e-10: return direction / norm return np.zeros(3) def _compute_centering_direction(self, point: np.ndarray) -> np.ndarray: """Compute a direction that moves a point toward the medial axis.""" is_inside = self.mesh.contains(point.reshape(1, 3))[0] if not is_inside: return self._compute_closest_point_direction(point) try: directions = self._get_uniform_sphere_directions(self.options.n_rays) force = np.zeros(3) for direction in directions: distance = self._ray_distance_to_surface(point, direction) if distance > 1e-6: force -= direction / distance force_mag = np.linalg.norm(force) if force_mag > 1e-10: return force / force_mag return np.zeros(3) except Exception as exc: logger.warning("Failed to compute centering direction: %s", exc) return self._compute_closest_point_direction(point) def _compute_snap_direction(self, point: np.ndarray) -> Tuple[np.ndarray, float]: """Return the direction and distance to the nearest mesh point.""" try: from trimesh.proximity import closest_point cp, _, _ = closest_point(self.mesh, point.reshape(1, 3)) surface_point = cp[0] to_surface = surface_point - point dist = float(np.linalg.norm(to_surface)) if dist < 1e-10: return np.zeros(3), 0.0 return to_surface / dist, dist except Exception as exc: logger.warning("Failed to compute snap direction: %s", exc) return np.zeros(3), 0.0 def _compute_closest_point_direction(self, point: np.ndarray) -> np.ndarray: """Fallback for outside points: move toward the closest mesh point.""" direction, _ = self._compute_snap_direction(point) return direction def _get_uniform_sphere_directions(self, n_points: int) -> np.ndarray: """Generate approximately uniform directions on the unit sphere.""" if n_points == 6: return np.array( [ [1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, -1.0], ] ) indices = np.arange(0, n_points, dtype=float) + 0.5 phi = (1 + np.sqrt(5)) / 2 theta = 2 * np.pi * indices / phi z = 1 - (2 * indices / n_points) radius = np.sqrt(1 - z * z) directions = np.column_stack( [radius * np.cos(theta), radius * np.sin(theta), z] ) norms = np.linalg.norm(directions, axis=1, keepdims=True) return directions / (norms + 1e-10) def _ray_distance_to_surface( self, point: np.ndarray, direction: np.ndarray ) -> float: """Compute distance from a point to the mesh surface along a ray.""" try: locations, _, _ = self.mesh.ray.intersects_location( ray_origins=point.reshape(1, 3), ray_directions=direction.reshape(1, 3), ) if len(locations) == 0: return self.options.fallback_distance return float(np.min(np.linalg.norm(locations - point, axis=1))) except Exception as exc: logger.warning("Ray tracing failed, using fallback distance: %s", exc) return self.options.fallback_distance
[docs] def get_optimization_stats(self) -> dict: """Return summary statistics for the optimized basis graph.""" _, _, num_outside, max_dist = self.get_outside_nodes() return { "surface_crossing_detected": self._surface_crossing_detected, "num_nodes": self.graph.number_of_nodes(), "num_edges": self.graph.number_of_edges(), "num_terminal_nodes": len(self.graph.get_terminal_nodes()), "num_branch_nodes": len(self.graph.get_branch_nodes()), "total_length": self.graph.get_total_length(), "nodes_outside_mesh": num_outside, "max_distance_outside": max_dist, }
def _compute_branch_lengths(self, graph: Graph3D) -> dict[tuple[int, int], float]: """Compute terminal-to-branch lengths for pruning thresholding.""" branch_lengths: dict[tuple[int, int], float] = {} for terminal in graph.get_terminal_nodes(): end, _, length = self._trace_from_terminal(graph, terminal) if end != terminal: branch_lengths[(terminal, end)] = length return branch_lengths def _trace_from_terminal( self, graph: Graph3D, start: int, ) -> tuple[int, list[int], float]: """Trace from a terminal node until degree differs from 2.""" if start not in graph or graph.degree(start) != 1: return start, [start], 0.0 path = [start] prev = None current = start length = 0.0 while True: nbrs = list(graph.neighbors(current)) if prev is not None: nbrs = [n for n in nbrs if n != prev] if not nbrs: break nxt = nbrs[0] length += self._edge_length(graph, current, nxt) prev, current = current, nxt path.append(current) if graph.degree(current) != 2: break return current, path, length def _edge_length(self, graph: Graph3D, u: int, v: int) -> float: """Return an edge length, computing it from geometry if needed.""" data = graph.get_edge_data(u, v) or {} length = data.get("length") if length is not None: return float(length) return float( np.linalg.norm(graph.get_node_position(v) - graph.get_node_position(u)) ) def _average_movement( self, old_positions: np.ndarray, new_positions: np.ndarray, ) -> float: """Return the mean per-node displacement between two position arrays.""" if old_positions.size == 0 or new_positions.size == 0: return 0.0 return float(np.linalg.norm(new_positions - old_positions, axis=1).mean())