from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import brentq
from .graph3d import Graph3D
from swctools import SWCModel, FrustaSet, plot_model
[docs]
@dataclass
class Junction:
"""
Container for a traced skeleton node.
Fields mirror what the tracing pipeline in `trace.py` constructs for each
sample along a polyline. The essential geometry is in `xyz` (XYZ) and
`radius`; other fields are retained for diagnostics/bookkeeping.
"""
id: int
xyz: np.ndarray
radius: float
[docs]
class MorphologyGraph(Graph3D):
"""
Graph representation of neuronal morphology with radii.
This class subclasses `networkx.Graph` and can contain cycles. Nodes should be
keyed directly by their junction `id` and store at least the attributes `xyz`
(3-vector) and `radius` (float). Use `to_swc_file()` to export to SWC format,
which breaks cycles by duplicating nodes.
"""
[docs]
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Populated by trace.build_traced_skeleton_graph for optional SWC export adjustments
position_attr = "xyz"
[docs]
@classmethod
def from_swc_file(cls, path: str) -> "MorphologyGraph":
"""Load a MorphologyGraph from an SWC file, restoring cycles.
Reads an SWC file and creates a MorphologyGraph. If the file contains
CYCLE_BREAK annotations in the header (generated by to_swc_file),
this method will restore the original cycles by reconnecting duplicate
nodes to their originals.
Parameters
----------
path : str
Path to the SWC file to load.
Returns
-------
MorphologyGraph
Graph with nodes containing 'xyz' and 'radius' attributes.
If CYCLE_BREAK annotations are present, cycles are restored.
Examples
--------
>>> graph = MorphologyGraph.from_swc_file("output.swc")
>>> graph.print_attributes()
"""
# Load SWC file using swctools
swc_model = SWCModel.from_swc_file(path)
# Parse header comments for CYCLE_BREAK annotations
cycle_breaks = [] # List of (duplicate_id, original_id) tuples
try:
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line.startswith("#"):
break
# Look for "# CYCLE_BREAK reconnect dup_id orig_id"
if "CYCLE_BREAK" in line and "reconnect" in line:
parts = line.split()
if len(parts) >= 4:
try:
dup_id = int(parts[-2])
orig_id = int(parts[-1])
cycle_breaks.append((dup_id, orig_id))
except (ValueError, IndexError):
pass
except Exception:
pass
# Create MorphologyGraph and add nodes
graph = cls()
# Add all nodes from SWC model
for node_id in swc_model.nodes():
node_data = swc_model.nodes[node_id]
graph.add_node(
int(node_id),
xyz=np.array(
[node_data["x"], node_data["y"], node_data["z"]], dtype=float
),
radius=float(node_data["r"]),
)
# Add edges from SWC model (already undirected in SWCModel)
for u, v in swc_model.edges():
graph.add_edge(int(u), int(v))
# Restore cycles by reconnecting duplicates to originals
for dup_id, orig_id in cycle_breaks:
# Remove the duplicate node's parent edge
if dup_id in graph:
# Find and remove edges connected to duplicate
dup_neighbors = list(graph.neighbors(dup_id))
for neighbor in dup_neighbors:
graph.remove_edge(dup_id, neighbor)
# Add edge from original to the duplicate's former parent
# (reconnecting the cycle)
if dup_neighbors and orig_id in graph:
# Connect original node to duplicate's parent
for neighbor in dup_neighbors:
if neighbor != orig_id:
graph.add_edge(orig_id, neighbor)
# Remove the duplicate node
graph.remove_node(dup_id)
return graph
[docs]
def add_junction(self, j: Junction) -> None:
"""Add a :class:`Junction` as a graph node.
The node key is ``j.id``; stored attributes are ``xyz`` and
``radius``.
Parameters
----------
j : Junction
The junction to add.
"""
self.add_node(
int(j.id),
xyz=j.xyz,
radius=float(j.radius),
)
[docs]
def copy(self) -> "MorphologyGraph":
"""Return a deep copy of the graph with all node arrays copied.
Returns
-------
MorphologyGraph
A new graph with the same topology and independent node data.
"""
new_graph = MorphologyGraph()
new_graph.graph.update(dict(self.graph))
for node, data in self.nodes(data=True):
node_data = dict(data)
if "xyz" in node_data:
node_data["xyz"] = np.asarray(node_data["xyz"], dtype=float).copy()
new_graph.add_node(node, **node_data)
for u, v, data in self.edges(data=True):
new_graph.add_edge(u, v, **dict(data))
return new_graph
[docs]
def to_swc_model(self) -> SWCModel:
"""Convert MorphologyGraph to a SWCModel instance.
Creates a SWCModel by adding all nodes with their attributes
(x, y, z, r, t) and edges from this graph. Note that SWCModel
is also a NetworkX Graph, so this is a conversion between
graph types.
Returns
-------
SWCModel
A SWCModel instance with the same topology and attributes.
Examples
--------
>>> graph = MorphologyGraph()
>>> # ... add nodes and edges ...
>>> swc_model = graph.to_swc_model()
>>> frusta = FrustaSet.from_swc_model(swc_model)
"""
swc_model = SWCModel()
# Add all nodes with required SWC attributes
for node_id, attrs in self.nodes(data=True):
xyz = attrs.get("xyz", np.array([0.0, 0.0, 0.0]))
radius = attrs.get("radius", 1.0)
# SWC type: default to 3 (dendrite)
swc_type = attrs.get("t", 3)
swc_model.add_node(
int(node_id),
x=float(xyz[0]),
y=float(xyz[1]),
z=float(xyz[2]),
r=float(radius),
t=int(swc_type),
)
# Add all edges
for u, v in self.edges():
swc_model.add_edge(int(u), int(v))
return swc_model
[docs]
def compute_volume(self, account_for_overlaps: bool = False) -> float:
"""Compute total volume of the morphology as sum of frustum segments.
Each edge represents a truncated cone (frustum) connecting two nodes.
The volume of a frustum is: V = (π*h/3) * (r1² + r1*r2 + r2²)
where h is the length and r1, r2 are the radii at the endpoints.
For nodes with degree > 2 (branch points), overlap correction is
applied by subtracting half a ball volume per edge beyond 2.
Parameters
----------
account_for_overlaps : bool, default False
If True, subtract branch-point overlap corrections from the naive
frustum sum (half a ball volume per edge beyond two at each
junction).
Returns
-------
float
Total volume of all segments in the morphology.
Examples
--------
>>> graph = MorphologyGraph()
>>> # ... add nodes and edges ...
>>> volume = graph.compute_volume()
"""
return self._metric_at_uniform_radius_scale(
1.0, metric="volume", account_for_overlaps=account_for_overlaps
)
[docs]
def compute_surface_area(self, account_for_overlaps: bool = False) -> float:
"""Compute total lateral surface area of the morphology.
Each edge represents a truncated cone (frustum) connecting two nodes.
The lateral surface area is: A = π * (r1 + r2) * s
where s = sqrt(h² + (r1 - r2)²) is the slant height.
End caps are added for terminal nodes (degree 1).
For nodes with degree > 2 (branch points), overlap correction is applied
by subtracting quarter of a ball surface area per edge beyond 2.
Parameters
----------
account_for_overlaps : bool, default False
If True, subtract branch-point overlap corrections from the naive
sum (quarter of a ball surface area per edge beyond two at each junction).
Returns
-------
float
Total lateral surface area of all segments in the morphology.
Examples
--------
>>> graph = MorphologyGraph()
>>> # ... add nodes and edges ...
>>> area = graph.compute_surface_area()
"""
return self._metric_at_uniform_radius_scale(
1.0, metric="surface_area", account_for_overlaps=account_for_overlaps
)
def _metric_at_uniform_radius_scale(
self,
k: float,
*,
metric: str,
account_for_overlaps: bool,
) -> float:
"""Return SA or volume if every node radius were multiplied by ``k`` (read-only)."""
k = float(k)
if metric == "volume":
total_volume = 0.0
for u, v in self.edges():
xyz_u = self.nodes[u]["xyz"]
xyz_v = self.nodes[v]["xyz"]
r_u = self.nodes[u]["radius"] * k
r_v = self.nodes[v]["radius"] * k
h = np.linalg.norm(xyz_v - xyz_u)
total_volume += (np.pi * h / 3.0) * (r_u**2 + r_u * r_v + r_v**2)
for node_id in self.nodes():
degree = self.degree[node_id]
if degree > 2 and account_for_overlaps:
r = self.nodes[node_id]["radius"] * k
num_overlaps = degree - 2
overlap_volume = np.pi * r**3 / 3.0
total_volume -= num_overlaps * overlap_volume
return float(total_volume)
if metric != "surface_area":
raise ValueError(
f"metric must be 'surface_area' or 'volume', got {metric!r}"
)
total_area = 0.0
for u, v in self.edges():
xyz_u = self.nodes[u]["xyz"]
xyz_v = self.nodes[v]["xyz"]
r_u = self.nodes[u]["radius"] * k
r_v = self.nodes[v]["radius"] * k
h = np.linalg.norm(xyz_v - xyz_u)
s = np.sqrt(h**2 + (r_u - r_v) ** 2)
total_area += np.pi * (r_u + r_v) * s
for node_id in self.nodes():
degree = self.degree[node_id]
r = self.nodes[node_id]["radius"] * k
if degree == 1:
total_area += np.pi * r**2
elif degree > 2 and account_for_overlaps:
num_overlaps = degree - 2
overlap_area = np.pi * r**2
total_area -= num_overlaps * overlap_area
return float(total_area)
def _solve_uniform_radius_scale_factor(
self,
target: float,
*,
metric: str,
account_for_overlaps: bool,
rtol: float = 1e-9,
atol: float = 1e-12,
) -> float:
"""Solve for k>0 such that _metric_at_uniform_radius_scale(k, ...) == target."""
def m_at(kk: float) -> float:
return self._metric_at_uniform_radius_scale(
kk, metric=metric, account_for_overlaps=account_for_overlaps
)
def residual(kk: float) -> float:
return m_at(kk) - target
m1 = m_at(1.0)
if np.isclose(m1, target, rtol=rtol, atol=atol):
return 1.0
grow = 2.0
if m1 < target:
lo, m_lo = 1.0, m1
hi = 1.0
m_hi = m1
for _ in range(100):
lo, m_lo = hi, m_hi
hi = lo * grow
m_hi = m_at(hi)
if m_hi >= target or np.isclose(m_hi, target, rtol=rtol, atol=atol):
break
if m_hi < m_lo:
raise ValueError(
"Target morphology metric exceeds the maximum achievable "
"with the current cable model (e.g. volume with "
"account_for_overlaps=True can decrease after a peak as "
"radii grow). Try a different metric or overlap setting."
)
else:
raise ValueError(
"Could not bracket a uniform radius scale factor for the target metric."
)
else:
hi, m_hi = 1.0, m1
lo = 1.0
m_lo = m1
for _ in range(100):
hi, m_hi = lo, m_lo
lo = hi / grow
if lo < 1e-30:
raise ValueError(
"Could not scale radii down enough to reach the target metric."
)
m_lo = m_at(lo)
if m_lo <= target or np.isclose(m_lo, target, rtol=rtol, atol=atol):
break
else:
raise ValueError(
"Could not bracket a uniform radius scale factor for the target metric."
)
if np.isclose(m_lo, target, rtol=rtol, atol=atol):
return float(lo)
if np.isclose(m_hi, target, rtol=rtol, atol=atol):
return float(hi)
a, b = float(lo), float(hi)
if residual(a) * residual(b) > 0:
raise ValueError(
"Failed to bracket the scale factor (non-monotonic metric or numerical issue)."
)
root = brentq(residual, a, b, xtol=atol, rtol=rtol, maxiter=200)
return float(root)
[docs]
def scale_radii_to_match_mesh(
self,
mesh,
metric: str = "surface_area",
account_for_overlaps: bool = False,
) -> float:
"""Scale all radii to match the mesh's surface area or volume.
This method finds a uniform scaling factor ``k`` for all radii such that
the cable model's surface area or volume (as defined by
:meth:`compute_surface_area` / :meth:`compute_volume`) equals that of the
input mesh. Because lateral frustum area and volume do not scale as pure
powers of ``k`` when edge lengths are fixed, ``k`` is computed with a
one-dimensional root solve (not ``sqrt`` / ``cbrt`` of a single ratio).
Parameters
----------
mesh : trimesh.Trimesh or MeshManager
The mesh to match. Can be either a trimesh.Trimesh object or a
MeshManager instance.
metric : str, default "surface_area"
Which metric to match. Options are:
- "surface_area": Match total surface area
- "volume": Match total volume
account_for_overlaps : bool, default False
If True, subtract branch-point overlap corrections when computing
the morphology's surface area or volume (same as for
:meth:`compute_surface_area` / :meth:`compute_volume`).
Returns
-------
float
The scaling factor applied to all radii.
Raises
------
ValueError
If metric is not "surface_area" or "volume", or if the mesh has
zero area/volume, or if the morphology has zero area/volume.
Examples
--------
>>> from mascaf import MeshManager, MorphologyGraph
>>> mesh_mgr = MeshManager(mesh_path="neuron.obj")
>>> graph = MorphologyGraph.from_swc_file("output.swc")
>>> scale_factor = graph.scale_radii_to_match_mesh(mesh_mgr)
>>> print(f"Radii scaled by factor: {scale_factor:.3f}")
"""
# Handle mesh input - extract trimesh.Trimesh if needed
try:
from .mesh import MeshManager
except ImportError:
MeshManager = None
if MeshManager is not None and isinstance(mesh, MeshManager):
mesh_obj = mesh.mesh
else:
# Assume it's already a trimesh.Trimesh
mesh_obj = mesh
# Validate metric
if metric not in ["surface_area", "volume"]:
raise ValueError(
f"metric must be 'surface_area' or 'volume', got '{metric}'"
)
# Get target value from mesh
if metric == "surface_area":
target_value = float(mesh_obj.area)
if target_value <= 0.0:
raise ValueError("Mesh has zero or negative surface area")
current_value = self.compute_surface_area(
account_for_overlaps=account_for_overlaps
)
else: # volume
target_value = float(mesh_obj.volume)
if target_value <= 0.0:
raise ValueError("Mesh has zero or negative volume")
current_value = self.compute_volume(
account_for_overlaps=account_for_overlaps
)
if current_value <= 0.0:
raise ValueError(
f"Morphology has zero or negative {metric.replace('_', ' ')}"
)
scale_factor = self._solve_uniform_radius_scale_factor(
target_value,
metric=metric,
account_for_overlaps=account_for_overlaps,
)
# Apply scaling to all node radii
for node_id in self.nodes():
self.nodes[node_id]["radius"] *= scale_factor
return scale_factor
[docs]
def print_attributes(
self, *, node_info: bool = False, edge_info: bool = False
) -> None:
"""Print graph attributes and optional node/edge details.
Parameters
----------
node_info: bool
If True, print per-node attributes (xyz, radius, and any other attributes).
edge_info: bool
If True, print all edges (u -- v) with edge attributes if any.
"""
# Compute graph statistics
nodes = self.number_of_nodes()
edges = self.number_of_edges()
components = nx.number_connected_components(self)
# Count cycles by checking if graph is a forest
try:
cycles = 0 if nx.is_forest(self) else len(nx.cycle_basis(self))
except Exception:
cycles = "?"
# Count branch points (degree > 2), leaves (degree == 1), and self-loops
branch_points = sum(1 for n in self.nodes() if self.degree[n] > 2)
leaves = sum(1 for n in self.nodes() if self.degree[n] == 1)
self_loops = nx.number_of_selfloops(self)
# Compute density
density = nx.density(self) if nodes > 0 else 0.0
header = (
f"MorphologyGraph: nodes={nodes}, edges={edges}, "
f"components={components}, cycles={cycles}, "
f"branch_points={branch_points}, "
f"leaves={leaves}, self_loops={self_loops}, density={density:.4f}"
)
print(header)
if node_info:
print("Nodes:")
for n, attrs in self.nodes(data=True):
parts = []
# Print xyz if present
if "xyz" in attrs:
xyz = attrs["xyz"]
parts.append(f"xyz=({xyz[0]:.4f}, {xyz[1]:.4f}, {xyz[2]:.4f})")
# Print radius if present
if "radius" in attrs:
parts.append(f"r={attrs['radius']:.4f}")
# Print other attributes
for k, v in attrs.items():
if k not in ["xyz", "radius"]:
parts.append(f"{k}={v}")
print(f" {n}: " + ", ".join(parts))
if edge_info:
print("Edges:")
for u, v, attrs in self.edges(data=True):
if attrs:
print(f" {u} -- {v}: {dict(attrs)}")
else:
print(f" {u} -- {v}")
# ------------------------------------------------------------------
# SWC Export
# ------------------------------------------------------------------
[docs]
def to_swc_file(
self,
path: str | None = None,
*,
tag: int = 3,
annotate_breaks: bool = True,
) -> str:
"""Export the skeleton to SWC format, breaking cycles by duplicating nodes.
The SWC format is a line-based format with columns
``n T x y z R parent``, where ``n`` is the node id, ``T`` is the SWC
type index, ``x,y,z`` are coordinates, ``R`` is radius, and ``parent``
is the parent's id (or -1 for the root).
Constructs a spanning forest over the undirected graph; for every
non-tree edge that would introduce a cycle it duplicates one endpoint
and attaches the duplicate as a child of the other endpoint.
Parameters
----------
path : str or None
If provided, write the SWC text to this file. If ``None``, return
the SWC text as a string.
tag : int, default 3
Integer placed in the SWC ``T`` (type) column for all nodes.
annotate_breaks : bool, default True
If ``True``, include header comment lines indicating how to
reconnect duplicates to restore each broken cycle.
Returns
-------
str
The SWC text (whether or not it was also written to a file).
Raises
------
KeyError
If any node is missing the ``xyz`` or ``radius`` attribute.
"""
# Collect original node ids and attributes
if self.number_of_nodes() == 0:
swc_text = "# Empty skeleton\n"
if path is None:
return swc_text
with open(path, "w", encoding="utf-8") as f:
f.write(swc_text)
return swc_text
# Ensure required attributes exist
for n in self.nodes:
if "xyz" not in self.nodes[n] or "radius" not in self.nodes[n]:
raise KeyError(
f"Node {n} missing required attributes 'xyz' and 'radius'"
)
# Build a spanning forest per connected component with DFS, choosing
# a terminal (degree==1) node as root when available. Create a DFS
# visitation order to determine SWC indices sequentially starting at 1.
parents_orig: dict[int, int] = {}
tree_edges: set[frozenset[int]] = set()
comp_roots: list[int] = []
comp_orders: list[list[int]] = []
# Sort components deterministically by min node id
components = [
sorted(int(x) for x in comp) for comp in nx.connected_components(self)
]
components.sort(key=lambda nodes: (len(nodes) == 0, nodes[0] if nodes else -1))
for comp_nodes in components:
if not comp_nodes:
continue
# Prefer a terminal as root
terminals = [n for n in comp_nodes if int(self.degree[n]) == 1]
root = min(terminals) if terminals else min(comp_nodes)
comp_roots.append(root)
parents_orig[int(root)] = -1
# DFS tree
for u, v in nx.dfs_edges(self, source=root):
u = int(u)
v = int(v)
parents_orig[v] = u
tree_edges.add(frozenset({u, v}))
# DFS visitation order (preorder)
order = [int(n) for n in nx.dfs_preorder_nodes(self, source=root)]
comp_orders.append(order)
# Ensure isolated nodes get parent -1
for n in comp_nodes:
if n not in parents_orig:
parents_orig[int(n)] = -1
# Determine non-tree edges (these would create cycles)
extra_edges: list[tuple[int, int]] = []
for u, v in self.edges():
e = frozenset({int(u), int(v)})
if e not in tree_edges:
extra_edges.append((int(u), int(v)))
# Assign new SWC indices according to concatenated DFS orders
new_id: dict[int, int] = {}
order_all: list[int] = []
for od in comp_orders:
order_all.extend(od)
next_index = 1
for n in order_all:
if n not in new_id:
new_id[n] = next_index
next_index += 1
# Prepare SWC entries for original nodes using DFS-based indices
entries: list[tuple[int, int, float, float, float, float, int]] = []
for n in order_all:
nid = new_id[n]
xyz = np.asarray(self.nodes[n]["xyz"], dtype=float).reshape(3)
r = float(self.nodes[n].get("radius", 0.0))
parent_orig = parents_orig.get(n, -1)
parent_id = new_id[parent_orig] if parent_orig in new_id else -1
entries.append(
(int(nid), int(tag), xyz[0], xyz[1], xyz[2], r, int(parent_id))
)
# Process extra edges by duplicating one endpoint and attaching to the other
cycle_annotations: list[tuple[int, int]] = (
[]
) # (duplicate_swc_id, original_swc_id)
for u, v in extra_edges:
# Choose which node to duplicate: prefer higher degree (branching)
deg_u = self.degree[u]
deg_v = self.degree[v]
dup_orig = v if deg_v >= deg_u else u
other_orig = u if dup_orig == v else v
xyz = np.asarray(self.nodes[dup_orig]["xyz"], dtype=float).reshape(3)
r = float(self.nodes[dup_orig].get("radius", 0.0))
dup_swc = int(next_index)
parent_swc = int(new_id.get(other_orig, -1))
entries.append((dup_swc, int(tag), xyz[0], xyz[1], xyz[2], r, parent_swc))
# Record annotation using SWC indices
cycle_annotations.append((dup_swc, int(new_id.get(dup_orig, dup_swc))))
next_index += 1
# Compose SWC text
lines: list[str] = []
lines.append("# generated by mascaf MorphologyGraph.to_swc_file")
lines.append(
f"# dfs_roots={' '.join(str(new_id.get(r, r)) for r in comp_roots)}"
)
lines.append(
f"# nodes={self.number_of_nodes()} extra_edges={len(extra_edges)} duplicates={len(cycle_annotations)}"
)
lines.append(f"# tag={int(tag)}")
if annotate_breaks and cycle_annotations:
for dup_id, orig_id in cycle_annotations:
lines.append(f"# CYCLE_BREAK reconnect {dup_id} {orig_id}")
# Sort entries by SWC id for readability
entries.sort(key=lambda t: int(t[0]))
for nid, T, x, y, z, R, parent in entries:
lines.append(f"{nid} {T} {x:.6f} {y:.6f} {z:.6f} {R:.6f} {parent}")
swc_text = "\n".join(lines) + "\n"
if path is None:
return swc_text
with open(path, "w", encoding="utf-8") as f:
f.write(swc_text)
return swc_text
__all__ = ["MorphologyGraph", "Junction"]