Source code for mascaf.skeleton

"""
Graph-based skeleton handler.

Provides a `SkeletonGraph` class that inherits from networkx.Graph to:
- Represent skeleton as a graph with xyz coordinates on each node
- Load from polylines array or polylines text format: `N x1 y1 z1 x2 y2 z2 ...`
- Identify terminal nodes (degree 1) and branch nodes (degree 3+)
- Every point from input polylines becomes a node in the SkeletonGraph

Note: This class represents skeleton topology as a graph where:
- Nodes have 'pos' attribute with (x, y, z) coordinates
- Edges connect consecutive nodes along polylines
- Terminal nodes have degree 1
- Branch nodes have degree 3+
- Continuation nodes have degree 2
"""

from __future__ import annotations

import logging
from typing import List, Optional, Sequence, Set

import networkx as nx
import numpy as np
from swctools import PointSet

from .graph3d import Graph3D

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


[docs] class SkeletonGraph(Graph3D): """ Graph-based skeleton representation with xyz coordinates on nodes. Inherits from networkx.Graph. Each node has a 'pos' attribute storing (x, y, z) coordinates. Edges represent connections between consecutive points along polylines data. Every point from input polylines file becomes a node. Endpoints within tolerance are merged into single nodes. Terminal nodes (degree 1) are isolated endpoints. Branch nodes (degree 3+) are where multiple branches meet. Continuation nodes (degree 2) are intermediate points along branches. """
[docs] def __init__(self, tolerance: float = 1e-6, **attr): """Initialize a SkeletonGraph. Parameters ---------- tolerance : float, default 1e-6 Distance threshold below which two endpoints are merged into a single node. **attr Additional keyword arguments forwarded to the networkx graph constructor as graph-level attributes. """ super().__init__(**attr) self.graph["tolerance"] = tolerance self._next_node_id = 0
position_attr = "pos"
[docs] @classmethod def from_polylines( cls, polylines: Sequence[np.ndarray], tolerance: float = 1e-6, ) -> "SkeletonGraph": """Create a :class:`SkeletonGraph` from a sequence of polyline arrays. Every point in every polyline becomes a node. Consecutive points within a polyline are connected by edges. Endpoints whose Euclidean distance is less than *tolerance* are merged into a single node. Parameters ---------- polylines : sequence of numpy.ndarray, shape (N_i, 3) Each element is an ordered array of 3D points representing one branch of the skeleton. tolerance : float, default 1e-6 Maximum distance between two endpoint coordinates for them to be merged. Returns ------- SkeletonGraph Graph with node positions stored under the ``'pos'`` attribute and edge ``'length'`` attributes populated. """ graph = cls(tolerance=tolerance) if not polylines: return graph # Step 1: Create nodes for all points in all polylines point_to_node = {} # Maps (poly_idx, point_idx) -> node_id endpoints = [] # List of (node_id, poly_idx, point_idx, coord) for poly_idx, pl in enumerate(polylines): if len(pl) == 0: continue for point_idx, coord in enumerate(pl): node_id = graph._get_next_node_id() graph.add_node(node_id, pos=np.array(coord, dtype=float)) point_to_node[(poly_idx, point_idx)] = node_id # Track endpoints (first and last points of each polyline) if point_idx == 0 or point_idx == len(pl) - 1: endpoints.append((node_id, poly_idx, point_idx, coord)) # Step 2: Merge endpoints that are within tolerance endpoint_groups = [] used = set() for i, (node_i, poly_i, pt_i, coord_i) in enumerate(endpoints): if i in used: continue # Find all endpoints close to this one group = [(node_i, poly_i, pt_i, coord_i)] for j, (node_j, poly_j, pt_j, coord_j) in enumerate(endpoints): if i != j and j not in used: dist = np.linalg.norm(coord_i - coord_j) if dist < tolerance: group.append((node_j, poly_j, pt_j, coord_j)) # Mark all as used for idx in range(len(endpoints)): if any(endpoints[idx][0] == node_id for node_id, _, _, _ in group): used.add(idx) endpoint_groups.append(group) # Merge nodes in each group for group in endpoint_groups: if len(group) == 1: continue # Use the first node as the merged node merged_node = group[0][0] # Compute centroid position for the merged node coords = np.array([item[3] for item in group]) merged_pos = coords.mean(axis=0) graph.nodes[merged_node]["pos"] = merged_pos # Map all other nodes to the merged node for node_id, poly_idx, point_idx, _ in group[1:]: point_to_node[(poly_idx, point_idx)] = merged_node # Remove the redundant node if node_id in graph: graph.remove_node(node_id) # Step 3: Add edges connecting consecutive points in each polyline for poly_idx, pl in enumerate(polylines): if len(pl) < 2: continue for point_idx in range(len(pl) - 1): node_u = point_to_node.get((poly_idx, point_idx)) node_v = point_to_node.get((poly_idx, point_idx + 1)) if node_u is not None and node_v is not None and node_u != node_v: # Compute edge length pos_u = graph.nodes[node_u]["pos"] pos_v = graph.nodes[node_v]["pos"] length = float(np.linalg.norm(pos_v - pos_u)) graph.add_edge( node_u, node_v, polyline_idx=poly_idx, segment_idx=point_idx, length=length, ) return graph
[docs] @classmethod def from_txt(cls, path: str, tolerance: float = 1e-6) -> "SkeletonGraph": """ Load a SkeletonGraph from a file. Supports two formats: 1. GraphML format (.graphml or .xml extension) - native graph format 2. Legacy polylines format (.polylines.txt) - for backward compatibility Args: path: Path to the skeleton file tolerance: Distance threshold for merging nearby endpoints (polylines format only) Returns: SkeletonGraph instance """ import networkx as nx # Check file extension to determine format if path.endswith(".graphml") or path.endswith(".xml"): # Load from GraphML format G = nx.read_graphml(path) # Create SkeletonGraph and copy data graph = cls(tolerance=tolerance) # Add nodes with positions for node_id in G.nodes(): node_data = G.nodes[node_id] # Convert node_id back to int node_int = int(node_id) # Parse position from string format "x,y,z" pos_str = node_data.get("pos", "0,0,0") pos = np.array([float(x) for x in pos_str.split(",")], dtype=float) graph.add_node(node_int, pos=pos) # Add edges with metadata for u, v, data in G.edges(data=True): u_int = int(u) v_int = int(v) edge_data = {} if "length" in data: edge_data["length"] = float(data["length"]) if "polyline_idx" in data: edge_data["polyline_idx"] = int(data["polyline_idx"]) if "segment_idx" in data: edge_data["segment_idx"] = int(data["segment_idx"]) graph.add_edge(u_int, v_int, **edge_data) # Set graph-level tolerance graph.graph["tolerance"] = tolerance return graph else: # Legacy polylines format polylines = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue parts = line.split() if len(parts) < 4: # Need at least N and one (x,y,z) coordinate continue try: n = int(parts[0]) coords = [float(x) for x in parts[1:]] if len(coords) != n * 3: continue # Reshape into (N, 3) array pl = np.array(coords).reshape(n, 3) polylines.append(pl) except (ValueError, IndexError): continue return cls.from_polylines(polylines, tolerance=tolerance)
def _get_next_node_id(self) -> int: """Get the next available node ID.""" node_id = self._next_node_id self._next_node_id += 1 return node_id
[docs] def detect_branch_points(self, tolerance: float = 1e-6) -> dict: """ Detect branch points and endpoints in the skeleton. Note: `SkeletonGraph` already merges endpoints within tolerance on import. The `tolerance` argument is accepted for API compatibility. Args: tolerance: Unused; kept for compatibility. Returns: Dictionary containing: - 'branch_points': List of node IDs for branch nodes - 'endpoints': List of node IDs for terminal nodes - 'branch_locations': List of 3D coordinates of branch nodes - 'endpoint_locations': List of 3D coordinates of terminal nodes """ _ = tolerance branch_nodes = sorted(self.get_branch_nodes()) terminal_nodes = sorted(self.get_terminal_nodes()) return { "branch_points": branch_nodes, "endpoints": terminal_nodes, "branch_locations": [self.get_node_position(n) for n in branch_nodes], "endpoint_locations": [self.get_node_position(n) for n in terminal_nodes], }
[docs] def get_branch_point_indices(self, tolerance: float = 1e-6) -> Set[int]: """Return the set of branch node IDs. Args: tolerance: Unused; kept for compatibility. Returns: Set of node IDs. """ _ = tolerance return set(self.get_branch_nodes())
[docs] def get_true_endpoint_indices(self, tolerance: float = 1e-6) -> Set[int]: """Return the set of terminal node IDs. Args: tolerance: Unused; kept for compatibility. Returns: Set of node IDs. """ _ = tolerance return set(self.get_terminal_nodes())
[docs] def build_graph(self, tolerance: float = 1e-6) -> nx.Graph: """ Build a networkx graph representation with node type annotations. This is primarily a compatibility helper for tests. The returned graph is a plain `nx.Graph` (not a `SkeletonGraph`). Args: tolerance: Unused; kept for compatibility. Returns: A `nx.Graph` with nodes annotated with: - 'pos': (x, y, z) - 'type': 'endpoint' | 'branch' | 'continuation' """ _ = tolerance G = nx.Graph() for n in self.nodes(): if self.is_branch_node(n): node_type = "branch" elif self.is_terminal_node(n): node_type = "endpoint" else: node_type = "continuation" pos = self.get_node_position(n) G.add_node(n, pos=pos, type=node_type) for u, v, data in self.edges(data=True): G.add_edge(u, v, **(data or {})) return G
def _edge_length(self, u: int, v: int) -> float: data = self.get_edge_data(u, v) or {} length = data.get("length") if length is not None: return float(length) pu = self.get_node_position(u) pv = self.get_node_position(v) return float(np.linalg.norm(pv - pu)) def _trace_from_terminal(self, start: int) -> tuple[int, list[int], float]: """Trace from a terminal node until the next non-continuation node. Returns (end_node, path_nodes, length). """ if start not in self: return start, [start], 0.0 if self.degree(start) != 1: return start, [start], 0.0 path = [start] prev = None current = start length = 0.0 while True: nbrs = list(self.neighbors(current)) if prev is not None: nbrs = [n for n in nbrs if n != prev] if len(nbrs) == 0: break nxt = nbrs[0] length += self._edge_length(current, nxt) prev, current = current, nxt path.append(current) if self.degree(current) != 2: break return current, path, length
[docs] def prune_short_branches( self, min_length: Optional[float] = None, min_length_fraction: Optional[float] = None, tolerance: float = 1e-6, iterative: bool = True, verbose: bool = False, ) -> "SkeletonGraph": """Remove short terminal branches. A terminal branch is a path from a terminal node (degree 1) to the next non-continuation node (degree != 2). If the path ends at a branch node (degree 3+) and its geometric length is below the threshold, it is removed. Isolated components that connect terminal-to-terminal with no branch nodes are removed regardless of length. Length is computed from edge lengths (or node coordinates if missing), so "short" refers to *geometric* length, not node count. """ _ = tolerance if self.number_of_nodes() == 0: return self.copy() if min_length is None and min_length_fraction is None: raise ValueError("Must specify either min_length or min_length_fraction") # Determine threshold from the original graph original = self.copy() original_branch_lengths = list(original.compute_branch_lengths().values()) if len(original_branch_lengths) == 0: # Nothing to prune return self.copy() if min_length is not None: threshold = float(min_length) else: threshold = float( np.percentile(original_branch_lengths, min_length_fraction * 100.0) ) if verbose: logger.info("Pruning branches with length < %.4f", threshold) current = self.copy() while True: terminal_nodes = sorted( [n for n in current.nodes() if current.degree(n) == 1] ) nodes_to_remove: Set[int] = set() visited_terminals: Set[int] = set() for t in terminal_nodes: if t in visited_terminals or t not in current: continue end, path, length = current._trace_from_terminal(t) if len(path) <= 1: visited_terminals.add(t) continue visited_terminals.add(t) if end != t and current.degree(end) == 1: visited_terminals.add(end) is_isolated = end != t and current.degree(end) == 1 ends_at_branch = end != t and current.degree(end) >= 3 should_remove = False if is_isolated: should_remove = True elif ends_at_branch and length < threshold: should_remove = True if not should_remove: continue # Remove everything except the branch node when applicable if ends_at_branch: nodes_to_remove.update(path[:-1]) else: nodes_to_remove.update(path) if len(nodes_to_remove) == 0: break current.remove_nodes_from([n for n in nodes_to_remove if n in current]) if not iterative: break return current
[docs] def prune_short_branches_inplace( self, min_length: Optional[float] = None, min_length_fraction: Optional[float] = None, tolerance: float = 1e-6, iterative: bool = True, verbose: bool = False, ) -> int: """In-place version of `prune_short_branches`. Returns: Number of nodes removed. """ before = self.number_of_nodes() pruned = self.prune_short_branches( min_length=min_length, min_length_fraction=min_length_fraction, tolerance=tolerance, iterative=True, verbose=verbose, ) self.clear() self.add_nodes_from(pruned.nodes(data=True)) self.add_edges_from(pruned.edges(data=True)) self.graph.update(pruned.graph) self._next_node_id = getattr(pruned, "_next_node_id", self._next_node_id) return before - self.number_of_nodes()
# --------------------------------------------------------------------- # Conversion # ---------------------------------------------------------------------
[docs] def to_point_set(self) -> PointSet: """Return all skeleton node locations as a PointSet.""" positions = self.get_all_positions() return PointSet.from_points(positions)
[docs] def to_polylines(self) -> List[np.ndarray]: """ Convert the graph back to a list of polyline arrays. Reconstructs polylines by grouping edges with the same polyline_idx and ordering them by segment_idx, properly handling branch points. Returns: List of (N_i, 3) arrays representing polylines """ if self.number_of_edges() == 0: return [] # Group edges by polyline_idx polyline_edges = {} # Maps polyline_idx -> list of (segment_idx, u, v) for u, v, data in self.edges(data=True): poly_idx = data.get("polyline_idx") seg_idx = data.get("segment_idx") if poly_idx is not None and seg_idx is not None: if poly_idx not in polyline_edges: polyline_edges[poly_idx] = [] polyline_edges[poly_idx].append((seg_idx, u, v)) # Reconstruct each polyline polylines = [] for poly_idx in sorted(polyline_edges.keys()): edges = sorted(polyline_edges[poly_idx], key=lambda x: x[0]) if not edges: continue # Build polyline from ordered edges by following the chain points = [] # Start with the first edge _, u_first, v_first = edges[0] current_node = u_first points.append(self.get_node_position(current_node)) # Follow the chain of edges for seg_idx, u, v in edges: # Determine which node is next in the chain if u == current_node: next_node = v elif v == current_node: next_node = u else: # Edge doesn't connect to current node - this shouldn't happen # but if it does, just add both nodes points.append(self.get_node_position(u)) next_node = v points.append(self.get_node_position(next_node)) current_node = next_node polylines.append(np.array(points)) return polylines
[docs] def to_txt(self, path: str) -> None: """ Save the skeleton to a file. Saves in GraphML format (.graphml) which preserves all graph structure, node positions, and edge metadata. This is the native format for SkeletonGraph. For legacy polylines format, use to_polylines() and save manually. Args: path: Output file path (will use .graphml extension if not provided) """ import networkx as nx # Ensure .graphml extension if not path.endswith(".graphml"): if path.endswith(".txt"): path = path.replace(".txt", ".graphml") else: path = path + ".graphml" # Create a copy of the graph for serialization G = nx.Graph() # Add nodes with position as string (GraphML requires string attributes) for node in self.nodes(): pos = self.get_node_position(node) pos_str = f"{pos[0]},{pos[1]},{pos[2]}" G.add_node(str(node), pos=pos_str) # Add edges with metadata for u, v, data in self.edges(data=True): edge_attrs = {} if "length" in data: edge_attrs["length"] = str(data["length"]) if "polyline_idx" in data: edge_attrs["polyline_idx"] = str(data["polyline_idx"]) if "segment_idx" in data: edge_attrs["segment_idx"] = str(data["segment_idx"]) G.add_edge(str(u), str(v), **edge_attrs) # Write to GraphML format nx.write_graphml(G, path)
# --------------------------------------------------------------------- # Copy # ---------------------------------------------------------------------
[docs] def copy(self) -> "SkeletonGraph": """ Create a deep copy of the skeleton graph. Returns: New SkeletonGraph instance with copied data """ # Create new graph with same tolerance new_graph = SkeletonGraph(tolerance=self.graph.get("tolerance", 1e-6)) # Copy nodes with positions for node, data in self.nodes(data=True): pos = np.array(data["pos"]) new_graph.add_node(node, pos=pos.copy()) # Copy edges with data for u, v, data in self.edges(data=True): edge_data = dict(data) new_graph.add_edge(u, v, **edge_data) # Update node ID counter new_graph._next_node_id = self._next_node_id return new_graph
# --------------------------------------------------------------------- # Statistics # ---------------------------------------------------------------------
[docs] def get_statistics(self) -> dict: """ Get statistics about the skeleton graph. Returns: Dictionary with various statistics """ terminal_nodes = self.get_terminal_nodes() branch_nodes = self.get_branch_nodes() continuation_nodes = self.get_continuation_nodes() edge_lengths = [data.get("length", 0.0) for _, _, data in self.edges(data=True)] stats = { "num_nodes": self.number_of_nodes(), "num_edges": self.number_of_edges(), "num_terminal_nodes": len(terminal_nodes), "num_branch_nodes": len(branch_nodes), "num_continuation_nodes": len(continuation_nodes), "total_points": self.number_of_nodes(), } if edge_lengths: stats["total_length"] = sum(edge_lengths) stats["mean_edge_length"] = np.mean(edge_lengths) stats["min_edge_length"] = min(edge_lengths) stats["max_edge_length"] = max(edge_lengths) return stats
def __repr__(self) -> str: """String representation of the skeleton graph.""" stats = self.get_statistics() return ( f"SkeletonGraph(nodes={stats['num_nodes']}, " f"edges={stats['num_edges']}, " f"terminals={stats['num_terminal_nodes']}, " f"branches={stats['num_branch_nodes']})" ) # --------------------------------------------------------------------- # Resampling # ---------------------------------------------------------------------
[docs] def resample(self, spacing: float) -> "SkeletonGraph": """ Resample the skeleton to have approximately uniform spacing between nodes. Works directly on graph edges, subdividing long edges and preserving topology. Args: spacing: Target distance between consecutive nodes Returns: New SkeletonGraph with resampled nodes """ import networkx as nx new_graph = SkeletonGraph(tolerance=self.graph.get("tolerance", 1e-6)) # Copy all existing nodes first node_mapping = {} # old_node -> new_node for node in self.nodes(): new_node = new_graph._get_next_node_id() new_graph.add_node(new_node, pos=self.get_node_position(node).copy()) node_mapping[node] = new_node # Process each edge, subdividing if necessary for u, v, data in self.edges(data=True): pos_u = self.get_node_position(u) pos_v = self.get_node_position(v) # Calculate edge length edge_length = np.linalg.norm(pos_v - pos_u) # Determine number of segments needed n_segments = max(1, int(np.ceil(edge_length / spacing))) if n_segments == 1: # Edge is short enough, just connect directly new_u = node_mapping[u] new_v = node_mapping[v] length = float(edge_length) new_graph.add_edge( new_u, new_v, length=length, polyline_idx=data.get("polyline_idx"), segment_idx=data.get("segment_idx"), ) else: # Subdivide edge prev_node = node_mapping[u] for i in range(1, n_segments): # Interpolate position t = i / n_segments new_pos = pos_u + t * (pos_v - pos_u) # Create intermediate node intermediate_node = new_graph._get_next_node_id() new_graph.add_node(intermediate_node, pos=new_pos) # Add edge from previous to intermediate seg_length = float( np.linalg.norm( new_graph.get_node_position(intermediate_node) - new_graph.get_node_position(prev_node) ) ) new_graph.add_edge( prev_node, intermediate_node, length=seg_length, polyline_idx=data.get("polyline_idx"), segment_idx=data.get("segment_idx"), ) prev_node = intermediate_node # Add final segment to v new_v = node_mapping[v] seg_length = float( np.linalg.norm( new_graph.get_node_position(new_v) - new_graph.get_node_position(prev_node) ) ) new_graph.add_edge( prev_node, new_v, length=seg_length, polyline_idx=data.get("polyline_idx"), segment_idx=data.get("segment_idx"), ) return new_graph
# --------------------------------------------------------------------- # Mesh surface projection # ---------------------------------------------------------------------
[docs] def snap_to_mesh_surface( self, mesh, project_outside_only: bool = True, max_distance: Optional[float] = None, ) -> tuple: """ Project node positions to the nearest surface point on mesh. Args: mesh: trimesh.Trimesh object project_outside_only: If True, only project points outside the mesh max_distance: If provided, only move points beyond this distance from surface Returns: (n_moved, mean_move_distance) tuple """ if mesh is None or len(getattr(mesh, "vertices", [])) == 0: return 0, 0.0 if self.number_of_nodes() == 0: return 0, 0.0 # Get all node positions positions = self.get_all_positions() # Determine which points to project use_mask = None if project_outside_only: try: from trimesh.proximity import signed_distance d = signed_distance(mesh, positions) use_mask = d > 0 # outside except Exception: use_mask = None # Find closest points on mesh try: from trimesh.proximity import closest_point closest_positions, distances, _ = closest_point(mesh, positions) except Exception: # Fallback: KDTree over vertices only vertices = np.asarray(mesh.vertices, dtype=float) if vertices.size == 0: return 0, 0.0 from scipy.spatial import cKDTree tree = cKDTree(vertices) distances, idx = tree.query(positions, k=1) closest_positions = vertices[idx] # Apply masks if use_mask is None: mask = np.ones(positions.shape[0], dtype=bool) else: mask = use_mask if max_distance is not None: mask = mask & (distances >= float(max_distance)) # Update positions moved = 0 total_move = 0.0 for i, node in enumerate(sorted(self.nodes())): if mask[i]: self.set_node_position(node, closest_positions[i]) moved += 1 total_move += distances[i] mean_move = (total_move / moved) if moved > 0 else 0.0 return moved, mean_move
# --------------------------------------------------------------------- # Branch length computation # ---------------------------------------------------------------------
[docs] def compute_branch_lengths(self) -> dict: """ Compute the length of each branch (path between terminal/branch nodes). Returns: Dictionary mapping (start_node, end_node) -> length """ branch_lengths = {} # Get terminal and branch nodes terminal_nodes = self.get_terminal_nodes() branch_nodes = self.get_branch_nodes() special_nodes = terminal_nodes | branch_nodes # For each special node, trace paths to other special nodes for start_node in special_nodes: # Use BFS to find paths to other special nodes visited = {start_node} queue = [(start_node, [start_node], 0.0)] # (node, path, length) while queue: current, path, length = queue.pop(0) for neighbor in self.neighbors(current): if neighbor in visited: continue # Get edge length edge_data = self.get_edge_data(current, neighbor) edge_length = edge_data.get("length", 0.0) if edge_data else 0.0 new_length = length + edge_length new_path = path + [neighbor] # If we reached another special node, record the branch if neighbor in special_nodes and neighbor != start_node: key = tuple(sorted([start_node, neighbor])) if key not in branch_lengths: branch_lengths[key] = new_length else: # Continue searching visited.add(neighbor) queue.append((neighbor, new_path, new_length)) return branch_lengths
[docs] def get_total_length(self) -> float: """ Get the total length of all edges in the skeleton. Returns: Total length """ return sum(data.get("length", 0.0) for _, _, data in self.edges(data=True))
# --------------------------------------------------------------------- # Helper functions # --------------------------------------------------------------------- def _resample_polyline(pl: np.ndarray, spacing: float) -> np.ndarray: """ Resample a polyline at approximately constant arc-length spacing. Includes the first and last vertex; inserts intermediate points every multiple of `spacing` along cumulative arclength. Args: pl: (N, 3) array of points spacing: Target spacing between points Returns: (M, 3) array of resampled points """ P = np.asarray(pl, dtype=float) if P.ndim != 2 or P.shape[1] != 3 or P.shape[0] == 0: return np.zeros((0, 3), dtype=float) if P.shape[0] == 1: return P.copy() seg = np.linalg.norm(P[1:] - P[:-1], axis=1) L = np.concatenate([[0.0], np.cumsum(seg)]) total = float(L[-1]) if total <= 0.0: return P[[0], :].copy() step = float(max(spacing, 1e-12)) # Always include start and end targets = list(np.arange(0.0, total, step)) if targets[-1] != total: targets.append(total) out: List[np.ndarray] = [] si = 0 # segment index for t in targets: # advance si until L[si] <= t <= L[si+1] while si < len(seg) and L[si + 1] < t: si += 1 if si >= len(seg): out.append(P[-1]) continue t0 = L[si] t1 = L[si + 1] if t1 <= t0: out.append(P[si]) continue alpha = (t - t0) / (t1 - t0) Q = (1.0 - alpha) * P[si] + alpha * P[si + 1] out.append(Q) return np.vstack(out) if out else P[[0], :].copy()