Source code for selective_blur.blurring.midasmodel
import torch
import numpy as np
[docs]
class MidasModel:
"""
Class to load Intel MiDaS model and make predictions
"""
def __init__(self, model: str = "DPT_Hybrid") -> None:
"""
Args:
model (str): MiDaS model version, available options are ["MiDaS_small", "DPT_Hybrid", "DPT_Large"]. Defaults to "DPT_Hybrid".
Raises:
ValueError: if model is not one of ["MiDaS_small", "DPT_Hybrid", "DPT_Large"]
"""
if model not in ["MiDaS_small", "DPT_Hybrid", "DPT_Large"]:
raise ValueError(
"Available models are ['MiDaS_small', 'DPT_Hybrid', 'DPT_Large']"
)
self.midas = torch.hub.load("intel-isl/MiDaS", model)
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
self.midas.to(self.device)
transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
self.transform = transforms.small_transform
[docs]
def predict_adjust(self, image: np.ndarray) -> np.ndarray:
"""
Args:
image (np.ndarray)
Returns:
np.ndarray: depth map of the image
"""
imgbatch = self.transform(image).to(self.device)
with torch.no_grad():
prediction = self.midas(imgbatch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=image.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
output = prediction.cpu().numpy()
return output