Source code for mascaf.graph3d

from __future__ import annotations

import logging
from typing import Optional, Set

import networkx as nx
import numpy as np

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


[docs] class Graph3D(nx.Graph): """Undirected graph whose vertices are embedded in 3D.""" position_attr = "xyz"
[docs] def get_node_position(self, node: int) -> np.ndarray: """Return the 3D position of a node.""" return np.asarray(self.nodes[node][self.position_attr], dtype=float)
[docs] def set_node_position(self, node: int, pos: np.ndarray) -> None: """Set the 3D position of a node.""" self.nodes[node][self.position_attr] = np.asarray(pos, dtype=float)
[docs] def get_all_positions(self) -> np.ndarray: """Return positions of all nodes ordered by sorted node ID.""" if self.number_of_nodes() == 0: return np.zeros((0, 3), dtype=float) return np.array( [self.get_node_position(node) for node in sorted(self.nodes())], dtype=float, )
[docs] def set_all_positions(self, positions: np.ndarray) -> None: """Set positions of all nodes from an `(N, 3)` array.""" if positions.shape[0] != self.number_of_nodes(): raise ValueError( f"Position array has {positions.shape[0]} rows but graph has " f"{self.number_of_nodes()} nodes" ) for idx, node in enumerate(sorted(self.nodes())): self.set_node_position(node, positions[idx])
[docs] def get_terminal_nodes(self) -> Set[int]: """Return nodes of degree 1.""" return {node for node in self.nodes() if self.degree(node) == 1}
[docs] def get_branch_nodes(self) -> Set[int]: """Return nodes of degree at least 3.""" return {node for node in self.nodes() if self.degree(node) >= 3}
[docs] def get_continuation_nodes(self) -> Set[int]: """Return nodes of degree 2.""" return {node for node in self.nodes() if self.degree(node) == 2}
[docs] def is_terminal_node(self, node: int) -> bool: """Return whether a node is terminal.""" return self.degree(node) == 1
[docs] def is_branch_node(self, node: int) -> bool: """Return whether a node is a branch node.""" return self.degree(node) >= 3
[docs] def is_continuation_node(self, node: int) -> bool: """Return whether a node is a continuation node.""" return self.degree(node) == 2
[docs] def bounds(self) -> Optional[dict]: """Return axis-aligned bounds of the vertex set.""" if self.number_of_nodes() == 0: return None if self.__class__.__name__ == "MorphologyGraph": logger.info( "MorphologyGraph bounds refer to the vertex set, " "not the volumetric model" ) positions = self.get_all_positions() lo = positions.min(axis=0) hi = positions.max(axis=0) return { "x": (float(lo[0]), float(hi[0])), "y": (float(lo[1]), float(hi[1])), "z": (float(lo[2]), float(hi[2])), }
[docs] def midpoint(self) -> Optional[np.ndarray]: """Return the mean position of the graph's vertex set.""" if self.number_of_nodes() == 0: return None return self.get_all_positions().mean(axis=0)
[docs] def get_total_length(self) -> float: """Return the sum of edge lengths. Missing edge lengths are computed from node geometry. """ total = 0.0 for u, v, data in self.edges(data=True): length = data.get("length") if length is None: pu = self.get_node_position(u) pv = self.get_node_position(v) length = float(np.linalg.norm(pv - pu)) total += float(length) return total