当前位置: 首页 > news >正文

Sam算法基本原理解析

        Sam是Segmantation anything model的缩写,也就是可以分割一切的模型,2023年4月份横空出世,记忆中这是第一个引爆CV圈的大模型,如今两年过去了,来拜读一下Sam的源码。

1.Sam算法推理流程

        Sam算法的推理流程可以用下面这张图来概括,输入一张图,经过3个Step,可以按照中途输入的提示点(五角星),其中表示标记的正向点(前景),表示标记的负向点(背景),最后输出正向点解码得到的分割结果。

        

        Step1就是载入模型权重,这是官方使用极大量的数据集训练出来的模型,在不进行微调训练的前提下也能对各种场景下的物体有很好的分割效果。

        我使用的是sam_vit_b_01ec64.pth模型,有357MB,这是最小的一个模型,官方提供了"vit_h"、"vit_l"、"vit_b"三种模型,,模型大小是从大到小。另外,从模型名字我们还可以知道,sam模型里面会包含vit网络,也就是应用在视觉领域的Transformer。

        Step2包括图像预处理(Preprocess)和图像编码网络(Image_encoder)。

        Step3包括提示词编码网络(Prompt_encoder)和掩码解码网络(Mask_decoder)。

       其中Step2和Step3中包含最重要的三个部分:Image_encoder、Prompt_encoder和Mask_decoder,他们的关系可以用下图来表示:

        Image_encoder负责对预处理后的图片进行图像编码,得到(1,256,64,64)的特征图,此时features特征图的分辨率已经从(1024,1024)缩放到(64,64);

        Prompt_encoder负责对输入的提示(可以是point,box和mask,我这里暂时只对提示点作为例子说明),提示点包括了点的坐标(point_coords),以及点的类别(point_labels)。点的坐标是相对于原图的(x,y)坐标,比如输入图像的宽高分辨率是( 967,546),那么输入的点坐标就是基于( 967,546)分辨率的坐标,而不是经过预处理后的(1024,1024)分辨率的坐标。点的类别就是1和0,分别表示前景和背景。Prompt_encoder会输出稀疏嵌入sparse_embeddings和密集嵌入dense_embeddings。

        Mask_decoder主要就是对Image_encoder和Prompt_encoder的输出结果进行解码,最后得到掩码结果low_res_masks和置信度iou_predictions(这个暂时就先理解为置信度吧)。这个是multimask_output=False的结果,后面再详细看看multimask_output=True的区别。

2.Preprocess

        Preprocess过程就是要把任意输入图像的分辨率调整成统一的(1024,1024)分辨率,同时将图片的数据进行标准化,即将均值和方差标准化为0和1。

        分辨率调整有点像YOLO中的letterbox,先选择长边然后等比例调整到1024,然后在短边填充黑边,像素值为(0,0,0),如下图所示。

3.Image Encoder

3.1patch embedding

        在正式对图片数据进行Transformer编码前,还需要对图片进行patch_embed处理和pos_embed处理,这个是Vit的常规操作,这部分推荐去看看大佬的博客(Vision Transformer详解-CSDN博客)。

       下面我来谈谈我对patch_embed这部分的理解:

        还是先看看输入和输出,输入是一张经过预处理后的3通道且分辨率为(1024,1024)的图片,Shape为(1,3,1024,1024);输出是Shape为(1,64,64,768)的特征图,在这里我们要清楚,1x3x1024x1024=1x64x64x768,说明了经过patch_embed只是对输入的特征图的Shape进行了某种变换,并没有让他的元素有任何损失,下面来看看进行了什么样的Shape变换。

        现在假设我想将图片划分出一个个宽高都为16像素的正方形,那么宽高为(1024,1024)的图片在水平方向有1024/16=64个正方形,垂直方向有1024/16=64个正方形,且每个正方形的Shape为(16,16,3),拉直后变成(768,1)的向量,因此我们就能看到Shape为(1,3,1024,1024)的特征图经过patch_embed后得到Shape为(1,64,64,768)的特征图。

        接着再来看看是怎么实现的,实际上就是通过一个普通的卷积核实现的,代码如下。卷积核的大小为k=16x16,卷积核的步长s=16,卷积核的通道数为c=768.通过这样一个卷积就可以对Shape为(1,3,1024,1024)特征图进行划分出若干个正方形,并且正方形之间没有重合,最后的输出特征图Shape为(1,768,64,64),再交换下维度就变成了(1,64,64,768)。

 #Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)

        接着就是pos_embed,这块是给(1,64,64,768)特征图对应位置上的元素加上一个绝对位置编码,这个相对位置编码是经过训练而学习来的,在代码中对应的是self.pos_embed.

# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))        

        接下来就是Image Encoder的重头戏了—Transformer Encoder,输入Shape为(1,64,64,768)的特征图,经过Transformer Encoder之后,将得到Shape为(1,256,64,64),下一节将详细介绍下里面的细节。

3.2Transformer Encoder

        未完待续。。。

 

相关文章:

  • CPU与GPU的功能与区别解析
  • 运维面试情景题:如果有一块新的硬盘要加入机架如何配置;如果新加了一台服务器,如何配置安全措施
  • DeepSeek预训练追求极致的训练效率的做法
  • 2025.04.26-淘天春招笔试题-第三题
  • MQL5教程 06 EA开发实战
  • 【OSG学习笔记】Day 11: 文件格式与数据交换
  • Dify中的文本分词处理技术详解
  • 财务管理域——企业风控系统设计
  • Channel如何安全地尝试发送数据
  • win11右键菜单改回win10模式
  • 基于 RAG 的 Text2SQL 全过程的 Python 实现详解,结合 LangChain 框架实现自然语言到 SQL 的转换
  • 20250426在ubuntu20.04.2系统上解决问题mkfs.exfat command not found
  • function,bind,lambda的用法
  • 力扣刷题Day 31:删除链表的倒数第N个结点(19)
  • 数据库原理(1)
  • 贝叶斯算法学习
  • 【LeetCode 热题 100】链表 系列
  • [实战] 卡尔曼滤波:原理、推导与卫星导航应用仿真(完整代码)
  • 深入剖析 TypeScript 基础类型:string、number、boolean 的声明与使用
  • lnmp1.5+centos7版本安装php8
  • 退休11年后,71岁四川厅官杨家卷被查
  • 娱见 | 为了撕番而脱粉,内娱粉丝为何如此在乎番位
  • 市场监管总局发布《城镇房屋租赁合同(示范文本)》
  • “穿越看洪武”,明太祖及其皇后像台北故宫博物院南园展出
  • 特朗普:不谋求第三个总统任期,中意万斯鲁比奥“接棒”
  • 17家城商行去年年报盘点:西安银行营收增速领跑,青岛银行净利增速领跑