Source code for selective_blur.masking.segmentation
import cv2
import torch
import io
import numpy as np
import supervision as sv
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from jupyter_bbox_widget import BBoxWidget
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display, clear_output
[docs]
class Segmenter:
"""
Class to segment an image from user selection in a jupyter notebook leveraging Segment Anything 2
"""
def _config_cuda(self):
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def _setup_predictor(self, model):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = f"checkpoints/sam2_hiera_{model}.pt"
config = f"sam2_hiera_{model[0]}.yaml"
sam2_model = build_sam2(
config, checkpoint, device=device, apply_postprocessing=False
)
self.predictor = SAM2ImagePredictor(sam2_model)
def __init__(self, image_path: str, model: str = "tiny"):
"""
Args:
image_path (str): path to the image that will be segmented
model (str, optional): SAM2 model version (choice between "tiny", "small", "base_plus", "large"). Defaults to "tiny".
Raises:
ValueError: if model is not one of ["tiny", "small", "base_plus", "large"]
"""
self.image_bgr = cv2.imread(image_path)
self.image_rgb = cv2.cvtColor(self.image_bgr, cv2.COLOR_BGR2RGB)
self._config_cuda()
if model not in ["tiny", "small", "base_plus", "large"]:
raise ValueError(
"Available models are ['tiny', 'small', 'base_plus', 'large']"
)
self._setup_predictor(model)
def _image_to_bytes(self):
img = Image.fromarray(self.image_rgb)
buffer = io.BytesIO()
img.save(buffer, format="PNG")
self.image_bytes = buffer.getvalue()
[docs]
def select_from_image(self):
"""
Allows to select with a point or a bounding box the subjects that will be segmented from an image (Jupyter Widget).
After user submits the selection, three candidates masks (np.ndarray with dtype bool), saved in self.masks are displayed
along with their confidence scores.
"""
def segmentation():
nonlocal widget
clear_output(wait=True)
display(widget)
boxes = widget.bboxes
input_point = np.array([[box["x"], box["y"]] for box in boxes])
input_label = np.ones(input_point.shape[0])
self.predictor.set_image(self.image_rgb)
self.masks, scores, logits = self.predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
sv.plot_images_grid(
images=self.masks,
titles=[f"score: {score:.2f}" for score in scores],
grid_size=(1, 3),
size=(12, 12),
)
self._image_to_bytes()
widget = BBoxWidget()
widget.on_submit(segmentation)
widget.image_bytes = self.image_bytes
display(widget)
[docs]
def choose_mask(self, mask_number: int):
"""
Args:
mask_number (int): choice between the three mask proposals displayed when select_from_image is executed (choices available are 0, 1 or 2)
Raises:
ValueError: if mask_number is not 0, 1 or 2
"""
if mask_number > len(self.masks) - 1:
raise ValueError("Choice can be either 0, 1 or 2")
else:
self.best_mask = self.masks[mask_number]
self.best_mask = self.best_mask.squeeze()