import logging
import cv2
import numpy as np
import os
import matplotlib # added for docs generation
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.cm as cm
import matplotlib.axes
from scipy.spatial import Voronoi, voronoi_plot_2d
from shapely.geometry import Polygon
from typing import Dict, List, Tuple, Union
from matplotlib.lines import Line2D
from matplotlib.colors import ListedColormap
from gridgene.logger import get_logger
from typing import Optional, Tuple, Dict, Any, List
from scipy.ndimage import distance_transform_edt
from skimage.measure import label
from scipy.ndimage import distance_transform_edt
from skimage.measure import regionprops
import cv2
import numpy as np
from shapely.geometry import Polygon, box
[docs]
def timeit(func):
@wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
print(f"{func.__name__} took {end - start:.4f} seconds")
return result
return wrapper
[docs]
class GetMasks:
"""
Class to handle mask processing operations such as filtering, creation, morphology, subtraction, saving, and plotting.
Parameters
----------
logger : logging.Logger, optional
Logger instance for logging messages. If None, a default logger is configured.
image_shape : tuple of int, optional
Tuple representing the shape of the image (height, width).
"""
def __init__(self, logger: Optional[logging.Logger] = None, image_shape: Optional[Tuple[int, int]] = None):
"""
Initialize the GetMasks class.
Parameters
----------
logger : logging.Logger, optional
Logger instance for logging messages. If None, a default logger is created.
image_shape : tuple of int, optional
Tuple representing the shape of the image (height, width).
Returns
-------
None
"""
self.image_shape = image_shape
self.height = self.image_shape[0] if self.image_shape is not None else None
self.width = self.image_shape[1] if self.image_shape is not None else None
self.logger = logger or get_logger(f'{__name__}.{"GetMasks"}')
self.logger.info("Initialized GetMasks")
[docs]
def filter_binary_mask_by_area(self, mask: np.ndarray, min_area: int) -> np.ndarray:
"""
Remove small connected components from a binary mask.
Parameters
----------
mask : np.ndarray
Binary mask (0 or 1).
min_area : int
Minimum area threshold.
Returns
-------
np.ndarray
Filtered binary mask.
"""
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask.astype(np.uint8), connectivity=8)
output_mask = np.zeros_like(mask, dtype=np.uint8)
for i in range(1, num_labels): # skip background
area = stats[i, cv2.CC_STAT_AREA]
if area >= min_area:
output_mask[labels == i] = 1
return output_mask
[docs]
def filter_labeled_mask_by_area(self, mask: np.ndarray, min_area: int) -> np.ndarray:
"""
Filter a labeled mask by keeping only components with area >= min_area.
Parameters
----------
mask : np.ndarray
Input labeled mask (integer labels).
min_area : int
Minimum area threshold.
Returns
-------
np.ndarray
Filtered labeled mask preserving label IDs.
"""
mask = mask.astype(np.int32)
unique_labels, counts = np.unique(mask, return_counts=True)
labels_to_keep = unique_labels[(counts >= min_area) & (unique_labels != 0)]
filtered_mask = np.zeros_like(mask, dtype=np.int32)
for label in labels_to_keep:
filtered_mask[mask == label] = label
# if logger:
self.logger.info(f'Filtered labeled mask by area >= {min_area}, kept {len(labels_to_keep)} components.')
return filtered_mask
[docs]
def create_mask(self, contours: List[np.ndarray]) -> np.ndarray:
"""
Create a binary mask from contours.
Parameters
----------
contours : list of np.ndarray
List of contours.
Returns
-------
np.ndarray
Binary mask.
Raises
------
ValueError
If image shape is not defined.
"""
if self.height is None or self.width is None:
raise ValueError("Image shape must be defined to create mask.")
mask = np.zeros((self.height, self.width), dtype=np.uint8)
cv2.drawContours(mask, contours, -1, color=1, thickness=cv2.FILLED)
return mask
[docs]
def fill_holes(self, mask: np.ndarray) -> np.ndarray:
"""
Fill holes inside a binary mask.
Parameters
----------
mask : np.ndarray
Binary mask.
Returns
-------
np.ndarray
Hole-filled binary mask.
"""
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
filled_mask = np.zeros_like(mask)
cv2.drawContours(filled_mask, contours, -1, color=1, thickness=cv2.FILLED)
return filled_mask
[docs]
def apply_morphology(self, mask: np.ndarray, operation: str = "open", kernel_size: int = 3) -> np.ndarray:
"""
Apply morphological operations to a binary mask.
Parameters
----------
mask : np.ndarray
Binary mask to process.
operation : str, optional
Morphological operation: "open", "close", "erode", or "dilate" (default is "open").
kernel_size : int, optional
Size of the structuring element (default is 3).
Returns
-------
np.ndarray
Processed binary mask.
"""
kernel = np.ones((kernel_size, kernel_size), np.uint8)
if operation == "open":
result = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
elif operation == "close":
result = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
elif operation == "erode":
result = cv2.erode(mask, kernel, iterations=1)
elif operation == "dilate":
result = cv2.dilate(mask, kernel, iterations=1)
else:
self.logger.warning(f"Unknown morphological operation '{operation}', returning original mask.")
result = mask
self.logger.info(f'Applied morphology operation "{operation}" with kernel size {kernel_size}.')
return result
[docs]
def subtract_masks(self, base_mask: np.ndarray, *masks: np.ndarray) -> np.ndarray:
"""
Subtract one or more masks from a base mask.
Parameters
----------
base_mask : np.ndarray
Initial binary mask.
*masks : np.ndarray
Masks to subtract from the base mask.
Returns
-------
np.ndarray
Resulting mask after subtraction.
"""
result_mask = base_mask.copy()
for mask in masks:
result_mask = cv2.subtract(result_mask, mask)
self.logger.info(f'Subtracted masks from base mask.')
return result_mask
[docs]
def save_masks_npy(self, mask: np.ndarray, save_path: str) -> None:
"""
Save mask as a .npy file.
Parameters
----------
mask : np.ndarray
Mask to save.
save_path : str
Path to save the .npy file.
Returns
-------
None
"""
np.save(save_path, mask)
self.logger.info(f'Mask saved at {save_path}')
[docs]
def save_masks(self, mask: np.ndarray, path: str) -> None:
"""
Save mask as an image file.
Parameters
----------
mask : np.ndarray
Binary mask to save.
path : str
Path to save the image file.
Returns
-------
None
"""
cv2.imwrite(path, mask * 255)
self.logger.info(f'Mask saved at {path}')
[docs]
def plot_masks(
self,
masks: List[np.ndarray],
mask_names: List[str],
background_color: Tuple[int, int, int] = (0, 0, 0),
mask_colors: Optional[Dict[str, Tuple[int, int, int]]] = None,
path: Optional[str] = None,
show: bool = True,
ax: Optional[plt.Axes] = None,
figsize: Tuple[int, int] = (10, 10)
) -> None:
"""
Plot multiple masks with their corresponding names.
Parameters
----------
masks : list of np.ndarray
List of masks to plot.
mask_names : list of str
Names corresponding to each mask.
background_color : tuple of int, optional
RGB color tuple for background areas (default (0, 0, 0)).
mask_colors : dict, optional
Mapping of mask names to RGB colors.
path : str, optional
Directory path to save the plot image.
show : bool, optional
Whether to display the plot (default True).
ax : matplotlib.axes.Axes, optional
Matplotlib axis to plot on. Creates new figure if None.
figsize : tuple of int, optional
Size of the figure in inches (width, height).
Returns
-------
None
"""
if len(masks) != len(mask_names):
self.logger.error('The number of masks and mask names must be the same.')
return
# Create a background image filled with the background color
background = np.full((self.height, self.width, 3), background_color)
# Create a list to store the patches for the legend
legend_patches = []
# Choose a colormap based on the number of masks
colormap = cm.get_cmap('tab10') if len(masks) <= 10 else cm.get_cmap('tab20')
# Add each mask to the background image
for i, (mask, mask_name) in enumerate(zip(masks, mask_names)):
# Choose a color for the mask
if mask_colors and mask_name in mask_colors:
mask_color = np.array(mask_colors[mask_name])
else:
mask_color = (np.array(colormap(i % colormap.N)[:3]) * 255).astype(int)
# Apply the mask color to the mask image
background[mask!=0] = mask_color
# Create a patch for the legend
legend_patches.append(mpatches.Patch(color=mask_color / 255, label=mask_name))
# Flip the mask horizontally and rotate 90 degrees clockwise
background = np.fliplr(background)
background = np.rot90(background, k=1)
created_fig = False
if ax is None:
created_fig = True
fig, ax = plt.subplots(figsize=figsize)
# Plot the background image
ax.imshow(background, origin='lower')
ax.set_axis_off()
# Add legend
ax.legend(
handles=legend_patches,
bbox_to_anchor=(1.05, 1),
loc='upper left',
bbox_transform=ax.transAxes
)
# Save the image if path is provided
if path is not None:
save_path = os.path.join(
path,
f'masks_{"_".join(mask_names).replace(" ", "").lower()}.png'
)
plt.savefig(save_path, dpi=1000, bbox_inches='tight')
self.logger.info(f'Plot saved at {save_path}')
# Show the plot if required
if show:
plt.show()
plt.close()
# Close the figure if it was created within this function
if created_fig:
plt.close(fig)
[docs]
def plot_labeled_masks(self, label_mask,mask_name, show=False, save_path=None, dpi=300):
"""
Plot the labeled mask with colored objects and bounding boxes.
Parameters
----------
mask_dict : dict (required)
show : bool (optional)
Returns
-------
"""
unique_labels = np.unique(label_mask)
# Generate random colors for each label using a colormap
colormap = cm.get_cmap('tab10', len(unique_labels))
colors = {label: colormap(i) for i, label in enumerate(unique_labels) if label != 0}
# Create a colored mask
colored_mask = np.zeros((self.height,self.width, 3), dtype=np.float32)
for label in unique_labels:
if label == 0:
continue
colored_mask[label_mask == label] = colors[label][:3]
# Create a figure and axis to plot the mask
fig, ax = plt.subplots()
ax.imshow(colored_mask, origin='lower')
# Plot each labeled object with its corresponding color and label number
for region in regionprops(label_mask):
if region.label == 0:
continue
minr, minc, maxr, maxc = region.bbox
rect = mpatches.Rectangle((minc, minr), maxc - minc, maxr - minr,
fill=False, edgecolor=colors[region.label], linewidth=2)
ax.add_patch(rect)
y, x = region.centroid
ax.text(x, y, str(region.label), color='white', fontsize=8, ha='center', va='center')
# Set the title and show the plot
ax.set_title(mask_name)
# # Save the plot as a high-resolution image
if save_path is not None:
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
# Show the plot if requested
if show:
plt.show()
return fig, ax
# CancerStromaInterfaceanalysis
[docs]
class ConstrainedMaskExpansion(GetMasks):
"""
Class for expanding a seed mask with constraints, generating binary, labeled, and referenced expansions.
"""
def __init__(
self,
seed_mask: np.ndarray,
constraint_mask: Optional[np.ndarray] = None,
logger: Optional[logging.Logger] = None,
) -> None:
"""
Initialize the ConstrainedMaskExpansion object.
Parameters
----------
seed_mask : np.ndarray
Binary seed mask to expand (non-zero labeled regions).
constraint_mask : np.ndarray, optional
Binary mask to limit the expansion area. If None, no constraint is applied.
logger : logging.Logger, optional
Logger instance for logging messages.
Raises
------
ValueError
If seed_mask is None.
"""
if seed_mask is None:
raise ValueError("Seed mask cannot be None.")
self.seed_mask_raw = seed_mask.astype(np.uint8)
self.seed_mask = label(self.seed_mask_raw) # connected components
self.constraint_mask = (
constraint_mask.astype(np.uint8)
if constraint_mask is not None
else np.ones_like(seed_mask, dtype=np.uint8)
)
image_shape = self.seed_mask.shape
super().__init__(logger=logger, image_shape=image_shape)
self.binary_expansions: Dict[str, np.ndarray] = {}
self.labeled_expansions: Dict[str, np.ndarray] = {}
self.referenced_expansions: Dict[str, np.ndarray] = {}
[docs]
def expand_mask(
self,
expansion_pixels: List[int],
min_area: Optional[int] = None,
restrict_to_limit: bool = True,
) -> None:
"""
Expand the seed mask outward by specified pixel distances with optional area filtering and constraints.
Parameters
----------
expansion_pixels : list of int
List of expansion distances (in pixels) from the seed mask.
min_area : int, optional
Minimum area threshold for keeping connected components in each expansion ring.
restrict_to_limit : bool, optional
If True, limit the expansion within the constraint mask.
Returns
-------
None
"""
sorted_dists = sorted(expansion_pixels)
dist_map = distance_transform_edt(self.seed_mask == 0)
previous_mask = np.zeros_like(self.seed_mask, dtype=bool)
current_labels = self.seed_mask.copy()
for dist in sorted_dists:
if dist == sorted_dists[0]:
ring = (dist_map <= dist) & (self.seed_mask == 0)
else:
prev_dist = sorted_dists[sorted_dists.index(dist) - 1]
ring = (dist_map <= dist) & (dist_map > prev_dist) & (self.seed_mask == 0)
if restrict_to_limit:
ring &= self.constraint_mask.astype(bool)
ring &= ~previous_mask
if min_area:
ring = self.filter_binary_mask_by_area(ring.astype(np.uint8), min_area).astype(bool)
previous_mask |= ring
# Store binary mask
self.binary_expansions[f"expansion_{dist}"] = ring.astype(np.uint8)
# Store labeled components using skimage
self.labeled_expansions[f"expansion_{dist}"] = label(ring.astype(np.uint8))
# Store label-referenced expansion using seed_mask
# referenced = self.propagate_labels(self.seed_mask, ring)
referenced = self.propagate_labels(current_labels, ring)
referenced[~ring] = 0
self.referenced_expansions[f"expansion_{dist}"] = referenced
current_labels[referenced > 0] = referenced[referenced > 0]
self.binary_expansions["seed_mask"] = (self.seed_mask > 0).astype(np.uint8)
self.labeled_expansions["seed_mask"] = self.seed_mask.copy()
self.referenced_expansions["seed_mask"] = self.seed_mask.copy()
constraint_remaining = (self.constraint_mask.astype(bool) & ~previous_mask).astype(np.uint8)
self.binary_expansions["constraint_remaining"] = constraint_remaining
# self.labeled_expansions["constraint_remaining"] = np.zeros_like(self.seed_mask, dtype=np.int32)
self.labeled_expansions["constraint_remaining"] = label(constraint_remaining)
self.referenced_expansions["constraint_remaining"] = np.zeros_like(self.seed_mask, dtype=np.int32)
[docs]
def propagate_labels(self, seed_labeled: np.ndarray, expansion_mask: np.ndarray) -> np.ndarray:
"""
Propagate labels from the seed labeled mask into the expansion region
using nearest-neighbor distance transform.
Parameters
----------
seed_labeled : np.ndarray
Labeled seed mask where non-zero values indicate components.
expansion_mask : np.ndarray
Binary mask indicating the expansion region to propagate labels into.
Returns
-------
np.ndarray
Labeled mask with propagated labels in the expansion area.
"""
output = np.zeros_like(seed_labeled, dtype=np.int32)
# Compute distance transform on inverse of seed (background = True)
# Return indices of nearest labeled pixels
distance, indices = distance_transform_edt(seed_labeled == 0, return_indices=True)
# Use the nearest labeled pixel for expansion mask locations
nearest_labels = seed_labeled[tuple(indices)]
# Fill only the expansion region with nearest labels
output[expansion_mask.astype(bool)] = nearest_labels[expansion_mask.astype(bool)]
# Preserve original seed labels
output[seed_labeled > 0] = seed_labeled[seed_labeled > 0]
return output
[docs]
class SingleClassObjectAnalysis(GetMasks):
"""
Analyze and expand a single binary object mask using distance-based ring expansion.
This class computes concentric ring-based expansions of a binary mask,
assigns unique labels to each expanded region, and tracks mask lineage
through label propagation.
Attributes
----------
mask : np.ndarray
Binary mask of the object to be expanded.
expansion_distances : List[int]
List of expansion radii in pixels.
labelled_mask : np.ndarray
Resulting labeled mask with original and expanded areas.
binary_masks : Dict[str, np.ndarray]
Dictionary of binary masks keyed by expansion distance.
labelled_masks : Dict[str, np.ndarray]
Dictionary of labeled masks keyed by expansion distance.
reference_masks : Dict[str, np.ndarray]
Masks encoding reference to original object.
"""
def __init__(
self,
get_masks_instance: GetMasks,
contours_object: List[np.ndarray],
contour_name: str = ""
) -> None:
"""
Initialize SingleClassObjectAnalysis with contour data and a GetMasks utility instance.
Parameters
----------
get_masks_instance : GetMasks
Instance of GetMasks providing access to shape and filtering methods.
contours_object : List[np.ndarray]
List of contours representing the object.
contour_name : str, optional
Optional name identifier for the object.
"""
self.get_masks_instance = get_masks_instance
self.height = get_masks_instance.height
self.width = get_masks_instance.width
self.logger = get_masks_instance.logger
self.mask_object_SA: Optional[np.ndarray] = None
self.binary_expansions: Dict[str, np.ndarray] = {}
self.labeled_expansions: Dict[str, np.ndarray] = {}
self.referenced_expansions: Dict[str, np.ndarray] = {}
self.contours_object = contours_object
self.contour_name = contour_name
[docs]
def get_mask_objects(
self,
exclude_masks: Optional[List[np.ndarray]] = None,
filter_area: Optional[int] = None
) -> None:
"""
Generate binary mask from object contours, optionally subtract other masks,
and apply area-based filtering.
Parameters
----------
exclude_masks : list of np.ndarray, optional
List of masks to subtract from the generated object mask.
filter_area : int, optional
Minimum area threshold to retain connected components in the object mask.
Returns
-------
None
"""
mask_object = np.zeros((self.height, self.width), dtype=np.uint8)
cv2.drawContours(mask_object, self.contours_object, -1, color=1, thickness=cv2.FILLED)
if exclude_masks:
for mask in exclude_masks:
mask_object = cv2.subtract(mask_object, mask)
if filter_area is not None:
self.logger.info(f"Filtering object mask by area: {filter_area}")
mask_object = self.get_masks_instance.filter_mask_by_area(mask_object, min_area=filter_area)
self.mask_object_SA = mask_object
self.logger.info("Mask for objects created.")
[docs]
def get_objects_expansion(
self,
expansions_pixels: Optional[List[int]] = None,
filter_area: Optional[int] = None
) -> None:
"""
Expand the object mask using distance-based rings and optionally filter
each ring by minimum area. Generates binary, labeled, and propagated-label expansion masks.
Parameters
----------
expansions_pixels : list of int, optional
List of pixel distances for expansion.
filter_area : int, optional
Minimum area threshold to retain connected components in each expansion ring.
Returns
-------
None
"""
if self.mask_object_SA is None:
self.logger.error("No object mask to expand.")
return
if expansions_pixels is None:
expansions_pixels = []
self.seed_mask = label(self.mask_object_SA)
dist_map = distance_transform_edt(self.seed_mask == 0)
previous_mask = np.zeros_like(self.seed_mask, dtype=bool)
current_labels = self.seed_mask.copy()
for i, dist in enumerate(sorted(expansions_pixels)):
prev_dist = sorted(expansions_pixels)[i - 1] if i > 0 else 0
raw_ring = (dist_map <= dist) & (dist_map > prev_dist) & (self.seed_mask == 0)
if filter_area:
raw_ring = self.get_masks_instance.filter_binary_mask_by_area(raw_ring.astype(np.uint8),
filter_area).astype(bool)
key = f"expansion_{dist}"
ring = raw_ring & (~previous_mask)
if not np.any(ring):
self.logger.warning(f"Expansion ring for distance {dist} is empty.")
empty_mask = np.zeros_like(self.seed_mask, dtype=np.uint8)
self.binary_expansions[key] = empty_mask
self.labeled_expansions[key] = empty_mask
self.referenced_expansions[key] = empty_mask
continue
previous_mask |= ring
self.binary_expansions[key] = ring.astype(np.uint8)
self.labeled_expansions[key] = label(ring.astype(np.uint8))
referenced = self.propagate_labels(current_labels, ring)
referenced[~ring] = 0
self.referenced_expansions[key] = referenced
current_labels[referenced > 0] = referenced[referenced > 0]
# Store the base seed info
self.binary_expansions["seed_mask"] = (self.seed_mask > 0).astype(np.uint8)
self.labeled_expansions["seed_mask"] = self.seed_mask.copy()
self.referenced_expansions["seed_mask"] = self.seed_mask.copy()
[docs]
def propagate_labels(self, seed_labeled: np.ndarray, expansion_mask: np.ndarray) -> np.ndarray:
"""
Propagate labeled regions from a seed mask into the expansion area using iterative dilation.
Parameters
----------
seed_labeled : np.ndarray
Input labeled mask where each connected component has a unique integer label.
expansion_mask : np.ndarray
Binary mask indicating the region where labels should expand.
Returns
-------
np.ndarray
Labeled mask with labels propagated into the expansion region.
"""
output = np.zeros_like(seed_labeled, dtype=np.int32)
# output[seed_labeled > 0] = seed_labeled[seed_labeled > 0]
#
# kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
# expansion_mask = expansion_mask.astype(bool)
# iteration = 0
#
# while True:
# iteration += 1
# prev = output.copy()
#
# mask_to_fill = (output == 0) & expansion_mask
# dilated = cv2.dilate(output.astype(np.float32), kernel)
# dilated = dilated.astype(np.int32)
#
# output[mask_to_fill] = dilated[mask_to_fill]
#
# if np.array_equal(output, prev):
# break
# if iteration > 1000:
# if self.logger:
# self.logger.warning("Label propagation exceeded 1000 iterations.")
# break
#
# return output
distance, indices = distance_transform_edt(seed_labeled == 0, return_indices=True)
# Use the nearest labeled pixel for expansion mask locations
nearest_labels = seed_labeled[tuple(indices)]
# Fill only the expansion region with nearest labels
output[expansion_mask.astype(bool)] = nearest_labels[expansion_mask.astype(bool)]
# Preserve original seed labels
output[seed_labeled > 0] = seed_labeled[seed_labeled > 0]
return output
# Propagate labels: If performance is a concern, the dilation-based propagation loop can be optimized with a queue-based BFS flood-fill instead.
[docs]
class MultiClassObjectAnalysis(GetMasks):
"""
Analyze and expand multiple object contours across different classes using Voronoi constraints.
Constructs Voronoi diagrams to limit spatial expansion, assigns unique labels to each object,
and tracks class-wise and parent-wise mask lineage for downstream analysis.
Attributes
----------
multiple_contours : dict[str, list[np.ndarray]]
Input contours grouped by class.
height : int
Image height.
width : int
Image width.
save_path : str or None
Optional path to save outputs.
vor : scipy.spatial.Voronoi or None
Computed Voronoi diagram.
all_centroids : np.ndarray or None
Coordinates of centroids of input objects.
class_labels : list[str] or None
Class label for each object.
binary_masks : dict[str, np.ndarray]
Output binary masks by class and expansions.
labeled_masks : dict[str, np.ndarray]
Output labeled masks by class and expansions.
referenced_masks : dict[str, np.ndarray]
Output referenced masks mapping pixels back to parent objects.
"""
def __init__(self, get_masks_instance, multiple_contours: dict, save_path: str = None):
"""
Initialize MultiClassObjectAnalysis instance.
Parameters
----------
get_masks_instance : GetMasks
Instance of GetMasks class with base image properties.
multiple_contours : dict[str, list[np.ndarray]]
Dictionary mapping class names to lists of contours.
save_path : str, optional
Directory path to save outputs (default is None).
"""
super().__init__()
self.get_masks_instance = get_masks_instance
self.height = self.get_masks_instance.height
self.width = self.get_masks_instance.width
self.logger = self.get_masks_instance.logger
# Remove tumour/stroma mask references as per your note
self.multiple_contours = multiple_contours
self.masks = None
self.vor = None
self.list_of_polygons = None
self.class_labels = None
self.all_centroids = None
self.voronoi_regions = None
self.voronoi_vertices = None
self.save_path = save_path
for class_label, contours in self.multiple_contours.items():
for i, contour in enumerate(contours):
if contour.shape[0] < 4:
self.logger.warning(f"Skipping contour with less than 4 points for class '{class_label}'.")
continue
self.multiple_contours[class_label][i] = contour[::-1]
[docs]
@staticmethod
def voronoi_finite_polygons_2d(vor, radius=None):
"""
Reconstruct finite Voronoi polygons in 2D by clipping infinite regions.
Parameters
----------
vor : scipy.spatial.Voronoi
The original Voronoi diagram from scipy.spatial.
radius : float, optional
Distance to extend infinite edges (default is twice the maximum image dimension).
Returns
-------
regions : list[list[int]]
List of polygon regions as indices of vertices.
vertices : np.ndarray
Array of Voronoi vertices coordinates.
"""
if vor.points.shape[1] != 2:
raise ValueError("Requires 2D input")
new_regions = []
new_vertices = vor.vertices.tolist()
center = vor.points.mean(axis=0)
if radius is None:
radius = vor.points.ptp().max() * 2
# Map of all ridges for a point
all_ridges = {}
for (p1, p2), (v1, v2) in zip(vor.ridge_points, vor.ridge_vertices):
all_ridges.setdefault(p1, []).append((p2, v1, v2))
all_ridges.setdefault(p2, []).append((p1, v1, v2))
# Reconstruct finite polygons
for p1, region_index in enumerate(vor.point_region):
vertices = vor.regions[region_index]
if all(v >= 0 for v in vertices):
# Finite region
new_regions.append(vertices)
continue
ridges = all_ridges[p1]
new_region = [v for v in vertices if v >= 0]
for p2, v1, v2 in ridges:
if v1 >= 0 and v2 >= 0:
continue
t = vor.points[p2] - vor.points[p1] # tangent
t /= np.linalg.norm(t)
n = np.array([-t[1], t[0]]) # normal vector
midpoint = vor.points[[p1, p2]].mean(axis=0)
direction = np.sign(np.dot(midpoint - center, n)) * n
far_point = vor.vertices[v1 if v1 >= 0 else v2] + direction * radius
new_vertices.append(far_point.tolist())
new_region.append(len(new_vertices) - 1)
# Sort region counterclockwise
vs = np.array([new_vertices[v] for v in new_region])
c = vs.mean(axis=0)
angles = np.arctan2(vs[:, 1] - c[1], vs[:, 0] - c[0])
new_region = [new_region[i] for i in np.argsort(angles)]
new_regions.append(new_region)
return new_regions, np.asarray(new_vertices)
[docs]
def get_polygons_from_contours(self, contours: List[np.ndarray]) -> List[Polygon]:
"""
Convert contours into Shapely polygons.
Parameters
----------
contours : list[np.ndarray]
List of contour arrays of shape (N, 2).
Returns
-------
polygons : list[Polygon]
List of valid Shapely Polygon objects.
"""
polygons = []
for cnt in contours:
if cnt.shape[0] < 4:
continue # Too few points to form a polygon
coords = cnt.squeeze()
if coords.shape[0] < 4:
continue # Still too few after squeezing
# Ensure it's closed (first point == last point)
if not np.array_equal(coords[0], coords[-1]):
coords = np.vstack([coords, coords[0]])
try:
polygon = Polygon(coords)
if not polygon.is_valid or polygon.area == 0:
continue # Skip invalid or zero-area polygons
polygons.append(polygon)
except Exception:
continue # Defensive: skip any invalid contour
return polygons
[docs]
def derive_voronoi_from_contours(self) -> None:
"""
Compute a Voronoi diagram from centroids of contours.
Computes Voronoi regions and finite polygons clipped to a large radius.
Stores regions, vertices, class labels, and centroids for further processing.
Raises
------
ValueError
If no contours are available to derive the Voronoi diagram.
"""
all_contours = [contour for contour_points in self.multiple_contours.values() for contour in contour_points if contour.shape[0] >= 4]
if not all_contours:
raise ValueError("No contours found to derive Voronoi diagram.")
list_of_polygons = self.get_polygons_from_contours(all_contours)
centroids = []
class_labels = []
for class_label, contours in self.multiple_contours.items():
for contour in contours:
contour = contour.squeeze()
if contour is not None and len(contour) >= 3:
polygon = Polygon(contour)
centroids.append(polygon.centroid)
class_labels.append(class_label)
else:
self.logger.warning(f"Skipping contour with less than 4 points for class '{class_label}'.")
continue
if len(centroids) < 4:
# Not enough data to compute Voronoi
self.logger.warning("Not enough valid centroids for Voronoi diagram. Skipping Voronoi computation.")
self.list_of_polygons = list_of_polygons
self.class_labels = class_labels
self.all_centroids = np.array([(c.x, c.y) for c in centroids]) if centroids else None
self.vor = None
self.voronoi_regions = None
self.voronoi_vertices = None
return
all_centroids = np.array([(c.x, c.y) for c in centroids])
vor = Voronoi(all_centroids)
# Use finite polygons clipped to a large radius (image max dimension * 2)
regions, vertices = self.voronoi_finite_polygons_2d(vor, radius=max(self.height, self.width) * 2)
self.list_of_polygons = list_of_polygons
self.class_labels = class_labels
self.all_centroids = all_centroids
self.vor = vor
self.voronoi_regions = regions
self.voronoi_vertices = vertices
[docs]
def get_voronoi_mask(self, category_name: str) -> np.ndarray:
"""
Get a binary mask for the Voronoi region of a given category.
If Voronoi regions are not computed (e.g. too few centroids), returns a full mask.
Parameters
----------
category_name : str
The category/class name for which the mask is requested.
Returns
-------
mask : np.ndarray
Binary mask of shape (height, width) with Voronoi regions for the category.
"""
mask = np.zeros((self.height, self.width), dtype=np.uint8)
# If Voronoi could not be computed, default to full image for that category
if self.voronoi_regions is None or self.voronoi_vertices is None:
# Option 1: Allow expansion to go anywhere
mask[:, :] = 255
return mask
# Normal case
for idx, (label, region) in enumerate(zip(self.class_labels, self.voronoi_regions)):
if label != category_name:
continue
polygon = self.voronoi_vertices[region]
polygon[:, 0] = np.clip(polygon[:, 0], 0, self.width - 1)
polygon[:, 1] = np.clip(polygon[:, 1], 0, self.height - 1)
int_polygon = polygon.astype(np.int32)
if len(int_polygon) >= 3:
cv2.fillPoly(mask, [int_polygon], color=255)
return mask
[docs]
def expand_mask(self, mask: np.ndarray, expansion_distance: int) -> np.ndarray:
"""
Expand a binary mask by a given pixel distance using distance transform.
The returned mask corresponds to the expansion region excluding the original mask.
Parameters
----------
mask : np.ndarray
Binary input mask to expand.
expansion_distance : int
Number of pixels to expand the mask by.
Returns
-------
np.ndarray
Binary mask representing the expansion area only.
"""
if not np.any(mask):
return np.zeros_like(mask, dtype=np.uint8)
# Compute distance from the background to the object mask
dist_transform = distance_transform_edt(mask == 0)
# Select pixels within the expansion distance (excluding original mask)
expanded_mask = (dist_transform <= expansion_distance) & (mask == 0)
expanded_mask = expanded_mask.astype(np.uint8) # Convert to binary mask
return expanded_mask
[docs]
def generate_expanded_masks_limited_by_voronoi(
self,
expansion_distances: list[int]
) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray], dict[str, np.ndarray]]:
"""
Generate expanded masks for each object limited by their Voronoi regions.
For each class and its contours, original masks are created and then expanded
by the specified distances, clipped to the corresponding Voronoi region.
All expansions are labeled and tracked with parent IDs.
Parameters
----------
expansion_distances : list[int]
List of pixel distances for mask expansion rings.
Returns
-------
tuple of dict
- binary_masks: dict mapping mask names to binary masks.
- labeled_masks: dict mapping mask names to labeled masks with unique IDs.
- referenced_masks: dict mapping mask names to masks referencing parent object IDs.
"""
masks = {} # Step 1: Generate masks for each contour, and label objects
labeled_masks = {}
referenced_labeled_mask = np.zeros((self.height, self.width), dtype=np.int32)
parent_id_counter = 1 # unique ID for each original object across all classes
# Map from category -> list of (parent_id, mask)
original_masks_info = {}
# Create binary masks for each individual contour, label them, assign parent IDs
for category_name, contours in self.multiple_contours.items():
if not contours or all(c.shape[0] < 4 for c in contours):
empty_mask = np.zeros((self.height, self.width), dtype=np.uint8)
empty_labeled = np.zeros_like(empty_mask, dtype=np.int32)
key = f"{category_name}"
masks[key] = empty_mask
labeled_masks[key] = empty_labeled
original_masks_info[category_name] = []
# Add empty expansions too
for expansion_distance in expansion_distances:
exp_key = f"{category_name}_expansion_{expansion_distance}"
masks[exp_key] = empty_mask.copy()
labeled_masks[exp_key] = empty_labeled.copy()
category_masks = []
for contour in contours:
mask = np.zeros((self.height, self.width), dtype=np.uint8)
cv2.drawContours(mask, [contour], -1, 1, thickness=cv2.FILLED)
# Label connected components (should be 1 per mask but be safe)
labeled = label(mask > 0)
# Extract regionprops if needed, here we just assign parent_id directly
labeled_mask = np.zeros_like(labeled, dtype=np.int32)
# Assign the unique parent ID to all pixels in this object
labeled_mask[labeled > 0] = parent_id_counter
# Update global referenced mask
referenced_labeled_mask[labeled_mask > 0] = parent_id_counter
# Store original mask and label
masks[f'{category_name}_{parent_id_counter}'] = mask
labeled_masks[f'{category_name}_{parent_id_counter}'] = labeled_mask
category_masks.append((parent_id_counter, mask))
parent_id_counter += 1
original_masks_info[category_name] = category_masks
# Step 2: Generate expansions and label them, mapping back to parent IDs
expanded_masks = {}
expanded_labeled_masks = {}
for category_name, masks_info in original_masks_info.items():
voronoi_mask = self.get_voronoi_mask(category_name)
for parent_id, base_mask in masks_info:
previous_expansion_mask = np.zeros((self.height, self.width), dtype=np.uint8)
for expansion_distance in expansion_distances:
current_expansion_mask = self.expand_mask(base_mask.copy(), expansion_distance)
current_expansion_mask = cv2.bitwise_and(current_expansion_mask,
cv2.bitwise_not(previous_expansion_mask))
current_expansion_mask = cv2.bitwise_and(current_expansion_mask, voronoi_mask)
# Label this expanded mask (connected components)
labeled_expansion = label(current_expansion_mask > 0)
labeled_mask = np.zeros_like(labeled_expansion, dtype=np.int32)
# For each component in expansion assign a unique label encoding:
# parent_id * 1000 + expansion_distance (assuming expansion_distance < 1000)
# This allows tracing expansions to parent
# label_value = parent_id * 1000 + expansion_distance
label_value = parent_id
labeled_mask[labeled_expansion > 0] = label_value
# Update global referenced mask — careful to avoid overwriting originals
referenced_labeled_mask[labeled_mask > 0] = label_value
key = f'{category_name}_expansion_{expansion_distance}_parent_{parent_id}'
expanded_masks[key] = current_expansion_mask
expanded_labeled_masks[key] = labeled_mask
previous_expansion_mask = cv2.bitwise_or(previous_expansion_mask, current_expansion_mask)
# Combine all masks and labeled masks
masks.update(expanded_masks)
labeled_masks.update(expanded_labeled_masks)
# Step 3: Aggregate masks by class and expansion name
aggregate_binary = {}
aggregate_labeled = {}
aggregate_referenced = {}
for key, mask in masks.items():
parts = key.split('_')
if 'expansion' in parts:
category = parts[0]
expansion_distance = parts[2]
agg_key = f"{category}_expansion_{expansion_distance}"
else:
category = parts[0]
agg_key = category
if agg_key not in aggregate_binary:
aggregate_binary[agg_key] = np.zeros_like(mask)
aggregate_labeled[agg_key] = np.zeros_like(mask, dtype=np.int32)
aggregate_referenced[agg_key] = np.zeros_like(mask, dtype=np.int32)
aggregate_binary[agg_key] = cv2.bitwise_or(aggregate_binary[agg_key], mask)
aggregate_labeled[agg_key] = np.maximum(aggregate_labeled[agg_key], labeled_masks[key])
# Referenced mask is pulled from the global referenced_labeled_mask
aggregate_referenced[agg_key] = np.maximum(
aggregate_referenced[agg_key],
np.where(mask > 0, referenced_labeled_mask, 0)
)
# Final output
self.binary_masks = aggregate_binary
self.labeled_masks = aggregate_labeled
self.referenced_masks = aggregate_referenced
return self.binary_masks, self.labeled_masks, self.referenced_masks
[docs]
def plot_masks_with_voronoi(self,
mask_colors: Dict[str, Tuple[int, int, int]],
background_color: Tuple[int, int, int] = (255, 255, 255),
show: bool = True,
axes: Optional["matplotlib.axes.Axes"] = None,
figsize: Tuple[int, int] = (8, 8)
) -> Optional["matplotlib.axes.Axes"]:
"""
Plots the generated masks overlaid with Voronoi edges.
Args:
mask_colors (Dict[str, Tuple[int, int, int]]): Mapping from class name to RGB color.
background_color (Tuple[int, int, int], optional): RGB color for background. Defaults to white.
show (bool, optional): If True, displays the plot. Defaults to True.
axes (matplotlib.axes.Axes, optional): Existing axes to plot on.
figsize (Tuple[int, int], optional): Figure size for new plot.
Returns:
matplotlib.axes.Axes: The plot axes (if `axes` was provided).
"""
masks = self.binary_masks
background = np.full((self.height, self.width, 3), background_color, dtype=np.uint8)
fig, ax = plt.subplots(figsize=figsize) if axes is None else (None, axes)
legend_patches = []
seen_classes = set()
for mask_name, mask in masks.items():
# Identify base class: 'gd' or 'cd8' from names like 'gd_expansion_30_0'
base_class = mask_name.split('_')[0]
# Get color for this base class
color = np.array(mask_colors.get(base_class, (128, 128, 128)))
background[mask != 0] = color
# Add legend entry only once per base class
if base_class not in seen_classes:
legend_patches.append(mpatches.Patch(color=color / 255, label=base_class))
seen_classes.add(base_class)
ax.imshow(background, origin='lower')
# Draw Voronoi edges
if self.vor:
voronoi_plot_2d(self.vor, ax=ax, show_vertices=False, line_colors='black', line_alpha=0.6)
# Plot centroids (smaller dots)
if self.all_centroids is not None:
centroids = np.array(self.all_centroids)
ax.plot(centroids[:, 0], centroids[:, 1], '*', markersize=1, alpha=0.6)
# Add clean legend (gd, cd8)
ax.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left', bbox_transform=ax.transAxes)
if self.save_path:
save_path = os.path.join(self.save_path, 'masks_with_voronoi_edges.png')
plt.savefig(save_path, dpi=1000, bbox_inches='tight')
self.logger.info(f'Plot saved at {save_path}')
if show:
plt.show()
return ax if axes is not None else None