"""Utility functions for working with meshes."""
import shutil
from brainglobe_utils.general.system import get_num_processes
from rich.progress import track
from scipy.ndimage import binary_closing, binary_fill_holes
from brainglobe_atlasapi.structure_tree_util import (
get_structures_tree,
preorder_depth_first_search,
)
try:
from vedo import Mesh, Volume, write
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Mesh generation with these utils requires vedo\n"
+ ' please install with "pip install vedo -U"'
)
try:
import mcubes
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Mesh generation with these utils requires PyMCubes\n"
+ ' please install with "pip install PyMCubes -U"'
)
import multiprocessing as mp
from pathlib import Path
import numpy as np
import zarr
from loguru import logger
from brainglobe_atlasapi.atlas_generation.volume_utils import (
create_masked_array,
)
# ----------------- #
# MESH CREATION #
# ----------------- #
[docs]
def create_region_mesh(args):
"""
Automate the creation of a region's mesh. Given a volume of annotations
and a structures tree, it takes the volume's region corresponding to the
region of interest and all of its children's labels and creates a mesh.
It takes a tuple of arguments to facilitate parallel processing with
multiprocessing.pool.map.
Note, by default it avoids overwriting a structure's mesh if the
.obj file exists already.
Parameters
----------
meshes_dir_path: pathlib Path object with folder where meshes are saved
tree: treelib.Tree with hierarchical structures information
node: tree's node corresponding to the region whose mesh is being created
labels: list of unique label annotations in annotated volume,
(list(np.unique(annotated_volume)))
annotated_volume: 3d numpy array path to a zarr store with annotations
ROOT_ID: int,
id of root structure (mesh creation is a bit more refined for that)
"""
# Split arguments
meshes_dir_path = args[0]
node = args[1]
tree = args[2]
labels = args[3]
annotated_volume = args[4]
ROOT_ID = args[5]
closing_n_iters = args[6]
decimate_fraction = args[7]
smooth = args[8]
verbosity = args[9] if len(args) > 9 else 0
if verbosity > 0:
logger.debug(f"Creating mesh for region {args[1].identifier}")
# Avoid overwriting existing mesh
savepath = meshes_dir_path / f"{node.identifier}.obj"
# if savepath.exists():
# logger.debug(f"Mesh file save path exists already, skipping.")
# return
if not isinstance(annotated_volume, np.ndarray):
# If annotated_volume is a path to a zarr store, open it
if isinstance(annotated_volume, (str, Path)):
annotated_volume = zarr.open(annotated_volume, mode="r")
else:
raise TypeError(
"Argument annotated_volume should be a np.ndarray"
" or a path to a zarr store"
)
# Get labels for region and it's children
stree = tree.subtree(node.identifier)
ids = list(stree.nodes.keys())
# Keep only labels that are in the annotation volume
matched_labels = [i for i in ids if i in labels]
if (
not matched_labels
): # it fails if the region and all of its children are not in annotation
if verbosity > 0:
print(f"No labels found for {node.tag}")
return
else:
# Create mask and extract mesh
mask = create_masked_array(annotated_volume, ids)
if np.sum(mask) == 0:
print(f"Empty mask for {node.tag}")
else:
if node.identifier == ROOT_ID:
extract_mesh_from_mask(
mask,
obj_filepath=savepath,
smooth=smooth,
decimate_fraction=decimate_fraction,
)
else:
extract_mesh_from_mask(
mask,
obj_filepath=savepath,
smooth=smooth,
closing_n_iters=closing_n_iters,
decimate_fraction=decimate_fraction,
)
[docs]
def construct_meshes_from_annotation(
save_path: Path,
volume: np.ndarray,
structures_list,
closing_n_iters=2,
decimate_fraction=0,
smooth=False,
parallel: bool = True,
num_threads: int = -1,
verbosity: int = 0,
):
"""
Retrieve or construct atlas region meshes for a given annotation volume.
If an atlas is packaged with mesh files, reuse those. Otherwise, construct
the meshes using the existing volume and structure tree. Returns a
dictionary mapping structure IDs to their corresponding .obj mesh files.
Parameters
----------
save_path : Path
Path to the directory where new mesh files will be saved.
volume : np.ndarray
3D annotation volume.
structures_list : list
List of structure dictionaries containing id information.
smooth: bool
if True the surface mesh is smoothed
closing_n_iters: int
number of iterations of closing morphological operation.
set to None to avoid applying morphological operations
decimate_fraction: float in range [0, 1].
What fraction of the original number of vertices is to be kept.
EG .5 means that 50% of the vertices are kept,
the others are removed.
parallel: bool
If True, uses multiprocessing to speed up mesh creation
num_threads: int
Number of threads to use for parallel processing.
If -1, threads are set to the maximum number based on
available memory.
If > 0, uses that many threads.
verbosity: int
Level of verbosity for logging. 0 for no output, 1 for basic info.
Returns
-------
dict
Dictionary of structure IDs and paths to their .obj mesh files.
"""
if num_threads == 0:
raise ValueError("Number of threads cannot be 0")
meshes_dir_path = save_path / "meshes"
meshes_dir_path.mkdir(exist_ok=True)
tree = get_structures_tree(structures_list)
labels = np.unique(volume).astype(np.int32)
# Only used for parallel processing
ann_path = save_path / "temp_annotations.zarr"
for key, node in tree.nodes.items():
node.data = Region(key in labels)
volume_size = volume.size
if parallel:
compressor = zarr.codecs.BloscCodec(
cname="zstd", clevel=6, shuffle=zarr.codecs.BloscShuffle.bitshuffle
)
if ann_path.exists():
shutil.rmtree(ann_path)
ann_store = zarr.storage.LocalStore(ann_path)
zarr.create_array(
ann_store,
data=volume,
compressors=compressor,
)
ann_store.close()
volume = ann_path
root_id = tree.root
# Create a list of arguments for each region's mesh creation
args_list = [
(
meshes_dir_path,
node,
tree,
labels,
volume,
root_id,
closing_n_iters,
decimate_fraction,
smooth,
verbosity,
)
for node in preorder_depth_first_search(tree)
]
if parallel:
if num_threads == -1:
# Each thread uses ~ 7 times the number of voxels in the volume.
mem_per_thread = 7 * volume_size
num_threads = get_num_processes(
ram_needed_per_process=mem_per_thread,
n_max_processes=mp.cpu_count() - 1,
fraction_free_ram=0.05,
)
logger.info(f"Using {num_threads} threads for mesh creation")
with mp.Pool(num_threads) as pool:
for _ in track(
pool.imap(create_region_mesh, args_list),
total=len(args_list),
description="Creating meshes",
):
pass
shutil.rmtree(ann_path) # Clean up temporary annotations zarr store
else:
for args in track(
args_list, total=len(args_list), description="Creating meshes"
):
create_region_mesh(args)
meshes_dict = {}
structures_with_mesh = []
for s in structures_list:
mesh_path = meshes_dir_path / f'{s["id"]}.obj'
if not mesh_path.exists():
print(f"No mesh file exists for: {s}, ignoring it")
continue
if mesh_path.stat().st_size < 512:
print(f"obj file for {s} is too small, ignoring it.")
continue
structures_with_mesh.append(s)
meshes_dict[s["id"]] = mesh_path
print(
(
f"In the end, {len(structures_with_mesh)}",
"structures with mesh are kept",
)
)
return meshes_dict
[docs]
class Region(object):
"""
Class used to add metadata to treelib.Tree during atlas creation.
Using this means that you can then filter tree nodes depending on
whether they have a mesh/label.
"""
def __init__(self, has_label):
self.has_label = has_label