【SAM2代码解析】training部分-1总体概述
总览
1.1 文件总览
training folder保存了训练SAM2的相关代码,该代码允许使用者们用他们自己的数据集(图像、视频或两者一起)去微调SAM2
文件结构如下:
- dataset文件夹:保存了包含图像和视频数据集以及数据加载器类及其转换
- mode文件夹:包含用于训练、微调的主要模型类(SAM2Train)。SAM2Train继承自SAM2Base模型,并提供用于启用SAM2训练或微调的函数。它还接受所有用于模拟用户提示的训练时参数。
- utils:此文件夹包含训练工具,例如日志记录器和分布式训练工具
- scripts:此文件夹包含用于提取SA-V数据集帧以用于训练的脚本
- loss_fns.py:此文件夹包含训练的主要损失类
- optimizer.py:此文件夹包含支持任意调度器的所有优化器工具
- trainer.py:此文件包含Trainer类,接收所有的Hydra可配置模块(模型、优化器、数据集等),并实现主要的训练、评估循环。
- train.py:此脚本用于启动训练作业,它支持单节点和多节点作业。可运行
python training/train.y -h
的方式查看有关使用方法。
1.2 训练/微调步骤
以MOSE数据集为例
- 1、运行
pip install -e ".[dev]"
安装训练所需的包 - 2、在
configs/sam2.1_training/sam2.1_hiera_b+MOSE_finetune.yaml
中设置MOSE数据集的路径
dataset:
#数据集路径
img_folder:null #MOSE JPEGImages 文件夹路径
gt_folder:null #MOSE Annotations 文件径
file_list txt:null # 可选路径,包含用于洲练的视频子集的文件列表
- 3、使用GPU在MOSE上微调基础模型
python training/train.py -c configs/sam2.1_training/sam2.1_hiera_b+MOSE_finetune.yaml --use-cluster 0
- 4、可以使用存储在实验日志目录下的tensorboard/文件夹中的TensorBoard日志来监控训练损失。我们还为评估目的提供了一个样本验证拆分。
- 5、训练完成后,可以使用实验日志目录中
checkpoints/
文件夹中保存的新检查点 - 6、在图像和视频上进行训练:代码支持在图像和视频上进行训练,我们提供了用于加载SA-1B作为示例图像数据集、SA-V作为示例视频数据集以及任何DAVIS风格的视频数据集的类。注意在SA-V上进行训练,必须先使用提供的提取脚本将所有视频提取为JPEG帧。
data:
train:
target: training.dataset.sam2_datasets.TorchTrainMixedDataset
phases_per_epoch: ${phases_per_epoch} # 将一个epoch划分成更小的阶段
batch_sizes:
- ${bs1} # 数据集1的批量大小
- ${bs2} # 数据集2的批量大小
datasets:
# SA1B 作为图像数据集的示例
- target: training.dataset.vos_dataset.VOSDataset
training: true
video_dataset:
target: training.dataset.vos_raw_dataset.SA1BRawDataset
img_folder: ${path_to_img_folder}
gt_folder: ${path_to_gt_folder}
file_list_txt: ${path_to_train_filelist} # 可选
sampler:
target: training.dataset.vos_sampler.RandomUniformSampler
num_frames: 1
max_num_objects: ${max_num_objects_per_image}
transforms: ${image_transforms}
# SA-V 作为视频数据集
- target: training.dataset.vos_dataset.VOSDataset
training: true
video_dataset:
target: training.dataset.vos_raw_dataset.JSONRawDataset
img_folder: ${path_to_img_folder}
gt_folder: ${path_to_gt_folder}
file_list_txt: ${path_to_train_filelist} # Optional
ann_every: 4
sampler:
target: training.dataset.vos_sampler.RandomUniformSampler
num_frames: 8 # Number of frames per video
max_num_objects: ${max_num_objects_per_video}
reverse_time_prob: ${reverse_time_prob} # probability to reverse video
transforms: ${video_transforms}
shuffle: True
num_workers: ${num_train_workers}
pin_memory: True
drop_last: True
collate_fn:
target: training.utils.data_utils.collate_fn
partial: true
dict_key: all
1.3 config解析
1)scratch
- resolution:训练时使用的图像分辨率
- num_frames:每个视频样本中使用的帧数,这里是8
- max_num_objects:每个视频中最多处理的对象数量,这里是3
- base_lr和vision_lr:基础学习率和视觉模块的学习率
- phases_per_epoch:每个epoch分成的阶段数,这里是1
- num_epochs:总训练周期数
2)dataset
- img_folder和gt_folder:图像和标注文件夹的路径
- file_list_txt:包含用于训练的视频子集的文件列表路径
- multiplier:数据增强的倍数,这里是2
3)数据增强
定义了在训练过程中对视频帧应用的数据增强操作,包括随机水平翻转、随机仿射变换、随机调整大小、颜色抖动、随机灰度化等
4)训练器
- target:指定使用的训练器类,这里是
training.trainer.Trainer
- mode:训练模式,这里是
train_only
- model:定义了SAM2模型的结构和参数,包括图像编码器、记忆注意力模块、记忆编码器等,这些参数控制了模型的架构和行为
5)数据加载器
- target:指定使用的数据集类,这里是 TorchTrainMixedDataset。
- phases_per_epoch 和 batch_sizes:每个 epoch 的阶段数和每个数据集的批次大小。
- datasets:定义了数据集的具体配置,包括数据增强、采样器等。
6)优化器(optim)
- amp:自动混合精度设置
- optimizer:指定使用的优化器,这里是
torch.optim.AdamW
- gradient_clip:梯度裁剪设置,最大范数为0.1
7)损失函数
定义了训练过程中使用的损失函数,包括掩码损失,dice损失,iou损失和分类损失
8)chcekpoint
- save_dir:保存checkpoints的目录
- save_freq:保存频率,设置为0意味着只保存最后一个检查点
- model_weight_initializer:模型权重初始化器的配置。