Comprehensive Multimodal Segment Anything
All-in-one skill covering foundation, model, image, segmentation. Includes structured workflows, validation checks, and reusable patterns for ai research.
Segment Anything Model (SAM) -- Comprehensive Multimodal Segmentation
Overview
A comprehensive skill for zero-shot image segmentation using Meta AI's Segment Anything Model (SAM). SAM enables segmenting any object in any image without task-specific training, using flexible prompt types including points, bounding boxes, and masks. Trained on the SA-1B dataset containing over 1.1 billion masks from 11 million images, SAM delivers state-of-the-art segmentation quality across domains -- from natural photos to medical imaging, satellite imagery, and microscopy. This skill covers the original SAM, SAM 2 for video segmentation, and integration with both the native library and HuggingFace Transformers.
When to Use
- Segmenting any object in images without task-specific training or fine-tuning
- Building interactive annotation and labeling tools with click-based prompts
- Generating high-quality training data for downstream vision models
- Processing medical, satellite, or domain-specific images with zero-shot transfer
- Creating automatic segmentation masks for entire images
- Building object cutout tools, background removal, or compositing pipelines
- Combining with text-based detectors (GroundingDINO) for text-prompted segmentation
Quick Start
# Install SAM from GitHub pip install git+https://github.com/facebookresearch/segment-anything.git # Required dependencies pip install opencv-python pycocotools matplotlib # Download checkpoint (ViT-H -- most accurate, 2.4 GB) wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
import numpy as np import cv2 from segment_anything import sam_model_registry, SamPredictor # Load model sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth") sam.to(device="cuda") # Create predictor and set image predictor = SamPredictor(sam) image = cv2.imread("photo.jpg") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) predictor.set_image(image) # Segment with a single point click input_point = np.array([[500, 375]]) input_label = np.array([1]) # 1 = foreground masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True, # Returns 3 candidate masks ) best_mask = masks[np.argmax(scores)] print(f"Best mask IoU score: {scores.max():.3f}")
Core Concepts
Architecture Overview
SAM uses a three-component design that separates heavyweight image encoding from lightweight prompt processing:
Input Image ──► Image Encoder (ViT) ──► Image Embeddings (computed once)
│
Prompts (points/boxes/masks) ──► Prompt Encoder ──► Prompt Embeddings
│
Image + Prompt Embeddings
│
Mask Decoder (lightweight transformer)
│
Output Masks + IoU Scores
The image encoder runs once per image and produces reusable embeddings. Multiple prompt queries can then be answered efficiently without re-encoding the image.
Model Variants
| Model | Checkpoint | Parameters | Size | Relative Speed | Accuracy |
|---|---|---|---|---|---|
| ViT-H | sam_vit_h_4b8939.pth | 636M | 2.4 GB | Slowest | Best |
| ViT-L | sam_vit_l_0b3195.pth | 308M | 1.2 GB | Medium | Very Good |
| ViT-B | sam_vit_b_01ec64.pth | 91M | 375 MB | Fastest | Good |
Prompt Types
| Prompt Type | Input Format | Best Use Case |
|---|---|---|
| Foreground Point | (x, y) with label 1 | Clicking on an object to select it |
| Background Point | (x, y) with label 0 | Excluding unwanted regions |
| Bounding Box | [x1, y1, x2, y2] | Selecting larger or ambiguous objects |
| Previous Mask | Low-res logits from prior prediction | Iterative refinement of results |
| Combined | Any mix of points, boxes, masks | Precise multi-cue segmentation |
Point Prompt Segmentation
import numpy as np from segment_anything import sam_model_registry, SamPredictor sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth") sam.to("cuda") predictor = SamPredictor(sam) predictor.set_image(image_rgb) # Single foreground point masks, scores, logits = predictor.predict( point_coords=np.array([[500, 375]]), point_labels=np.array([1]), multimask_output=True, ) # Multiple points: 2 foreground + 1 background for precision masks, scores, logits = predictor.predict( point_coords=np.array([[500, 375], [520, 390], [300, 200]]), point_labels=np.array([1, 1, 0]), multimask_output=False, # Single mask when prompts are unambiguous )
Bounding Box Prompts
# Box prompt: [x_min, y_min, x_max, y_max] input_box = np.array([425, 600, 700, 875]) masks, scores, logits = predictor.predict( box=input_box, multimask_output=False, )
Combined Prompts for Precision
# Combine box + point for maximum control masks, scores, logits = predictor.predict( point_coords=np.array([[500, 375]]), point_labels=np.array([1]), box=np.array([400, 300, 700, 600]), multimask_output=False, )
Iterative Mask Refinement
# First pass: coarse segmentation masks, scores, logits = predictor.predict( point_coords=np.array([[500, 375]]), point_labels=np.array([1]), multimask_output=True, ) # Second pass: refine using previous mask logits + additional points best_logit = logits[np.argmax(scores)] masks_refined, scores_refined, _ = predictor.predict( point_coords=np.array([[500, 375], [550, 400]]), point_labels=np.array([1, 0]), # Add background exclusion mask_input=best_logit[None, :, :], # Feed prior mask multimask_output=False, )
Automatic Mask Generation
Generate segmentation masks for every object in an image without manual prompts:
from segment_anything import SamAutomaticMaskGenerator mask_generator = SamAutomaticMaskGenerator( model=sam, points_per_side=32, # Density of point grid (32x32 = 1024 points) pred_iou_thresh=0.88, # Minimum predicted IoU quality stability_score_thresh=0.95, # Mask stability threshold crop_n_layers=1, # Multi-scale cropping layers crop_n_points_downscale_factor=2, min_mask_region_area=100, # Remove masks smaller than 100 pixels ) masks = mask_generator.generate(image_rgb) # Each mask dict contains: # - "segmentation": np.ndarray boolean mask # - "bbox": [x, y, w, h] # - "area": pixel count # - "predicted_iou": model confidence score # - "stability_score": robustness under perturbation # - "point_coords": the generating sample point # Sort by area and filter large_masks = sorted(masks, key=lambda m: m["area"], reverse=True) high_quality = [m for m in masks if m["predicted_iou"] > 0.92]
HuggingFace Transformers Integration
import torch from PIL import Image from transformers import SamModel, SamProcessor model = SamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") model.to("cuda") image = Image.open("photo.jpg") # Point-prompted segmentation input_points = [[[450, 600]]] # Batch of point sets inputs = processor(image, input_points=input_points, return_tensors="pt") inputs = {k: v.to("cuda") for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) # Post-process to original image resolution masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu(), ) # Box-prompted segmentation input_boxes = [[[400, 300, 700, 600]]] inputs = processor(image, input_boxes=input_boxes, return_tensors="pt") inputs = {k: v.to("cuda") for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs)
Text-Prompted Segmentation with GroundingDINO
Combine SAM with a text-based detector for open-vocabulary segmentation:
from groundingdino.util.inference import load_model, predict from segment_anything import sam_model_registry, SamPredictor import cv2 import numpy as np # Step 1: Detect objects with text prompt using GroundingDINO dino_model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "weights/groundingdino_swint_ogc.pth") image_source, image_transformed = load_image("photo.jpg") boxes, logits, phrases = predict( model=dino_model, image=image_transformed, caption="cat . dog . person", box_threshold=0.35, text_threshold=0.25, ) # Step 2: Segment detected boxes with SAM sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth") sam.to("cuda") predictor = SamPredictor(sam) image_rgb = cv2.cvtColor(cv2.imread("photo.jpg"), cv2.COLOR_BGR2RGB) predictor.set_image(image_rgb) h, w, _ = image_rgb.shape for box in boxes: # Convert normalized box to pixel coordinates box_px = (box * np.array([w, h, w, h])).astype(int) masks, scores, _ = predictor.predict(box=box_px, multimask_output=False) # Use masks for downstream tasks
ONNX Export for Edge Deployment
from segment_anything.utils.onnx import SamOnnxModel import torch # Export mask decoder to ONNX (lightweight, ~15MB) onnx_model = SamOnnxModel(sam, return_single_mask=True) dummy_inputs = { "image_embeddings": torch.randn(1, 256, 64, 64), "point_coords": torch.randint(0, 1024, (1, 2, 2), dtype=torch.float), "point_labels": torch.randint(0, 2, (1, 2), dtype=torch.float), "mask_input": torch.randn(1, 1, 256, 256), "has_mask_input": torch.tensor([1], dtype=torch.float), "orig_im_size": torch.tensor([1024, 1024], dtype=torch.float), } torch.onnx.export( onnx_model, tuple(dummy_inputs.values()), "sam_decoder.onnx", input_names=list(dummy_inputs.keys()), output_names=["masks", "iou_predictions", "low_res_masks"], )
Configuration Reference
SamPredictor.predict Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
point_coords | np.ndarray | None | Nx2 array of (x, y) point coordinates |
point_labels | np.ndarray | None | N-length array, 1=foreground, 0=background |
box | np.ndarray | None | Length-4 array [x1, y1, x2, y2] |
mask_input | np.ndarray | None | 1x256x256 low-res mask from prior prediction |
multimask_output | bool | True | Return 3 candidate masks (True) or 1 (False) |
return_logits | bool | False | Return raw logits instead of binary masks |
SamAutomaticMaskGenerator Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
points_per_side | int | 32 | Grid sampling density per side |
points_per_batch | int | 64 | Batch size for point processing |
pred_iou_thresh | float | 0.88 | Minimum predicted IoU to keep mask |
stability_score_thresh | float | 0.95 | Minimum stability score |
stability_score_offset | float | 1.0 | Offset for stability calculation |
crop_n_layers | int | 0 | Number of multi-scale crop layers |
crop_n_points_downscale_factor | int | 1 | Point reduction per crop layer |
min_mask_region_area | int | 0 | Remove masks below this pixel area |
output_mode | str | "binary_mask" | "binary_mask", "uncompressed_rle", or "coco_rle" |
Best Practices
- Compute image embeddings once -- Call
predictor.set_image()a single time per image, then run multiple prompt queries against the cached embeddings for interactive workflows. - Use
multimask_output=Truefor ambiguous prompts -- When a single point could match multiple objects, get three candidates and pick the highest-scoring one. Switch tomultimask_output=Falsewhen prompts are specific (box + points). - Start with ViT-B for prototyping -- The 375 MB ViT-B model is 3-4x faster than ViT-H and sufficient for initial development. Upgrade to ViT-H only when accuracy is critical.
- Combine prompts for precision -- A bounding box plus a foreground point consistently outperforms either prompt type alone, especially for irregularly shaped objects.
- Filter automatic masks aggressively -- Raise
pred_iou_threshandstability_score_threshto reduce noisy small masks. Usemin_mask_region_areato discard tiny fragments. - Use iterative refinement for difficult objects -- Feed the best logits from a first prediction back as
mask_inputalong with additional corrective points for complex shapes. - Leverage ONNX export for deployment -- Export only the lightweight mask decoder (~15 MB) to ONNX for browser or edge deployment. Pre-compute image embeddings server-side.
- Pair with GroundingDINO for text prompts -- SAM itself has no text understanding. Combine it with GroundingDINO or OWLv2 for open-vocabulary segmentation driven by natural language.
- Apply GPU memory management -- For batch processing large images, move the model to GPU only during inference and clear CUDA cache between images using
torch.cuda.empty_cache(). - Consider SAM 2 for video -- If your use case involves video or temporal consistency, use SAM 2 which extends the architecture with memory-based tracking across frames.
Troubleshooting
Model runs out of GPU memory with large images: SAM resizes images internally to 1024x1024. If memory is still tight, use ViT-B instead of ViT-H, or process on CPU for non-interactive workloads.
Automatic mask generator produces too many overlapping masks:
Increase stability_score_thresh to 0.97 and pred_iou_thresh to 0.92. Reduce points_per_side from 32 to 16 for coarser coverage.
Point prompt selects the wrong object: Add a background point (label=0) on the unwanted object to exclude it. Alternatively, provide a bounding box around the intended target.
ONNX export fails with custom model modifications:
The ONNX exporter expects the standard SAM architecture. If you have modified layers, trace with torch.jit.trace first or manually adapt the SamOnnxModel wrapper.
HuggingFace Transformers gives different results than native SAM:
The Transformers implementation normalizes inputs differently. Ensure you use SamProcessor for preprocessing and post_process_masks for output conversion to match native behavior.
Reviews
No reviews yet. Be the first to review this template!
Similar Templates
Full-Stack Code Reviewer
Comprehensive code review skill that checks for security vulnerabilities, performance issues, accessibility, and best practices across frontend and backend code.
Test Suite Generator
Generates comprehensive test suites with unit tests, integration tests, and edge cases. Supports Jest, Vitest, Pytest, and Go testing.
Pro Architecture Workspace
Battle-tested skill for architectural, decision, making, framework. Includes structured workflows, validation checks, and reusable patterns for development.