“Segment Anything Model(SAM) Unveiled: A Deep Dive into Object Segmentation with PyTorch and SAM”

Shankar Shamra
8 min readNov 13, 2023

--

This blog takes you on an immersive journey, unraveling the nuances of setting up the environment, installing dependencies, and harnessing SAM’s capabilities for advanced object segmentation.

The Segment Anything Model (SAM) is a deep learning model designed for image segmentation, trained on an extensive dataset known as the Segment Anything dataset. Operating on the U-Net architecture, SAM is a single-shot object detection model that distinguishes itself by its prompt-based approach. Drawing inspiration from Language Model (LLM) paradigms, SAM allows users to guide segmentation tasks through prompts, such as clicking on a point, drawing a boundary box, or sketching a rough mask on an object. This pioneering model serves as a foundation for image segmentation, bridging the gap between natural language prompts and visual understanding. SAM’s versatility is evident in its ability to handle various types of segmentation tasks, adapting to new challenges through user-defined prompts. Trained to provide valid segmentation masks even in ambiguous scenarios, SAM aligns with the prompting philosophy seen in LLMs, where responses are generated based on unclear prompts. Additionally, SAM’s promptable segmentation extends its utility to novel tasks, making it a valuable asset in combination with existing object detectors for diverse applications.

Building the PyTorch Playground

Begin your exploration by setting up the PyTorch playground. Navigate through the essentials of importing libraries, configuring your workspace, and establishing the groundwork for seamless development.


import cv2 #import computer vision lib
import numpy as np #import numpy as np
import torch #import pytorch
import matplotlib.pyplot as plt
import os #importing operating system , this module provides a way to interact with operating system including accesssing env var and changing the working directory
HOME = os.getcwd() #get the current work directory using the getcwd(), and assign it to home var
print("HOME:", HOME)
%cd {HOME}. #this line changes the current working directory to the value of the HOME


using_colab = True
if using_colab:
import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
import sys
!{sys.executable} -m pip install opencv-python matplotlib onnx onnxruntime
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

!mkdir images
!wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg

!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.onnx import SamOnnxModel
import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic

-Let’s explore three essential functions designed to breathe life into object segmentation results: show_mask, show_points, and show_box.

show_mask: Coloring the Canva

def show_mask(mask, ax):
# Takes a binary mask and an axis object, displaying the mask as an image on the axis.
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]# Create a colored mask_image overlay for highlighting regions or objects.
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)# Display the colored mask_image on the axis.
ax.imshow(mask_image)
  • This function encapsulates the magic of turning a binary mask into a visual masterpiece. The resulting mask_image serves as an overlay, allowing you to spotlight specific areas on an image.

show_points: Plotting Stars in the Sky

def show_points(coords, labels, ax, marker_size=375):
# Takes coordinates, labels, and an axis object, plotting points based on their labels.
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]

# Plot positive points in green and negative points in red, using stars as markers.
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

This function transforms coordinates and labels into a celestial display of stars. Positive points shine in green, negatives twinkle in red, all contributing to the visual narrative.

show_box: Framing the Scene

def show_box(box, ax):
# Takes a bounding box and an axis object, displaying the box on the axis.
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]

# Add a transparent rectangle to represent the bounding box.
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=1))

Here, the function paints a transparent rectangle on the canvas, framing the specified region with a touch of green.

Setting the Stage with a Pre-Trained ViT-H Model

checkpoint = "/content/sam_vit_h_4b8939.pth"
model_type = "vit_h"

# Checkpoint contains the path to a pre-trained ViT-H model.
# Let's initialize our SAM model using the specified checkpoint and model type.
sam = sam_model_registry[model_type](checkpoint=checkpoint)

Here, the checkpoint variable serves as our ticket to a pre-trained ViT-H model. SAM's model registry effortlessly maps the model type to its corresponding class.

SAM Model Instance

# Now, we create an instance of the SAM model using the initialized parameters.
predictor = SamPredictor(sam)

From SAM to ONNX: A Journey of Model Export and Quantization

We navigate through the intricacies of exporting a Segment Anything Model (SAM) to the Open Neural Network Exchange (ONNX) format, followed by the transformative process of quantization.

Exporting SAM to ONNX

onnx_model_path = "sam_onnx_example.onnx"
onnx_model = SamOnnxModel(sam, return_single_mask=True)

# Defining dynamic axes for variable-length input data
dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
}

# Preparing dummy inputs for export
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}

# Specifying output tensor names
output_names = ["masks", "iou_predictions", "low_res_masks"]

# Exporting the ONNX model
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
with open(onnx_model_path, "wb") as f:
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=17,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)

Quantizing the ONNX Model

onnx_model_quantized_path = "sam_onnx_quantized_example.onnx"

# Quantizing the ONNX model
quantize_dynamic(
model_input=onnx_model_path,
model_output=onnx_model_quantized_path,
optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)

onnx_model_path = onnx_model_quantized_path

Quantizing is performed to enhance its inference speed and reduce memory footprint. The resulting quantized ONNX model is now equipped for deployment in resource-constrained environments.

Uploading image


image = cv2.imread('/content/230121121737-01-nfl-playoffs-preview.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('on')
plt.show()

#creating an ONNX runtime session using the onnxruntime.InferenceSession() func and loads the ONNX model using the sam model
ort_session = onnxruntime.InferenceSession(onnx_model_path)


sam.to(device='cuda')
predictor = SamPredictor(sam)#this creates a SamPredictor object using the sam model
predictor.set_image(image)#set image() method of SamPredictor object is used to set the image
image_embedding = predictor.get_image_embedding().cpu().numpy()# get_image_embedding() method of SamPredictor object is used to get the image embedding , the cpu() method is used to move the tensor to cpu ,and the numpy() method is used to convert the tensor to a nnumpy array
image_embedding.shape

Selecting a single mask


input_point = np.array([[1648,640]]) #var which is a numpy array with shape(1,2) that contains x,y coordinates of a point
input_label = np.array([1])# this var is a numpy array with shape(1,) that contains the label of the point

#Add a batch index, concatenate a padding point, and transform.
#this lines add batch index to the point coordinates and labels , concatenate a padding point with label -1 and transform the point coordinates using the predictor.transform.apply_coord() method
#onnx_coord and onnx_label are created by adding a batch index to the point coordinates and labels , and concatenate a padding point with label -1 .
#the batch index is added by introducing a new axis using none . this is done because DL model usually process daa in batches so even if there's only one point , it needs to be in the shape of a batch
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
#the coordinates are then transformed using predictor.transform.apply_coords methid , this is a normalization or scaling operation to ensure the coordinates are in a suitable range for the model
onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)

#Create an empty mask input and an indicator for no mask.

onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
#Package the inputs to run in the onnx model

ort_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}

#Predict a mask and threshold it.

masks, _, low_res_logits = ort_session.run(None, ort_inputs)
masks = masks > predictor.model.mask_threshold
print(masks.shape)
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()

Selecting multiple mask


#Example mask input
input_point = np.array([[1648,640],[1609,931]])
input_label = np.array([1,1])

# Use the mask output from the previous run. It is already in the correct form for input to the ONNX model.
onnx_mask_input = low_res_logits
#Transform the points as in the previous example.

onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)

onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
#The has_mask_input indicator is now 1.

onnx_has_mask_input = np.ones(1, dtype=np.float32)
#Package inputs, then predict and threshold the mask.

ort_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}

masks, _, _ = ort_session.run(None, ort_inputs)
masks = masks > predictor.model.mask_threshold
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
#To exclude the helmet and specify just the face , a background point (with label 0, here shown in red) can be supplied.

input_point = np.array([[1707, 312], [1711, 206]])
input_label = np.array([1, 0])

mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

Selecting object with box

#Specifying a specific object with a box
input_box = np.array([1302,178,1961,1631])
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()

Selecting object with box and point

#Combining points and boxes
input_box = np.array([1302,178,1961,1631])
input_point = np.array([[1648,640],[1609,931]])
input_label = np.array([1,1])
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()

Selecting multiple objects with box


#Batched prompt inputs
input_boxes = torch.tensor([
[1334, 171, 1919, 1618],
[1103,790,1174,854],
[1992,778,2835,1687,],
[689,1536,871,1659],
], device=predictor.device)


transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
masks.shape # (batch_size) x (num_predicted_masks_per_input) x H x W


plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca())
for box in input_boxes:
show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()

--

--