Source code for selective_blur.masking.mask_tools

import matplotlib.pyplot as plt
from scipy.ndimage import label
import ipywidgets as widgets
from jupyter_bbox_widget import BBoxWidget
import io
from PIL import Image
from IPython.display import display, clear_output
import numpy as np
import cv2


[docs] def show_image_and_mask(image: np.ndarray, mask: np.ndarray): """ Plot image and its masked version. Args: image (np.ndarray) mask (np.ndarray) """ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.title("Original Image") plt.imshow(image_rgb) plt.axis("off") plt.subplot(1, 2, 2) plt.title("Masked Image") plt.imshow(image_rgb * mask[:, :, np.newaxis].astype(bool)) plt.axis("off") plt.tight_layout() plt.show()
[docs] class MaskEditor: """ Class to refine the mask generated automatically with SAM2. """ def __init__(self, mask: np.ndarray): """ Args: mask (np.ndarray): bool array with mask of the original image """ self.mask = mask self.original_mask = mask self.add = True def _denoise_mask(self, threshold: float): """ Function to clean automatically the mask by removing small segmented areas. The principle is the following: 1. dimension of the closed True areas in the mask is computed 2. the areas are placed in increasing order and divided into deciles 3. small areas are converted into False, so that they wil no be included in the mask anymore. The parameter that defines the threshold under which areas are considered "small" is threshold, which refers to the decile of the area distribution. Args: threshold (float): decile of the True areas under which they will be converted to False """ self.mask = self.original_mask if self.mask.dtype != "float32": self.mask = self.mask.astype("float32") self.mask = cv2.medianBlur(self.mask, 5) self.mask = self.mask.astype("bool") labeled_mask, num_features = label(self.mask) areas = np.array([(labeled_mask == i + 1).sum() for i in range(num_features)]) quantile_threshold = np.quantile(areas, threshold) for i, area in enumerate(areas): if area < quantile_threshold: self.mask[labeled_mask == i + 1] = False plt.imshow(self.mask, cmap="gray", interpolation="none") plt.axis("off") plt.tight_layout() plt.show()
[docs] def auto_denoise(self): """ Function to clean automatically the mask by removing small segmented areas. The principle is the following: 1. dimension of the closed True areas in the mask is computed 2. the areas are placed in increasing order and divided into deciles 3. small areas are converted into False, so that they wil no be included in the mask anymore. The parameter that defines the threshold (selected via interactive slider) under which areas are considered "small" is threshold, which refers to the decile of the area distribution. """ def _denoise_from_selection(change): clear_output(wait=True) display(intensity_slider) self._denoise_mask(change["new"]) intensity_slider = widgets.FloatSlider( value=0.5, min=0.0, max=1.0, step=0.1, description="Denoise level:", continuous_update=False, ) intensity_slider.observe(_denoise_from_selection, names="value") display(intensity_slider) self._denoise_mask(intensity_slider.value)
def _mask_to_bytes(self, mask: np.ndarray): if mask.dtype != "np.uint8": mask = (mask * 255).astype(np.uint8) mask_img = Image.fromarray(mask) buffer = io.BytesIO() mask_img.save(buffer, format="PNG") return buffer.getvalue() def _toggle_widget(self): def _on_switch_change(change): self.add = change["new"] if self.add != self.prev_toggle: self.widget.image_bytes = self._mask_to_bytes(self.widget.temp_mask) clear_output(wait=True) display(self.menu) display(self.widget) self.widget.mask = self.widget.temp_mask self.prev_toggle = self.add self.switch = widgets.ToggleButton( value=True, description="Add-Mode", disabled=False, button_style="", tooltip="Toggle Switch", icon="check", ) self.switch.observe(_on_switch_change, names="value") return self.switch def _save_mask_widget(self): def _on_save_button_clicked(b): self.mask = self.widget.temp_mask clear_output(wait=True) print("Mask correctly saved") save_button = widgets.Button( description="Save", disabled=False, button_style="", tooltip="Click to save", icon="save", ) save_button.on_click(_on_save_button_clicked) return save_button def _menu_widget(self): toggle_button = self._toggle_widget() save_button = self._save_mask_widget() hbox = widgets.HBox([toggle_button, save_button]) return hbox
[docs] def manual_edit(self): """ Allows to manually remove or add elements from the mask using bounding boxes (Jupyter widget) """ self.prev_toggle = True def manual_clean(): clear_output(wait=True) display(self.menu) display(self.widget) self.widget.temp_mask = self.widget.mask.copy() for box in self.widget.bboxes: self.widget.temp_mask[ box["y"] : (box["y"] + box["height"]), box["x"] : (box["x"] + box["width"]), ] = self.add plt.imshow(self.widget.temp_mask, cmap="gray", interpolation="none") plt.axis("off") plt.tight_layout() plt.show() self.widget = BBoxWidget() self.widget.on_submit(manual_clean) self.widget.mask = self.mask self.widget.temp_mask = self.mask self.widget.image_bytes = self._mask_to_bytes(self.mask) self.menu = self._menu_widget() display(self.menu) display(self.widget)