Source code for mascaf.morphology_graph

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import brentq

from .graph3d import Graph3D
from swctools import SWCModel, FrustaSet, plot_model


[docs] @dataclass class Junction: """ Container for a traced skeleton node. Fields mirror what the tracing pipeline in `trace.py` constructs for each sample along a polyline. The essential geometry is in `xyz` (XYZ) and `radius`; other fields are retained for diagnostics/bookkeeping. """ id: int xyz: np.ndarray radius: float
[docs] class MorphologyGraph(Graph3D): """ Graph representation of neuronal morphology with radii. This class subclasses `networkx.Graph` and can contain cycles. Nodes should be keyed directly by their junction `id` and store at least the attributes `xyz` (3-vector) and `radius` (float). Use `to_swc_file()` to export to SWC format, which breaks cycles by duplicating nodes. """
[docs] def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs)
# Populated by trace.build_traced_skeleton_graph for optional SWC export adjustments position_attr = "xyz"
[docs] @classmethod def from_swc_file(cls, path: str) -> "MorphologyGraph": """Load a MorphologyGraph from an SWC file, restoring cycles. Reads an SWC file and creates a MorphologyGraph. If the file contains CYCLE_BREAK annotations in the header (generated by to_swc_file), this method will restore the original cycles by reconnecting duplicate nodes to their originals. Parameters ---------- path : str Path to the SWC file to load. Returns ------- MorphologyGraph Graph with nodes containing 'xyz' and 'radius' attributes. If CYCLE_BREAK annotations are present, cycles are restored. Examples -------- >>> graph = MorphologyGraph.from_swc_file("output.swc") >>> graph.print_attributes() """ # Load SWC file using swctools swc_model = SWCModel.from_swc_file(path) # Parse header comments for CYCLE_BREAK annotations cycle_breaks = [] # List of (duplicate_id, original_id) tuples try: with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line.startswith("#"): break # Look for "# CYCLE_BREAK reconnect dup_id orig_id" if "CYCLE_BREAK" in line and "reconnect" in line: parts = line.split() if len(parts) >= 4: try: dup_id = int(parts[-2]) orig_id = int(parts[-1]) cycle_breaks.append((dup_id, orig_id)) except (ValueError, IndexError): pass except Exception: pass # Create MorphologyGraph and add nodes graph = cls() # Add all nodes from SWC model for node_id in swc_model.nodes(): node_data = swc_model.nodes[node_id] graph.add_node( int(node_id), xyz=np.array( [node_data["x"], node_data["y"], node_data["z"]], dtype=float ), radius=float(node_data["r"]), ) # Add edges from SWC model (already undirected in SWCModel) for u, v in swc_model.edges(): graph.add_edge(int(u), int(v)) # Restore cycles by reconnecting duplicates to originals for dup_id, orig_id in cycle_breaks: # Remove the duplicate node's parent edge if dup_id in graph: # Find and remove edges connected to duplicate dup_neighbors = list(graph.neighbors(dup_id)) for neighbor in dup_neighbors: graph.remove_edge(dup_id, neighbor) # Add edge from original to the duplicate's former parent # (reconnecting the cycle) if dup_neighbors and orig_id in graph: # Connect original node to duplicate's parent for neighbor in dup_neighbors: if neighbor != orig_id: graph.add_edge(orig_id, neighbor) # Remove the duplicate node graph.remove_node(dup_id) return graph
[docs] def add_junction(self, j: Junction) -> None: """Add a :class:`Junction` as a graph node. The node key is ``j.id``; stored attributes are ``xyz`` and ``radius``. Parameters ---------- j : Junction The junction to add. """ self.add_node( int(j.id), xyz=j.xyz, radius=float(j.radius), )
[docs] def copy(self) -> "MorphologyGraph": """Return a deep copy of the graph with all node arrays copied. Returns ------- MorphologyGraph A new graph with the same topology and independent node data. """ new_graph = MorphologyGraph() new_graph.graph.update(dict(self.graph)) for node, data in self.nodes(data=True): node_data = dict(data) if "xyz" in node_data: node_data["xyz"] = np.asarray(node_data["xyz"], dtype=float).copy() new_graph.add_node(node, **node_data) for u, v, data in self.edges(data=True): new_graph.add_edge(u, v, **dict(data)) return new_graph
[docs] def to_swc_model(self) -> SWCModel: """Convert MorphologyGraph to a SWCModel instance. Creates a SWCModel by adding all nodes with their attributes (x, y, z, r, t) and edges from this graph. Note that SWCModel is also a NetworkX Graph, so this is a conversion between graph types. Returns ------- SWCModel A SWCModel instance with the same topology and attributes. Examples -------- >>> graph = MorphologyGraph() >>> # ... add nodes and edges ... >>> swc_model = graph.to_swc_model() >>> frusta = FrustaSet.from_swc_model(swc_model) """ swc_model = SWCModel() # Add all nodes with required SWC attributes for node_id, attrs in self.nodes(data=True): xyz = attrs.get("xyz", np.array([0.0, 0.0, 0.0])) radius = attrs.get("radius", 1.0) # SWC type: default to 3 (dendrite) swc_type = attrs.get("t", 3) swc_model.add_node( int(node_id), x=float(xyz[0]), y=float(xyz[1]), z=float(xyz[2]), r=float(radius), t=int(swc_type), ) # Add all edges for u, v in self.edges(): swc_model.add_edge(int(u), int(v)) return swc_model
[docs] def compute_volume(self, account_for_overlaps: bool = False) -> float: """Compute total volume of the morphology as sum of frustum segments. Each edge represents a truncated cone (frustum) connecting two nodes. The volume of a frustum is: V = (π*h/3) * (r1² + r1*r2 + r2²) where h is the length and r1, r2 are the radii at the endpoints. For nodes with degree > 2 (branch points), overlap correction is applied by subtracting half a ball volume per edge beyond 2. Parameters ---------- account_for_overlaps : bool, default False If True, subtract branch-point overlap corrections from the naive frustum sum (half a ball volume per edge beyond two at each junction). Returns ------- float Total volume of all segments in the morphology. Examples -------- >>> graph = MorphologyGraph() >>> # ... add nodes and edges ... >>> volume = graph.compute_volume() """ return self._metric_at_uniform_radius_scale( 1.0, metric="volume", account_for_overlaps=account_for_overlaps )
[docs] def compute_surface_area(self, account_for_overlaps: bool = False) -> float: """Compute total lateral surface area of the morphology. Each edge represents a truncated cone (frustum) connecting two nodes. The lateral surface area is: A = π * (r1 + r2) * s where s = sqrt(h² + (r1 - r2)²) is the slant height. End caps are added for terminal nodes (degree 1). For nodes with degree > 2 (branch points), overlap correction is applied by subtracting quarter of a ball surface area per edge beyond 2. Parameters ---------- account_for_overlaps : bool, default False If True, subtract branch-point overlap corrections from the naive sum (quarter of a ball surface area per edge beyond two at each junction). Returns ------- float Total lateral surface area of all segments in the morphology. Examples -------- >>> graph = MorphologyGraph() >>> # ... add nodes and edges ... >>> area = graph.compute_surface_area() """ return self._metric_at_uniform_radius_scale( 1.0, metric="surface_area", account_for_overlaps=account_for_overlaps )
def _metric_at_uniform_radius_scale( self, k: float, *, metric: str, account_for_overlaps: bool, ) -> float: """Return SA or volume if every node radius were multiplied by ``k`` (read-only).""" k = float(k) if metric == "volume": total_volume = 0.0 for u, v in self.edges(): xyz_u = self.nodes[u]["xyz"] xyz_v = self.nodes[v]["xyz"] r_u = self.nodes[u]["radius"] * k r_v = self.nodes[v]["radius"] * k h = np.linalg.norm(xyz_v - xyz_u) total_volume += (np.pi * h / 3.0) * (r_u**2 + r_u * r_v + r_v**2) for node_id in self.nodes(): degree = self.degree[node_id] if degree > 2 and account_for_overlaps: r = self.nodes[node_id]["radius"] * k num_overlaps = degree - 2 overlap_volume = np.pi * r**3 / 3.0 total_volume -= num_overlaps * overlap_volume return float(total_volume) if metric != "surface_area": raise ValueError( f"metric must be 'surface_area' or 'volume', got {metric!r}" ) total_area = 0.0 for u, v in self.edges(): xyz_u = self.nodes[u]["xyz"] xyz_v = self.nodes[v]["xyz"] r_u = self.nodes[u]["radius"] * k r_v = self.nodes[v]["radius"] * k h = np.linalg.norm(xyz_v - xyz_u) s = np.sqrt(h**2 + (r_u - r_v) ** 2) total_area += np.pi * (r_u + r_v) * s for node_id in self.nodes(): degree = self.degree[node_id] r = self.nodes[node_id]["radius"] * k if degree == 1: total_area += np.pi * r**2 elif degree > 2 and account_for_overlaps: num_overlaps = degree - 2 overlap_area = np.pi * r**2 total_area -= num_overlaps * overlap_area return float(total_area) def _solve_uniform_radius_scale_factor( self, target: float, *, metric: str, account_for_overlaps: bool, rtol: float = 1e-9, atol: float = 1e-12, ) -> float: """Solve for k>0 such that _metric_at_uniform_radius_scale(k, ...) == target.""" def m_at(kk: float) -> float: return self._metric_at_uniform_radius_scale( kk, metric=metric, account_for_overlaps=account_for_overlaps ) def residual(kk: float) -> float: return m_at(kk) - target m1 = m_at(1.0) if np.isclose(m1, target, rtol=rtol, atol=atol): return 1.0 grow = 2.0 if m1 < target: lo, m_lo = 1.0, m1 hi = 1.0 m_hi = m1 for _ in range(100): lo, m_lo = hi, m_hi hi = lo * grow m_hi = m_at(hi) if m_hi >= target or np.isclose(m_hi, target, rtol=rtol, atol=atol): break if m_hi < m_lo: raise ValueError( "Target morphology metric exceeds the maximum achievable " "with the current cable model (e.g. volume with " "account_for_overlaps=True can decrease after a peak as " "radii grow). Try a different metric or overlap setting." ) else: raise ValueError( "Could not bracket a uniform radius scale factor for the target metric." ) else: hi, m_hi = 1.0, m1 lo = 1.0 m_lo = m1 for _ in range(100): hi, m_hi = lo, m_lo lo = hi / grow if lo < 1e-30: raise ValueError( "Could not scale radii down enough to reach the target metric." ) m_lo = m_at(lo) if m_lo <= target or np.isclose(m_lo, target, rtol=rtol, atol=atol): break else: raise ValueError( "Could not bracket a uniform radius scale factor for the target metric." ) if np.isclose(m_lo, target, rtol=rtol, atol=atol): return float(lo) if np.isclose(m_hi, target, rtol=rtol, atol=atol): return float(hi) a, b = float(lo), float(hi) if residual(a) * residual(b) > 0: raise ValueError( "Failed to bracket the scale factor (non-monotonic metric or numerical issue)." ) root = brentq(residual, a, b, xtol=atol, rtol=rtol, maxiter=200) return float(root)
[docs] def scale_radii_to_match_mesh( self, mesh, metric: str = "surface_area", account_for_overlaps: bool = False, ) -> float: """Scale all radii to match the mesh's surface area or volume. This method finds a uniform scaling factor ``k`` for all radii such that the cable model's surface area or volume (as defined by :meth:`compute_surface_area` / :meth:`compute_volume`) equals that of the input mesh. Because lateral frustum area and volume do not scale as pure powers of ``k`` when edge lengths are fixed, ``k`` is computed with a one-dimensional root solve (not ``sqrt`` / ``cbrt`` of a single ratio). Parameters ---------- mesh : trimesh.Trimesh or MeshManager The mesh to match. Can be either a trimesh.Trimesh object or a MeshManager instance. metric : str, default "surface_area" Which metric to match. Options are: - "surface_area": Match total surface area - "volume": Match total volume account_for_overlaps : bool, default False If True, subtract branch-point overlap corrections when computing the morphology's surface area or volume (same as for :meth:`compute_surface_area` / :meth:`compute_volume`). Returns ------- float The scaling factor applied to all radii. Raises ------ ValueError If metric is not "surface_area" or "volume", or if the mesh has zero area/volume, or if the morphology has zero area/volume. Examples -------- >>> from mascaf import MeshManager, MorphologyGraph >>> mesh_mgr = MeshManager(mesh_path="neuron.obj") >>> graph = MorphologyGraph.from_swc_file("output.swc") >>> scale_factor = graph.scale_radii_to_match_mesh(mesh_mgr) >>> print(f"Radii scaled by factor: {scale_factor:.3f}") """ # Handle mesh input - extract trimesh.Trimesh if needed try: from .mesh import MeshManager except ImportError: MeshManager = None if MeshManager is not None and isinstance(mesh, MeshManager): mesh_obj = mesh.mesh else: # Assume it's already a trimesh.Trimesh mesh_obj = mesh # Validate metric if metric not in ["surface_area", "volume"]: raise ValueError( f"metric must be 'surface_area' or 'volume', got '{metric}'" ) # Get target value from mesh if metric == "surface_area": target_value = float(mesh_obj.area) if target_value <= 0.0: raise ValueError("Mesh has zero or negative surface area") current_value = self.compute_surface_area( account_for_overlaps=account_for_overlaps ) else: # volume target_value = float(mesh_obj.volume) if target_value <= 0.0: raise ValueError("Mesh has zero or negative volume") current_value = self.compute_volume( account_for_overlaps=account_for_overlaps ) if current_value <= 0.0: raise ValueError( f"Morphology has zero or negative {metric.replace('_', ' ')}" ) scale_factor = self._solve_uniform_radius_scale_factor( target_value, metric=metric, account_for_overlaps=account_for_overlaps, ) # Apply scaling to all node radii for node_id in self.nodes(): self.nodes[node_id]["radius"] *= scale_factor return scale_factor
[docs] def print_attributes( self, *, node_info: bool = False, edge_info: bool = False ) -> None: """Print graph attributes and optional node/edge details. Parameters ---------- node_info: bool If True, print per-node attributes (xyz, radius, and any other attributes). edge_info: bool If True, print all edges (u -- v) with edge attributes if any. """ # Compute graph statistics nodes = self.number_of_nodes() edges = self.number_of_edges() components = nx.number_connected_components(self) # Count cycles by checking if graph is a forest try: cycles = 0 if nx.is_forest(self) else len(nx.cycle_basis(self)) except Exception: cycles = "?" # Count branch points (degree > 2), leaves (degree == 1), and self-loops branch_points = sum(1 for n in self.nodes() if self.degree[n] > 2) leaves = sum(1 for n in self.nodes() if self.degree[n] == 1) self_loops = nx.number_of_selfloops(self) # Compute density density = nx.density(self) if nodes > 0 else 0.0 header = ( f"MorphologyGraph: nodes={nodes}, edges={edges}, " f"components={components}, cycles={cycles}, " f"branch_points={branch_points}, " f"leaves={leaves}, self_loops={self_loops}, density={density:.4f}" ) print(header) if node_info: print("Nodes:") for n, attrs in self.nodes(data=True): parts = [] # Print xyz if present if "xyz" in attrs: xyz = attrs["xyz"] parts.append(f"xyz=({xyz[0]:.4f}, {xyz[1]:.4f}, {xyz[2]:.4f})") # Print radius if present if "radius" in attrs: parts.append(f"r={attrs['radius']:.4f}") # Print other attributes for k, v in attrs.items(): if k not in ["xyz", "radius"]: parts.append(f"{k}={v}") print(f" {n}: " + ", ".join(parts)) if edge_info: print("Edges:") for u, v, attrs in self.edges(data=True): if attrs: print(f" {u} -- {v}: {dict(attrs)}") else: print(f" {u} -- {v}")
# ------------------------------------------------------------------ # SWC Export # ------------------------------------------------------------------
[docs] def to_swc_file( self, path: str | None = None, *, tag: int = 3, annotate_breaks: bool = True, ) -> str: """Export the skeleton to SWC format, breaking cycles by duplicating nodes. The SWC format is a line-based format with columns ``n T x y z R parent``, where ``n`` is the node id, ``T`` is the SWC type index, ``x,y,z`` are coordinates, ``R`` is radius, and ``parent`` is the parent's id (or -1 for the root). Constructs a spanning forest over the undirected graph; for every non-tree edge that would introduce a cycle it duplicates one endpoint and attaches the duplicate as a child of the other endpoint. Parameters ---------- path : str or None If provided, write the SWC text to this file. If ``None``, return the SWC text as a string. tag : int, default 3 Integer placed in the SWC ``T`` (type) column for all nodes. annotate_breaks : bool, default True If ``True``, include header comment lines indicating how to reconnect duplicates to restore each broken cycle. Returns ------- str The SWC text (whether or not it was also written to a file). Raises ------ KeyError If any node is missing the ``xyz`` or ``radius`` attribute. """ # Collect original node ids and attributes if self.number_of_nodes() == 0: swc_text = "# Empty skeleton\n" if path is None: return swc_text with open(path, "w", encoding="utf-8") as f: f.write(swc_text) return swc_text # Ensure required attributes exist for n in self.nodes: if "xyz" not in self.nodes[n] or "radius" not in self.nodes[n]: raise KeyError( f"Node {n} missing required attributes 'xyz' and 'radius'" ) # Build a spanning forest per connected component with DFS, choosing # a terminal (degree==1) node as root when available. Create a DFS # visitation order to determine SWC indices sequentially starting at 1. parents_orig: dict[int, int] = {} tree_edges: set[frozenset[int]] = set() comp_roots: list[int] = [] comp_orders: list[list[int]] = [] # Sort components deterministically by min node id components = [ sorted(int(x) for x in comp) for comp in nx.connected_components(self) ] components.sort(key=lambda nodes: (len(nodes) == 0, nodes[0] if nodes else -1)) for comp_nodes in components: if not comp_nodes: continue # Prefer a terminal as root terminals = [n for n in comp_nodes if int(self.degree[n]) == 1] root = min(terminals) if terminals else min(comp_nodes) comp_roots.append(root) parents_orig[int(root)] = -1 # DFS tree for u, v in nx.dfs_edges(self, source=root): u = int(u) v = int(v) parents_orig[v] = u tree_edges.add(frozenset({u, v})) # DFS visitation order (preorder) order = [int(n) for n in nx.dfs_preorder_nodes(self, source=root)] comp_orders.append(order) # Ensure isolated nodes get parent -1 for n in comp_nodes: if n not in parents_orig: parents_orig[int(n)] = -1 # Determine non-tree edges (these would create cycles) extra_edges: list[tuple[int, int]] = [] for u, v in self.edges(): e = frozenset({int(u), int(v)}) if e not in tree_edges: extra_edges.append((int(u), int(v))) # Assign new SWC indices according to concatenated DFS orders new_id: dict[int, int] = {} order_all: list[int] = [] for od in comp_orders: order_all.extend(od) next_index = 1 for n in order_all: if n not in new_id: new_id[n] = next_index next_index += 1 # Prepare SWC entries for original nodes using DFS-based indices entries: list[tuple[int, int, float, float, float, float, int]] = [] for n in order_all: nid = new_id[n] xyz = np.asarray(self.nodes[n]["xyz"], dtype=float).reshape(3) r = float(self.nodes[n].get("radius", 0.0)) parent_orig = parents_orig.get(n, -1) parent_id = new_id[parent_orig] if parent_orig in new_id else -1 entries.append( (int(nid), int(tag), xyz[0], xyz[1], xyz[2], r, int(parent_id)) ) # Process extra edges by duplicating one endpoint and attaching to the other cycle_annotations: list[tuple[int, int]] = ( [] ) # (duplicate_swc_id, original_swc_id) for u, v in extra_edges: # Choose which node to duplicate: prefer higher degree (branching) deg_u = self.degree[u] deg_v = self.degree[v] dup_orig = v if deg_v >= deg_u else u other_orig = u if dup_orig == v else v xyz = np.asarray(self.nodes[dup_orig]["xyz"], dtype=float).reshape(3) r = float(self.nodes[dup_orig].get("radius", 0.0)) dup_swc = int(next_index) parent_swc = int(new_id.get(other_orig, -1)) entries.append((dup_swc, int(tag), xyz[0], xyz[1], xyz[2], r, parent_swc)) # Record annotation using SWC indices cycle_annotations.append((dup_swc, int(new_id.get(dup_orig, dup_swc)))) next_index += 1 # Compose SWC text lines: list[str] = [] lines.append("# generated by mascaf MorphologyGraph.to_swc_file") lines.append( f"# dfs_roots={' '.join(str(new_id.get(r, r)) for r in comp_roots)}" ) lines.append( f"# nodes={self.number_of_nodes()} extra_edges={len(extra_edges)} duplicates={len(cycle_annotations)}" ) lines.append(f"# tag={int(tag)}") if annotate_breaks and cycle_annotations: for dup_id, orig_id in cycle_annotations: lines.append(f"# CYCLE_BREAK reconnect {dup_id} {orig_id}") # Sort entries by SWC id for readability entries.sort(key=lambda t: int(t[0])) for nid, T, x, y, z, R, parent in entries: lines.append(f"{nid} {T} {x:.6f} {y:.6f} {z:.6f} {R:.6f} {parent}") swc_text = "\n".join(lines) + "\n" if path is None: return swc_text with open(path, "w", encoding="utf-8") as f: f.write(swc_text) return swc_text
__all__ = ["MorphologyGraph", "Junction"]