当前位置: 首页 > news >正文

【SAM2代码解析】training部分-1总体概述

总览

1.1 文件总览

training folder保存了训练SAM2的相关代码,该代码允许使用者们用他们自己的数据集(图像、视频或两者一起)去微调SAM2
文件结构如下:

  • dataset文件夹:保存了包含图像和视频数据集以及数据加载器类及其转换
  • mode文件夹:包含用于训练、微调的主要模型类(SAM2Train)。SAM2Train继承自SAM2Base模型,并提供用于启用SAM2训练或微调的函数。它还接受所有用于模拟用户提示的训练时参数。
  • utils:此文件夹包含训练工具,例如日志记录器和分布式训练工具
  • scripts:此文件夹包含用于提取SA-V数据集帧以用于训练的脚本
  • loss_fns.py:此文件夹包含训练的主要损失类
  • optimizer.py:此文件夹包含支持任意调度器的所有优化器工具
  • trainer.py:此文件包含Trainer类,接收所有的Hydra可配置模块(模型、优化器、数据集等),并实现主要的训练、评估循环。
  • train.py:此脚本用于启动训练作业,它支持单节点和多节点作业。可运行python training/train.y -h的方式查看有关使用方法。
    • 在这里插入图片描述

1.2 训练/微调步骤

以MOSE数据集为例

  • 1、运行pip install -e ".[dev]"安装训练所需的包
  • 2、在configs/sam2.1_training/sam2.1_hiera_b+MOSE_finetune.yaml中设置MOSE数据集的路径
    在这里插入图片描述

dataset:
#数据集路径
img_folder:null #MOSE JPEGImages 文件夹路径
gt_folder:null #MOSE Annotations 文件径
file_list txt:null # 可选路径,包含用于洲练的视频子集的文件列表

  • 3、使用GPU在MOSE上微调基础模型

python training/train.py -c configs/sam2.1_training/sam2.1_hiera_b+MOSE_finetune.yaml --use-cluster 0

  • 4、可以使用存储在实验日志目录下的tensorboard/文件夹中的TensorBoard日志来监控训练损失。我们还为评估目的提供了一个样本验证拆分。
  • 5、训练完成后,可以使用实验日志目录中checkpoints/文件夹中保存的新检查点
  • 6、在图像和视频上进行训练:代码支持在图像和视频上进行训练,我们提供了用于加载SA-1B作为示例图像数据集、SA-V作为示例视频数据集以及任何DAVIS风格的视频数据集的类。注意在SA-V上进行训练,必须先使用提供的提取脚本将所有视频提取为JPEG帧。

data:
train:
target: training.dataset.sam2_datasets.TorchTrainMixedDataset
phases_per_epoch: ${phases_per_epoch} # 将一个epoch划分成更小的阶段
batch_sizes:
- ${bs1} # 数据集1的批量大小
- ${bs2} # 数据集2的批量大小
datasets:
# SA1B 作为图像数据集的示例
- target: training.dataset.vos_dataset.VOSDataset
training: true
video_dataset:
target: training.dataset.vos_raw_dataset.SA1BRawDataset
img_folder: ${path_to_img_folder}
gt_folder: ${path_to_gt_folder}
file_list_txt: ${path_to_train_filelist} # 可选
sampler:
target: training.dataset.vos_sampler.RandomUniformSampler
num_frames: 1
max_num_objects: ${max_num_objects_per_image}
transforms: ${image_transforms}
# SA-V 作为视频数据集
- target: training.dataset.vos_dataset.VOSDataset
training: true
video_dataset:
target: training.dataset.vos_raw_dataset.JSONRawDataset
img_folder: ${path_to_img_folder}
gt_folder: ${path_to_gt_folder}
file_list_txt: ${path_to_train_filelist} # Optional
ann_every: 4
sampler:
target: training.dataset.vos_sampler.RandomUniformSampler
num_frames: 8 # Number of frames per video
max_num_objects: ${max_num_objects_per_video}
reverse_time_prob: ${reverse_time_prob} # probability to reverse video
transforms: ${video_transforms}
shuffle: True
num_workers: ${num_train_workers}
pin_memory: True
drop_last: True
collate_fn:
target: training.utils.data_utils.collate_fn
partial: true
dict_key: all

1.3 config解析

1)scratch

  • resolution:训练时使用的图像分辨率
  • num_frames:每个视频样本中使用的帧数,这里是8
  • max_num_objects:每个视频中最多处理的对象数量,这里是3
  • base_lr和vision_lr:基础学习率和视觉模块的学习率
  • phases_per_epoch:每个epoch分成的阶段数,这里是1
  • num_epochs:总训练周期数

2)dataset

  • img_folder和gt_folder:图像和标注文件夹的路径
  • file_list_txt:包含用于训练的视频子集的文件列表路径
  • multiplier:数据增强的倍数,这里是2

3)数据增强

定义了在训练过程中对视频帧应用的数据增强操作,包括随机水平翻转、随机仿射变换、随机调整大小、颜色抖动、随机灰度化等

4)训练器

  • target:指定使用的训练器类,这里是training.trainer.Trainer
  • mode:训练模式,这里是train_only
  • model:定义了SAM2模型的结构和参数,包括图像编码器、记忆注意力模块、记忆编码器等,这些参数控制了模型的架构和行为

5)数据加载器

  • target:指定使用的数据集类,这里是 TorchTrainMixedDataset。
  • phases_per_epoch 和 batch_sizes:每个 epoch 的阶段数和每个数据集的批次大小。
  • datasets:定义了数据集的具体配置,包括数据增强、采样器等。

6)优化器(optim)

  • amp:自动混合精度设置
  • optimizer:指定使用的优化器,这里是torch.optim.AdamW
  • gradient_clip:梯度裁剪设置,最大范数为0.1

7)损失函数

定义了训练过程中使用的损失函数,包括掩码损失,dice损失,iou损失和分类损失

8)chcekpoint

  • save_dir:保存checkpoints的目录
  • save_freq:保存频率,设置为0意味着只保存最后一个检查点
  • model_weight_initializer:模型权重初始化器的配置。

相关文章:

  • 实时监测+远程管控:ADW300解锁阳台光伏运维新维度
  • Java转Go日记(六):TCP黏包
  • 5个Go接口常见错误及避免方法
  • 初次尝试Ghidra
  • usb2.0的硬件知识(一)
  • 2023蓝帽杯初赛内存取证-3
  • 【Ansible】批量管理 Windows自动化运维
  • 设置开机自启动
  • Cursor 设置规则
  • 遨游通讯发布国产化旗舰三防手机AORO AU1:以自主可控重塑工业安全
  • Curl用法解析
  • 基于华为云 ModelArts 的在线服务应用开发(Requests 模块)
  • drupal7可以从测试环境一键部署到生产环境吗
  • Springboot+Vue实现邮箱验证功能(邮箱登录+忘记密码)
  • Docker 部署 PostgreSQL 数据库
  • 基于龙芯 2K1000处理器和复旦微 FPGA K7 的全国产RapidIO 解决方案研究
  • Maven插件学习(三)——插件配置依赖和执行目标
  • 代码随想录算法训练营Day34
  • 【Java面试笔记:基础】4.强引用、软引用、弱引用、幻象引用有什么区别?
  • LangChain 核心模块学习:Chains
  • 大家聊中国式现代化|刘亮:因地制宜发展新质生产力,推动经济高质量发展
  • 金发科技去年净利增160%,机器人等新领域催生材料新需求
  • 私和人命:清代四川南部县谢相荣投河溺毙一案
  • 一季度减持阿里、美团,加仓顺丰,张坤:与其纠结经济,不如着眼企业
  • 被指违反代理协议遭南航暂停售票资格, 去哪儿网:今起恢复
  • 陈杨梅:刷到“棉花糖爸爸”寻女视频,隐约觉得自己就是爸爸要找的孩子