"""
Main mesh class
"""
import logging
import multiprocessing
import traceback
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import trimesh
# Module-level logger
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]
def example_mesh(
kind: str = "cylinder",
*,
# Cylinder params
radius: float = 1,
height: float = 10,
sections: int | None = 16,
# Torus params
major_radius: float = 4,
minor_radius: float = 1,
major_sections: int | None = 32,
minor_sections: int | None = 16,
**kwargs,
) -> trimesh.Trimesh:
"""Create a simple demo mesh using trimesh primitives.
Parameters
----------
kind : {"cylinder", "torus"}
Type of primitive to generate. Default "cylinder".
radius : float
Cylinder radius (when kind="cylinder"). Default 1.
height : float
Cylinder height (when kind="cylinder"). Default 10.
sections : int or None
Cylinder radial resolution (pie wedges). Default 16.
major_radius : float
Torus major radius (center of hole to centerline of tube). Default 4.
minor_radius : float
Torus minor radius (tube radius). Default 1.
major_sections : int or None
Torus resolution around major circle. Default 32.
minor_sections : int or None
Torus resolution around tube section. Default 16.
**kwargs : dict
Passed through to Trimesh constructor via trimesh.creation.* helpers
(e.g., process=False).
Returns
-------
trimesh.Trimesh
Generated primitive mesh.
Examples
--------
>>> m = example_mesh("cylinder", radius=0.4, height=1.5)
>>> t = example_mesh("torus", major_radius=1.0, minor_radius=0.25)
"""
k = (kind or "cylinder").lower()
out = None
if k == "cylinder":
out = trimesh.creation.cylinder(
radius=float(radius),
height=float(height),
sections=None if sections is None else int(sections),
**kwargs,
)
elif k == "torus":
# trimesh.creation.torus parameters
out = trimesh.creation.torus(
major_radius=float(major_radius),
minor_radius=float(minor_radius),
major_sections=None if major_sections is None else int(major_sections),
minor_sections=None if minor_sections is None else int(minor_sections),
**kwargs,
)
# Rotate torus 90 degrees, so it's standing on its side
out.apply_transform(
trimesh.transformations.rotation_matrix(
angle=np.pi / 2, direction=[1, 0, 0], point=out.centroid
)
)
else:
raise ValueError("example_mesh kind must be 'cylinder' or 'torus'")
return out
[docs]
class MeshManager:
"""
Unified mesh class handling loading, processing, and analysis of triangle meshes.
Wraps a :class:`trimesh.Trimesh` with convenience methods for loading,
saving, copying, and analyzing mesh geometry. Loading can be deferred
by passing a ``mesh_path`` to the constructor or by calling
:meth:`load_mesh` directly.
Parameters
----------
mesh : trimesh.Trimesh or None
An already-loaded mesh to wrap. Either this or ``mesh_path`` must be
provided.
mesh_path : str or None
Path to a mesh file. If provided, :meth:`load_mesh` is called
immediately during construction.
verbose : bool, default True
If True, log informational messages when loading.
Examples
--------
>>> mgr = MeshManager(mesh_path="neuron.obj")
>>> mgr.bounding_box_diagonal()
"""
[docs]
def __init__(
self,
mesh: Optional[trimesh.Trimesh] = None,
mesh_path: Optional[str] = None,
verbose: bool = True,
):
# Core mesh attributes
self.mesh = mesh
self.mesh_path = mesh_path
# Attributes
self.verbose = verbose
self.stats = {
"processed": 0,
"successful": 0,
"failed": 0,
"volume_fixed": 0,
"watertight_fixed": 0,
"degenerate_removed": 0,
}
if mesh_path is not None:
self.load_mesh(mesh_path)
# =================================================================
# MESH LOADING AND BASIC OPERATIONS
# =================================================================
[docs]
def load_mesh(
self, filepath: str, file_format: Optional[str] = None
) -> trimesh.Trimesh:
"""Load a mesh from file and store it as :attr:`mesh`.
Parameters
----------
filepath : str
Path to the mesh file.
file_format : str or None
Format hint (e.g. ``"obj"``). Auto-detected from the file extension
when ``None``.
Returns
-------
trimesh.Trimesh
The loaded mesh (also stored as ``self.mesh``).
Raises
------
ValueError
If the file cannot be loaded or does not contain a single mesh.
"""
try:
if file_format:
mesh = trimesh.load(filepath, file_type=file_format)
else:
mesh = trimesh.load(filepath)
# Ensure we have a single mesh
if isinstance(mesh, trimesh.Scene):
# If it's a scene, try to get the first geometry
geometries = list(mesh.geometry.values())
if geometries:
mesh = geometries[0]
else:
raise ValueError("No geometry found in mesh scene")
if not isinstance(mesh, trimesh.Trimesh):
raise ValueError(f"Loaded object is not a mesh: {type(mesh)}")
self.mesh = mesh
if self.verbose:
logger.info(
"Loaded mesh: %d vertices, %d faces",
len(mesh.vertices),
len(mesh.faces),
)
return mesh
except Exception as e:
raise ValueError(f"Failed to load mesh from {filepath}: {str(e)}")
[docs]
def save(self, filepath: str, file_format: str = "obj") -> None:
"""Export the mesh to a file.
Parameters
----------
filepath : str
Destination file path.
file_format : str, default "obj"
Format string understood by ``trimesh.Trimesh.export``.
"""
self.mesh.export(filepath, file_type=file_format)
[docs]
def copy(self) -> "MeshManager":
"""Return a new :class:`MeshManager` wrapping a deep copy of the mesh."""
return MeshManager(self.mesh.copy())
[docs]
def to_trimesh(self) -> trimesh.Trimesh:
"""Return the underlying :class:`trimesh.Trimesh` object."""
return self.mesh
[docs]
def bounding_box_diagonal(self) -> float:
"""Return the length of the bounding box space diagonal.
Computed as the Euclidean distance between the minimum and maximum
corners of the axis-aligned bounding box of the mesh. Useful as
a scale reference when choosing ``max_edge_length`` for
:class:`~mascaf.FitOptions`.
Returns
-------
float
Length of the bounding box diagonal.
"""
bounds = self.mesh.bounds
return float(np.linalg.norm(bounds[1] - bounds[0]))
# combining the functions from utils into this class
[docs]
def analyze_mesh(self) -> dict:
"""Analyze mesh properties for diagnostic purposes without modifying the mesh.
Returns
-------
dict
Dictionary with keys: ``face_count``, ``vertex_count``, ``bounds``,
``is_watertight``, ``is_winding_consistent``, ``volume``,
``is_manifold``, ``euler_characteristic``, ``genus``,
``normal_stats``, and ``issues`` (list of warning strings).
"""
mesh = self.to_trimesh()
# Initialize results dictionary
results = {
"face_count": len(mesh.faces),
"vertex_count": len(mesh.vertices),
"bounds": mesh.bounds.tolist() if hasattr(mesh, "bounds") else None,
"is_watertight": mesh.is_watertight,
"is_winding_consistent": mesh.is_winding_consistent,
"issues": [],
}
# Calculate volume (report actual value, even if negative)
try:
results["volume"] = mesh.volume
if mesh.volume < 0:
results["issues"].append(
"Negative volume detected - face normals may be inverted"
)
except Exception as e:
results["volume"] = None
results["issues"].append(f"Volume calculation failed: {str(e)}")
# Check for non-manifold edges
try:
if hasattr(mesh, "is_manifold"):
results["is_manifold"] = mesh.is_manifold
if not mesh.is_manifold:
results["issues"].append("Non-manifold edges detected")
except Exception:
results["is_manifold"] = None
# Calculate topological properties using trimesh's built-in methods
try:
# Use trimesh's built-in euler_number property for correct topology calculation
# For a sphere: euler_number = 2
# For a torus: euler_number = 0
# For a double torus: euler_number = -2
# Genus = (2 - euler_number) / 2
results["euler_characteristic"] = mesh.euler_number
# Only calculate genus for closed (watertight) meshes
if mesh.is_watertight:
# For a closed orientable surface: genus = (2 - euler_number) / 2
results["genus"] = int((2 - mesh.euler_number) / 2)
# Sanity check - genus should be non-negative for simple shapes
if results["genus"] < 0:
results["genus"] = (
0 # Default to 0 for simple shapes like spheres, cylinders
)
results["issues"].append(
"Calculated negative genus, defaulting to 0"
)
else:
# For non-watertight meshes, genus is not well-defined
results["genus"] = None
results["issues"].append("Genus undefined for non-watertight mesh")
except Exception as e:
results["genus"] = None
results["euler_characteristic"] = None
results["issues"].append(f"Topology calculation failed: {str(e)}")
# Analyze face normals
try:
if hasattr(mesh, "face_normals") and mesh.face_normals is not None:
# Get statistics on face normal directions
results["normal_stats"] = {
"mean": mesh.face_normals.mean(axis=0).tolist(),
"std": mesh.face_normals.std(axis=0).tolist(),
"sum": mesh.face_normals.sum(axis=0).tolist(),
}
# Check if normals are predominantly pointing inward (negative volume)
if results.get("volume", 0) < 0:
results["normal_direction"] = "inward"
else:
results["normal_direction"] = "outward"
except Exception as e:
results["normal_stats"] = None
results["issues"].append(f"Normal analysis failed: {str(e)}")
# Check for duplicate vertices and faces
try:
unique_verts = np.unique(mesh.vertices, axis=0)
results["duplicate_vertices"] = len(mesh.vertices) - len(unique_verts)
if results["duplicate_vertices"] > 0:
results["issues"].append(
f"Found {results['duplicate_vertices']} duplicate vertices"
)
except Exception:
results["duplicate_vertices"] = None
# Check for degenerate faces (zero area)
try:
if hasattr(mesh, "area_faces"):
degenerate_count = np.sum(mesh.area_faces < 1e-8)
results["degenerate_faces"] = int(degenerate_count)
if degenerate_count > 0:
results["issues"].append(
f"Found {degenerate_count} degenerate faces"
)
except Exception:
results["degenerate_faces"] = None
# Check for connected components
try:
components = mesh.split(only_watertight=False)
results["component_count"] = len(components)
if len(components) > 1:
results["issues"].append(
f"Mesh has {len(components)} disconnected components"
)
except Exception:
results["component_count"] = None
return results
[docs]
def print_mesh_analysis(self, verbose: bool = False) -> None:
"""
Analyze a mesh and print a formatted report of its properties.
Args:
verbose: Whether to print detailed information
"""
analysis = self.analyze_mesh()
logger.info("Mesh Analysis Report")
logger.info("====================")
# Basic properties
logger.info("\nGeometry:")
logger.info(" * Vertices: %s", analysis["vertex_count"])
logger.info(" * Faces: %s", analysis["face_count"])
if analysis.get("component_count") is not None:
logger.info(" * Components: %s", analysis["component_count"])
if analysis.get("volume") is not None:
logger.info(" * Volume: %.2f", analysis["volume"])
if analysis.get("bounds") is not None:
min_bound, max_bound = analysis["bounds"]
logger.info(
" * Bounds: [%.1f, %.1f, %.1f] to [%.1f, %.1f, %.1f]",
min_bound[0],
min_bound[1],
min_bound[2],
max_bound[0],
max_bound[1],
max_bound[2],
)
# Mesh quality
logger.info("\nMesh Quality:")
logger.info(" * Watertight: %s", analysis["is_watertight"])
logger.info(" * Winding Consistent: %s", analysis["is_winding_consistent"])
if analysis.get("is_manifold") is not None:
logger.info(" * Manifold: %s", analysis["is_manifold"])
if analysis.get("normal_direction") is not None:
logger.info(" * Normal Direction: %s", analysis["normal_direction"])
if analysis.get("duplicate_vertices") is not None:
logger.info(" * Duplicate Vertices: %s", analysis["duplicate_vertices"])
if analysis.get("degenerate_faces") is not None:
logger.info(" * Degenerate Faces: %s", analysis["degenerate_faces"])
# Topology
if (
analysis.get("genus") is not None
or analysis.get("euler_characteristic") is not None
):
logger.info("\nTopology:")
if analysis.get("genus") is not None:
logger.info(" * Genus: %s", analysis["genus"])
if analysis.get("euler_characteristic") is not None:
logger.info(
" * Euler Characteristic: %s", analysis["euler_characteristic"]
)
# Issues
if analysis["issues"]:
logger.info("\nIssues Detected (%d):", len(analysis["issues"]))
for i, issue in enumerate(analysis["issues"]):
logger.info(" %d. %s", i + 1, issue)
else:
logger.info("\nNo issues detected")
# Detailed stats
if verbose and analysis.get("normal_stats") is not None:
logger.info("\nNormal Statistics:")
mean = analysis["normal_stats"]["mean"]
sum_val = analysis["normal_stats"]["sum"]
logger.info(" * Mean: [%.4f, %.4f, %.4f]", mean[0], mean[1], mean[2])
logger.info(
" * Sum: [%.4f, %.4f, %.4f]", sum_val[0], sum_val[1], sum_val[2]
)
logger.info("\nRecommendation:")
if analysis["issues"]:
logger.info(" Consider using repair_mesh() to fix the detected issues.")
else:
logger.info(" Mesh appears to be in good condition.")
logger.info("====================")
[docs]
def repair_mesh(
self,
fix_holes: bool = True,
remove_duplicates: bool = True,
fix_normals: bool = True,
remove_degenerate: bool = True,
fix_negative_volume: bool = True,
keep_largest_component: bool = False,
verbose: bool = True,
) -> trimesh.Trimesh:
"""
Attempt to repair common mesh issues to improve watertightness and quality.
Args:
mesh_data: Either a Trimesh object or (vertices, faces) tuple
fix_holes: Whether to attempt filling holes
remove_duplicates: Whether to remove duplicate faces and vertices
fix_normals: Whether to fix face normal consistency
remove_degenerate: Whether to remove degenerate faces
fix_negative_volume: Whether to invert faces if mesh has negative volume
keep_largest_component: Whether to keep only the largest connected component
verbose: Whether to print repair summary
Returns:
Repaired mesh (new copy, original is not modified)
"""
mesh = self.to_trimesh()
repair_log = []
# Fix negative volume by inverting faces if needed
if fix_negative_volume:
try:
# Check if the mesh has a negative volume
if hasattr(mesh, "volume") and mesh.volume < 0:
initial_volume = mesh.volume
mesh.invert()
repair_log.append(
f"Inverted faces to fix negative volume: {initial_volume:.2f} → {mesh.volume:.2f}"
)
except Exception as e:
repair_log.append(f"Failed to fix negative volume: {e}")
# Remove duplicate and degenerate faces
if remove_duplicates:
try:
initial_faces = len(mesh.faces)
mesh.remove_duplicate_faces()
removed_faces = initial_faces - len(mesh.faces)
if removed_faces > 0:
repair_log.append(f"Removed {removed_faces} duplicate faces")
except Exception as e:
repair_log.append(f"Failed to remove duplicate faces: {e}")
if remove_degenerate:
try:
initial_faces = len(mesh.faces)
mesh.remove_degenerate_faces()
removed_faces = initial_faces - len(mesh.faces)
if removed_faces > 0:
repair_log.append(f"Removed {removed_faces} degenerate faces")
except Exception as e:
repair_log.append(f"Failed to remove degenerate faces: {e}")
# Fix winding consistency
if fix_normals:
try:
if not mesh.is_winding_consistent:
mesh.fix_normals()
if mesh.is_winding_consistent:
repair_log.append("Fixed face normal winding consistency")
else:
repair_log.append(
"Attempted to fix normals but still inconsistent"
)
except Exception as e:
repair_log.append(f"Failed to fix normals: {e}")
# Attempt to fill holes
if fix_holes:
try:
if not mesh.is_watertight:
initial_watertight = mesh.is_watertight
mesh.fill_holes()
if mesh.is_watertight and not initial_watertight:
repair_log.append(
"Successfully filled holes - mesh is now watertight"
)
elif mesh.is_watertight:
repair_log.append("Mesh was already watertight")
else:
repair_log.append(
"Attempted to fill holes but mesh still not watertight"
)
except Exception as e:
repair_log.append(f"Failed to fill holes: {e}")
# Keep only the largest component if requested
if keep_largest_component:
try:
components = mesh.split(only_watertight=False)
if len(components) > 1:
# Keep the largest component by volume or face count
volumes = [
abs(c.volume) if hasattr(c, "volume") else len(c.faces)
for c in components
]
largest_idx = np.argmax(volumes)
mesh = components[largest_idx]
repair_log.append(
f"Kept largest of {len(components)} components (volume: {volumes[largest_idx]:.2f})"
)
except Exception as e:
repair_log.append(f"Failed to isolate largest component: {e}")
# Final processing to ensure consistency
try:
mesh.process(validate=True)
repair_log.append("Applied final mesh processing and validation")
except Exception as e:
repair_log.append(f"Final processing failed: {e}")
# Store repair log as mesh metadata
if not hasattr(mesh, "metadata"):
mesh.metadata = {}
mesh.metadata["repair_log"] = repair_log
# Log repair summary
if verbose:
if repair_log:
logger.info("Mesh Repair Summary:")
for log_entry in repair_log:
logger.info(" • %s", log_entry)
# Final mesh status
logger.info("\nFinal Mesh Status:")
logger.info(
" • Volume: %s",
mesh.volume if hasattr(mesh, "volume") else "N/A",
)
logger.info(" • Watertight: %s", mesh.is_watertight)
logger.info(" • Winding consistent: %s", mesh.is_winding_consistent)
logger.info(" • Faces: %d", len(mesh.faces))
logger.info(" • Vertices: %d", len(mesh.vertices))
else:
logger.info("No repairs needed - mesh is in good condition")
self.mesh = mesh
return mesh
[docs]
def visualize_mesh_3d(
self,
title: str = "3D Mesh Visualization",
color: str = "lightblue",
backend: str = "auto",
show_axes: bool = True,
show_wireframe: bool = False,
width: int = 800,
height: int = 600,
*,
eye_scale: float = 1.25,
skel: Optional[Union["SkeletonGraph", List["SkeletonGraph"]]] = None,
skel_color: Union[str, List[str]] = "crimson",
skel_line_width: float = 3.0,
skel_opacity: float = 0.95,
) -> Optional[object]:
"""
Create a 3D visualization of a mesh.
Args:
title: Plot title
color: Mesh color (named color or RGB tuple)
backend: Visualization backend ('plotly' or 'matplotlib')
show_axes: Whether to show coordinate axes
show_wireframe: Whether to show wireframe overlay
skel: Optional SkeletonGraph or list of SkeletonGraph to overlay as 3D lines
skel_color: Color(s) for skeleton overlay. Can be a single color or list of colors (one per skeleton)
skel_line_width: Line width for skeleton overlay
skel_opacity: Opacity for skeleton overlay (plotly only)
Returns:
Figure object (backend-dependent) or None if visualization fails
"""
if backend == "auto":
# Try plotly first, then fallback to matplotlib
try:
import plotly.graph_objects as go # noqa: F401
backend = "plotly"
except ImportError:
try:
import matplotlib.pyplot as plt # noqa: F401
backend = "matplotlib"
except ImportError:
backend = "plotly"
if backend == "plotly":
return self._visualize_mesh_plotly(
title,
color,
show_axes,
show_wireframe,
width,
height,
eye_scale=eye_scale,
skel=skel,
skel_color=skel_color,
skel_line_width=skel_line_width,
skel_opacity=skel_opacity,
)
elif backend == "matplotlib":
return self._visualize_mesh_matplotlib(
title,
color,
show_axes,
show_wireframe,
skel=skel,
skel_color=skel_color,
skel_line_width=skel_line_width,
)
else:
raise ValueError(f"Unknown backend: {backend}")
def _visualize_mesh_plotly(
self,
title,
color,
show_axes,
show_wireframe,
width=800,
height=600,
*,
eye_scale: float = 1.25,
skel: Optional[Union["SkeletonGraph", List["SkeletonGraph"]]] = None,
skel_color: Union[str, List[str]] = "crimson",
skel_line_width: float = 3.0,
skel_opacity: float = 0.95,
):
"""Plotly-based mesh visualization with optional SkeletonGraph overlay."""
try:
import plotly.graph_objects as go
vertices = self.mesh.vertices
faces = self.mesh.faces
# Create mesh trace
mesh_trace = go.Mesh3d(
x=vertices[:, 0],
y=vertices[:, 1],
z=vertices[:, 2],
i=faces[:, 0],
j=faces[:, 1],
k=faces[:, 2],
opacity=0.8,
color=color,
name="Mesh",
)
fig = go.Figure(data=[mesh_trace])
# Add wireframe if requested
if show_wireframe:
edge_x = []
edge_y = []
edge_z = []
for face in faces:
for i in range(3):
v1, v2 = face[i], face[(i + 1) % 3]
edge_x += [vertices[v1][0], vertices[v2][0], None]
edge_y += [vertices[v1][1], vertices[v2][1], None]
edge_z += [vertices[v1][2], vertices[v2][2], None]
fig.add_trace(
go.Scatter3d(
x=edge_x,
y=edge_y,
z=edge_z,
mode="lines",
line=dict(color="black", width=1),
name="Wireframe",
)
)
# Add skeleton overlay if provided
if skel is not None:
# Normalize to list
if isinstance(skel, (list, tuple)):
skel_list = skel
else:
skel_list = [skel]
# Normalize colors to list
if isinstance(skel_color, str):
colors = [skel_color] * len(skel_list)
else:
colors = skel_color
if len(colors) < len(skel_list):
colors = list(colors) + [colors[-1]] * (
len(skel_list) - len(colors)
)
# Add each skeleton
for skel_idx, skeleton in enumerate(skel_list):
if skeleton is None:
continue
# Draw edges directly from the graph
color = colors[skel_idx]
# Collect all edge segments for this skeleton
edge_x = []
edge_y = []
edge_z = []
for u, v in skeleton.edges():
pos_u = skeleton.get_node_position(u)
pos_v = skeleton.get_node_position(v)
# Add edge as a line segment (with None separator for discontinuous lines)
edge_x.extend([pos_u[0], pos_v[0], None])
edge_y.extend([pos_u[1], pos_v[1], None])
edge_z.extend([pos_u[2], pos_v[2], None])
# Add all edges as a single trace
if edge_x:
fig.add_trace(
go.Scatter3d(
x=edge_x,
y=edge_y,
z=edge_z,
mode="lines",
line=dict(color=color, width=float(skel_line_width)),
opacity=float(skel_opacity),
name=f"Skeleton {skel_idx}",
showlegend=False,
)
)
# Configure layout
_e = float(eye_scale)
fig.update_layout(
title=title,
autosize=False,
width=width,
height=height,
margin=dict(l=0, r=0, t=40, b=0),
scene=dict(
aspectmode="data",
camera=dict(eye=dict(x=_e, y=_e, z=_e)),
xaxis=dict(visible=show_axes),
yaxis=dict(visible=show_axes),
zaxis=dict(visible=show_axes),
),
)
return fig
except ImportError:
print("Plotly not available")
return None
except Exception as e:
print(f"Plotly visualization failed: {e}")
return None
def _visualize_mesh_matplotlib(
self,
title,
color,
show_axes,
show_wireframe,
*,
skel: Optional[Union["SkeletonGraph", List["SkeletonGraph"]]] = None,
skel_color: Union[str, List[str]] = "crimson",
skel_line_width: float = 3.0,
):
"""Matplotlib-based mesh visualization with optional SkeletonGraph overlay."""
try:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")
vertices = self.mesh.vertices
faces = self.mesh.faces
# Create mesh surface
poly3d = Poly3DCollection(
vertices[faces],
alpha=0.7,
facecolor=color,
edgecolor="black" if show_wireframe else None,
)
ax.add_collection3d(poly3d)
# Add skeleton overlay if provided
if skel is not None:
# Normalize to list
if isinstance(skel, (list, tuple)):
skel_list = skel
else:
skel_list = [skel]
# Normalize colors to list
if isinstance(skel_color, str):
colors = [skel_color] * len(skel_list)
else:
colors = skel_color
if len(colors) < len(skel_list):
colors = list(colors) + [colors[-1]] * (
len(skel_list) - len(colors)
)
# Add each skeleton
for skel_idx, skeleton in enumerate(skel_list):
if skeleton is None:
continue
# Draw edges directly from the graph
color = colors[skel_idx]
for u, v in skeleton.edges():
pos_u = skeleton.get_node_position(u)
pos_v = skeleton.get_node_position(v)
# Draw edge as a line segment
ax.plot(
[pos_u[0], pos_v[0]],
[pos_u[1], pos_v[1]],
[pos_u[2], pos_v[2]],
color=color,
linewidth=float(skel_line_width),
)
ax.set_xlim(vertices[:, 0].min(), vertices[:, 0].max())
ax.set_ylim(vertices[:, 1].min(), vertices[:, 1].max())
ax.set_zlim(vertices[:, 2].min(), vertices[:, 2].max())
ax.set_xlabel("X (µm)")
ax.set_ylabel("Y (µm)")
ax.set_zlabel("Z (µm)")
ax.set_title(title)
if not show_axes:
ax.set_axis_off()
plt.tight_layout()
return fig
except ImportError:
print("Matplotlib not available")
return None
except Exception as e:
print(f"Matplotlib visualization failed: {e}")
return None
[docs]
def visualize_mesh_slice_interactive(
self,
title: str = "Interactive Mesh Slice",
z_range: Optional[Tuple[float, float]] = None,
num_slices: int = 50,
slice_color: str = "red",
mesh_color: str = "lightblue",
mesh_opacity: float = 0.3,
) -> Optional[object]:
"""
Create an interactive 3D visualization of a mesh with a controllable slice plane.
This function displays a 3D mesh and calculates the intersection of the mesh
with an xy-plane at a user-controlled z-value. The intersection is shown as a
colored line on the mesh. A slider allows the user to interactively change the
z-value of the intersection plane.
Args:
title: Plot title
z_range: Tuple of (min_z, max_z) for slice range. Auto-detected if None.
num_slices: Number of positions for the slider
slice_color: Color for the intersection line
mesh_color: Color for the 3D mesh
mesh_opacity: Opacity of the 3D mesh (0-1)
Returns:
Plotly figure with interactive slider for controlling the z-value
"""
try:
import plotly.graph_objects as go
except ImportError:
print("Plotly is required for interactive visualization")
return None
mesh = self.mesh
# Determine z-range if not provided
if z_range is None:
z_min, z_max = mesh.vertices[:, 2].min(), mesh.vertices[:, 2].max()
# Add small padding
padding = (z_max - z_min) * 0.05
z_min -= padding
z_max += padding
else:
z_min, z_max = z_range
# Create the base figure with the mesh
fig = go.Figure()
# Add the mesh to the figure
fig.add_trace(
go.Mesh3d(
x=mesh.vertices[:, 0],
y=mesh.vertices[:, 1],
z=mesh.vertices[:, 2],
i=mesh.faces[:, 0],
j=mesh.faces[:, 1],
k=mesh.faces[:, 2],
opacity=mesh_opacity,
color=mesh_color,
name="Mesh",
)
)
# Function to create a slice at a given z-value
def create_slice_trace(z_value):
# Calculate intersection with plane at z_value
section = mesh.section(plane_origin=[0, 0, z_value], plane_normal=[0, 0, 1])
# If no intersection, return None
if (
section is None
or not hasattr(section, "entities")
or len(section.entities) == 0
):
return None
# Process all entities in the section to get 3D coordinates
all_points = []
for entity in section.entities:
if hasattr(entity, "points") and len(entity.points) > 0:
# Get the actual 2D coordinates
points_2d = section.vertices[entity.points]
# Convert to 3D by adding z_value
points_3d = np.column_stack(
[points_2d, np.full(len(points_2d), z_value)]
)
# Add closing point if needed (to complete the loop)
if len(points_2d) > 2 and not np.array_equal(
points_2d[0], points_2d[-1]
):
closing_point = np.array(
[points_2d[0][0], points_2d[0][1], z_value]
)
points_3d = np.vstack([points_3d, closing_point])
# Add to all points list
all_points.extend(points_3d.tolist())
# Add None to create a break between separate entities
all_points.append([None, None, None])
# If we have points, create a scatter trace
if all_points:
x_coords = [p[0] if p is not None else None for p in all_points]
y_coords = [p[1] if p is not None else None for p in all_points]
z_coords = [p[2] if p is not None else None for p in all_points]
return go.Scatter3d(
x=x_coords,
y=y_coords,
z=z_coords,
mode="lines",
line=dict(color=slice_color, width=5),
name=f"Slice at z={z_value:.2f}",
)
return None
# Create initial slice
initial_z = (z_min + z_max) / 2
initial_slice = create_slice_trace(initial_z)
# Add initial slice to figure if it exists
if initial_slice:
fig.add_trace(initial_slice)
# Create frames for animation
frames = []
for i, z_val in enumerate(np.linspace(z_min, z_max, num_slices)):
# Create a slice at this z-value
slice_trace = create_slice_trace(z_val)
# If we have a valid slice, add it to frames
if slice_trace:
frame_data = [fig.data[0], slice_trace] # Mesh and slice
else:
frame_data = [fig.data[0]] # Just the mesh
frames.append(
go.Frame(
data=frame_data,
name=f"frame_{i}",
traces=[0, 1], # Update both traces
)
)
# Create slider steps
steps = []
for i, z_val in enumerate(np.linspace(z_min, z_max, num_slices)):
step = dict(
args=[
[f"frame_{i}"],
{"frame": {"duration": 0, "redraw": True}, "mode": "immediate"},
],
label=f"{z_val:.2f}",
method="animate",
)
steps.append(step)
# Configure the slider
sliders = [
dict(
active=num_slices // 2, # Start in the middle
currentvalue={
"prefix": "Z-value: ",
"visible": True,
"xanchor": "right",
},
pad={"t": 50, "b": 10},
len=0.9,
x=0.1,
y=0,
steps=steps,
)
]
# Configure the figure layout
fig.update_layout(
title=title,
scene=dict(aspectmode="data", camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))),
height=800, # Taller to make room for slider
margin=dict(l=50, r=50, b=100, t=100), # Add margin at bottom for slider
sliders=sliders,
# Add animation controls
updatemenus=[
dict(
type="buttons",
showactive=False,
y=0,
x=0,
xanchor="left",
yanchor="top",
pad=dict(t=60, r=10),
buttons=[
dict(
label="Play",
method="animate",
args=[
None,
{
"frame": {"duration": 200, "redraw": True},
"fromcurrent": True,
},
],
),
dict(
label="Pause",
method="animate",
args=[
[None],
{
"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
},
],
),
dict(
label="Reset View",
method="relayout",
args=[{"scene.camera.eye": dict(x=1.5, y=1.5, z=1.5)}],
),
],
)
],
)
# Set frames
fig.frames = frames
return fig