当前位置: 首页 > news >正文

AMP混合精度训练 详细解析

为什么需要 AMP混合精度训练?

PyTorch 1.6版本今天发布了,带来的最大更新就是自动混合精度。release说明的标题是:

Stable release of automatic mixed precision (AMP).New Beta features include a TensorPipe backend for RPC, memory profiler, and several improvements to distributed training for both RPC and DDP.

可见自动混合精度正是PyTorch 1.6的最大更新。这就带来了几个问题:

  • 什么是自动混合精度训练?
  • 为什么需要自动混合精度?
  • 如何在PyTorch中使用自动混合精度?
    一、什么是自动混合精度训练(AMP)
    >>> import torch
    >>> gemfield = torch.zeros(70,30)
    >>> gemfield.type() 'torch.FloatTensor'
    >>> syszux = torch.Tensor([1,2])
    >>> syszux.type() 'torch.FloatTensor'
    我们知道神经网络框架的计算核心是Tensor,也就是那个从scaler -> array -> matrix -> tensor 维度一路丰富过来的tensor。在PyTorch中,我们可以这样创建一个Tensor:
    可以看到默认创建的tensor都是FloatTensor类型。而在PyTorch中,一共有10种类型的tensor:
  • torch.FloatTensor (32-bit floating point)
  • torch.DoubleTensor (64-bit floating point)
  • torch.HalfTensor (16-bit floating point 1)
  • torch.BFloat16Tensor (16-bit floating point 2)
  • torch.ByteTensor (8-bit integer (unsigned))
  • torch.CharTensor (8-bit integer (signed))
  • torch.ShortTensor (16-bit integer (signed))
  • torch.IntTensor (32-bit integer (signed))
  • torch.LongTensor (64-bit integer (signed))
  • torch.BoolTensor (Boolean)

由此可见,默认的Tensor是32-bit floating point,这就是32位浮点型精度的Tensor。

from torch.cuda.amp import autocast as autocast

自动混合精度的关键词有两个:自动、混合精度。这是由PyTorch 1.6的torch.cuda.amp模块带来的:

混合精度预示着有不止一种精度的Tensor,那在PyTorch的AMP模块里是几种呢?2种:torch.FloatTensor和

torch.HalfTensor

自动预示着Tensor的dtype类型会自动变化,也就是框架按需自动调整tensor的dtype(其实不是完全自动,有些 地方还是需要手工干预);

torch.cuda.amp 的名字意味着这个功能只能在cuda上使用,事实上,这个功能正是NVIDIA的开发人员贡献到PyTorch项目中的。而只有支持Tensor core的CUDA硬件才能享受到AMP的好处(比如2080ti显卡)。Tensor Core是一种矩阵乘累加的计算单元,每个Tensor Core每个时钟执行64个浮点混合精度操作(FP16矩阵相乘和FP32累加),英伟达宣称使用Tensor Core进行矩阵运算可以轻易的提速,同时降低一半的显存访问和存储。

因此,在PyTorch中,当我们提到自动混合精度训练,我们说的就是在NVIDIA的支持Tensor core的CUDA设备上使用torch.cuda.amp.autocast (以及torch.cuda.amp.GradScaler)来进行训练。咦?为什么还要有torch.cuda.amp.GradScaler?

二、为什么需要自动混合精度?

这个问题其实暗含着这样的意思:为什么需要自动混合精度,也就是torch.FloatTensor和torch.HalfTensor的混 合,而不全是torch.FloatTensor?或者全是torch.HalfTensor?

如果非要以这种方式问,那么答案只能是,在某些上下文中torch.FloatTensor有优势,在某些上下文中torch.HalfTensor有优势呗。答案进一步可以转化为,相比于之前的默认的torch.FloatTensor,torch.HalfTensor 有时具有优势,有时劣势不可忽视。

torch.HalfTensor的优势就是存储小、计算快、更好的利用CUDA设备的Tensor Core。因此训练的时候可以减少显存的占用(可以增加batchsize了),同时训练速度更快;

torch.HalfTensor的劣势就是:数值范围小(更容易Overflow / Underflow)、舍入误差(Rounding Error,导致一些微小的梯度信息达不到16bit精度的最低分辨率,从而丢失)。

可见,当有优势的时候就用torch.HalfTensor,而为了消除torch.HalfTensor的劣势,我们带来了两种解决方案:

  1. 梯度scale,这正是上一小节中提到的torch.cuda.amp.GradScaler,通过放大loss的值来防止梯度的underflow(这只是BP的时候传递梯度信息使用,真正更新权重的时候还是要把放大的梯度再unscale回 去);
  2. matmul
    addbmm
    addmm
    addmv
    addr
    baddbmm
    bmm
    chain_matmul
    conv1d
    conv2d
    conv3d
    conv_transpose1d
    conv_transpose2d
    conv_transpose3d
    linear
    matmul
    mm
    mv
    prelu
    回落到torch.FloatTensor,这就是混合一词的由来。那怎么知道什么时候用torch.FloatTensor,什么时候用半 精度浮点型呢?这是PyTorch框架决定的,在PyTorch 1.6的AMP上下文中,如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:

三、混合精度训练的优点是什么?

  • 减少显存
  • 占用加快训练速度:通信量
  • 减半;计算性能翻倍。
    四、混合精度训练的缺点是什么?
  • 数据溢出
  • 舍入误差
    五、混合精度训练的关键技术是什么?
  • float32主权重备份
  • 动态损失缩放
    六、介绍一下 混合精度训练 动态损失缩放?
  • 介绍:即损失标度。这会将梯度也放大1024倍,大大降只需要将损失乘以某个大数字 (如1024),低了梯度发生下溢的几率。计算出梯度后,只需将其除以1024就可以得到准确值。

动态选择损失标度 介绍:发生溢出,跳过优化器更新,损失标度减半;连续N个steps没有发生溢出,损失标度翻倍

七、如何在PyTorch中使用自动混合精度?

答案就是autocast + GradScaler。

1. autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的:

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor model = Net().cuda()

optimizer = optim.SGD(model.parameters(),)

可以使用autocast的context managers语义(如上所示),也可以使用decorators语义。 当进入autocast的上下文后,上面列出来的那些CUDA ops 会把tensor的dtype转换为半精度浮点型,从而在不损失训练精度的情况下加快运算。刚进入autocast的上下文时,tensor可以是任何类型,你不要在model或者input上手工调用.half() ,框架会自动做,这也是自动混合精度中“自动”一词的由来。

另外一点就是,autocast上下文应该只包含网络的前向过程(包括loss的计算),而不要包含反向传播,因为BP 的op会使用和前向op相同的类型。

Traceback (most recent call last):

......

File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl

result = self.forward(*input, **kwargs)

......

RuntimeError: expected scalar type float but found c10::Half

还有的时候呀,你的代码在autocast上下文中会报如下的错误:

对于RuntimeError: expected scalar type float but found c10::Half,这估计是个bug。你可以在tensor上手工调用.float()来让type匹配。

1. GradScaler

但是别忘了前面提到的梯度scaler模块呀,需要在训练最开始之前实例化一个GradScaler对象。因此PyTorch中 经典的AMP使用方式如下:

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor model = Net().cuda()

optimizer = optim.SGD(model.parameters(),)

# 在训练最开始之前实例化一个GradScaler对象scaler = GradScaler()

for epoch in epochs:

for input, target in data:

optimizer.zero_grad()

# 前向过程(model + loss)开启 autocast

with autocast():

scaler的大小在每次迭代中动态的估计,为了尽可能的减少梯度underflow,scaler应该更大;但是如果太大的 话,半精度浮点型的tensor又容易overflow(变成inf或者NaN)。所以动态估计的原理就是在不出现inf或者NaN

梯度值的情况下尽可能的增大scaler的值——在每次scaler.step(optimizer)中,都会检查是否又inf或NaN的梯度出 现:

  • 如果出现了inf或者NaN,scaler.step(optimizer)会忽略此次的权重更新(optimizer.step() ),并且将scaler的大小缩小(乘上backoff_factor);
  • 如果没有出现inf或者NaN,那么权重正常更新,并且当连续多次(growth_interval指定)没有出现inf或者

NaN,则scaler.update()会将scaler的大小增加(乘上growth_factor)。

八、如何使用 AMP混合精度训练 ?

  1. AMP混合精度训练 代码
  2. Trainer 训练类
    class Trainer:
    ...
    def train(self, train_loader, dev_loader=None, train_sampler=None):
    ...
    # 设置 AMP混合精度训练
    if self.args.use_amp:
    scaler = torch.cuda.amp.GradScaler()
    if self.args.local_rank == 0:
    start = time.time()
    for epoch in range(1, self.args.epochs + 1):
    train_sampler.set_epoch(epoch)
    for step, batch_data in enumerate(train_loader):
    self.model.train()
    # 使用 AMP混合精度训练
    if self.args.use_amp:
    with torch.cuda.amp.autocast():
    logits, label = self.on_step(batch_data)
    loss = self.criterion(logits, label)
    torch.distributed.barrier()
    scaler.scale(loss).backward()
    scaler.step(self.optimizer)
    scaler.update()
  3. class Args:
    ...
    local_rank = None
    local_world_size = None
    device_ids = None
    rank = None
    dev = False
    use_amp = True
    Args 参数类
  4. main_worker 主函数

def main_worker(local_rank, local_world_size):

# =======================================

# 设置参数

...

dist.init_process_group(backend="nccl", init_method="tcp://localhost:12345", world_size=local_world_size, rank=local_rank)

n = torch.cuda.device_count() // local_world_size

device_ids = [local_rank]

print(

f"[{os.getpid()}] rank = {local_rank}, "

+ f"world_size = {local_world_size}, n = {n}, device_ids = {device_ids}

\n", end=''

)

torch.cuda.set_device(local_rank)

args = Args()

args.local_world_size = local_world_size

args.local_rank = local_rank

args.device_ids = device_ids

args.rank = local_rank

tokenizer = BertTokenizer.from_pretrained(args.model_path)

  1. AMP混合精度训练 完整代码

torch.cuda.manual_seed_all(seed)

def get_data():

with open("data/train.json", "r", encoding="utf-8") as fp:

data = fp.read()

data = json.loads(data)

return data

def load_data():

data = get_data()

return_data = []

# [(文本, 标签id)]

for d in data:

text = d[0]

label = d[1]

return_data.append(("".join(text.split(" ")).strip(), label))

return return_data

class ClsDataset(Dataset):

def init (self, data):

self.data = data

def len (self):

return len(self.data)

def getitem (self, index):

return self.data[index]

class Collate:

def init (self,

tokenizer,

max_seq_len,

):

self.tokenizer = tokenizer

self.max_seq_len = max_seq_len

def collate_fn(self, batch):

input_ids_all = []

token_type_ids_all = []

attention_mask_all = []

label_all = []

for data in batch:

text = data[0]

label = data[1]

inputs = self.tokenizer.encode_plus(text=text,

max_length=self.max_seq_len,

padding="max_length",

truncation="longest_first",

return_attention_mask=True,

return_token_type_ids=True)

input_ids = inputs["input_ids"]

token_type_ids = inputs["token_type_ids"]

attention_mask = inputs["attention_mask"]

input_ids_all.append(input_ids)

token_type_ids_all.append(token_type_ids)

attention_mask_all.append(attention_mask)

label_all.append(label)

input_ids_all = torch.tensor(input_ids_all, dtype=torch.long)

token_type_ids_all = torch.tensor(token_type_ids_all, dtype=torch.long)

attention_mask_all = torch.tensor(attention_mask_all, dtype=torch.long)

label_all = torch.tensor(label_all, dtype=torch.long)

return_data = {

"input_ids": input_ids_all,

"attention_mask": attention_mask_all,

"token_type_ids": token_type_ids_all,

"label": label_all

}

return return_data

def build_optimizer(model, args):

no_decay = ['bias', 'LayerNorm.weight']

optimizer_grouped_parameters = [

{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],

'weight_decay': args.weight_decay},

{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],

'weight_decay': 0.0}

]

# optimizer = AdamW(model.parameters(), lr=learning_rate)

optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

return optimizer

class Trainer:

def init (self,

args,

config,

model,

criterion,

optimizer):

self.args = args

self.config = config,

self.model = model

self.criterion = criterion

self.optimizer = optimizer

def on_step(self, batch_data):

label = batch_data["label"].cuda()

input_ids = batch_data["input_ids"].cuda()

token_type_ids = batch_data["token_type_ids"].cuda()

attention_mask = batch_data["attention_mask"].cuda()

output = self.model(input_ids=input_ids,

token_type_ids=token_type_ids,

attention_mask=attention_mask,

labels=label)

logits = output[1]

return logits, label

def loss_reduce(self, loss):

rt = loss.clone()

dist.all_reduce(rt, op=dist.ReduceOp.SUM)

rt /= self.args.local_world_size

return rt

def output_reduce(self, outputs, targets):

output_gather_list = [torch.zeros_like(outputs) for _ in range(self.args.local_world_size)]

# 把每一个GPU的输出聚合起来

dist.all_gather(output_gather_list, outputs)

outputs = torch.cat(output_gather_list, dim=0)

target_gather_list = [torch.zeros_like(targets) for _ in range(self.args.local_world_size)]

# 把每一个GPU的输出聚合起来

dist.all_gather(target_gather_list, targets)

targets = torch.cat(target_gather_list, dim=0)

return outputs, targets

def train(self, train_loader, dev_loader=None, train_sampler=None):

gloabl_step = 1

best_acc = 0.

if self.args.use_amp:

scaler = torch.cuda.amp.GradScaler()

if self.args.local_rank == 0:

start = time.time()

for epoch in range(1, self.args.epochs + 1):

train_sampler.set_epoch(epoch)

for step, batch_data in enumerate(train_loader):

self.model.train()

if self.args.use_amp:

with torch.cuda.amp.autocast():

logits, label = self.on_step(batch_data)

loss = self.criterion(logits, label)

torch.distributed.barrier()

scaler.scale(loss).backward()

scaler.step(self.optimizer)

scaler.update()

else:

logits, label = self.on_step(batch_data)

loss = self.criterion(logits, label)

torch.distributed.barrier()

loss.backward()

self.optimizer.step()

if self.args.local_rank == 0:

print("【train】 epoch:{}/{} step:{}/{} loss:{:.6f}".format(

epoch, self.args.epochs, gloabl_step, self.args.total_step, loss

))

gloabl_step += 1

if self.args.dev:

if gloabl_step % self.args.eval_step == 0:

loss, accuracy = self.dev(dev_loader)

if self.args.local_rank == 0:

print("【dev】 loss:{:.6f} accuracy:

{:.4f}".format(loss, accuracy))

if accuracy > best_acc:

best_acc = accuracy

print("【best accuracy】 {:.4f}".format(best_acc))

torch.save(self.model.state_dict(), self.args.ckpt_path)

if self.args.local_rank == 0:

end = time.time()

print("耗时:{}分钟".format((end - start) / 60))

if not self.args.dev and self.args.local_rank == 0:

torch.save(self.model.state_dict(), self.args.ckpt_path)

def dev(self, dev_loader):

self.model.eval()

correct_total = 0

num_total = 0

loss_total = 0.

with torch.no_grad():

for step, batch_data in tqdm(enumerate(dev_loader)):

logits, label = self.on_step(batch_data)

loss = self.criterion(logits, label)

torch.distributed.barrier()

loss = self.loss_reduce(loss)

loss_total += loss

logits, label = self.output_reduce(logits, label)

logits = logits.detach().cpu().numpy()

label = label.view(-1).detach().cpu().numpy()

num_total += len(label)

preds = np.argmax(logits, axis=1).flatten()

correct_num = (preds == label).sum()

correct_total += correct_num

return loss_total, correct_total / num_total

def test(self, model, test_loader, labels):

self.model = model

self.model.eval()

preds = []

trues = []

with torch.no_grad():

for step, batch_data in enumerate(test_loader):

logits, label = self.on_step(batch_data)

torch.distributed.barrier()

logits, label = self.output_reduce(logits, label)

label = label.view(-1).detach().cpu().numpy().tolist()

logits = logits.detach().cpu().numpy()

pred = np.argmax(logits, axis=1).flatten().tolist()

trues.extend(label)

preds.extend(pred)

report = classification_report(trues, preds, target_names=labels)

return report

class Args:

model_path = "/mnt/kaimo/data/pretrain/bert-base-chinese"

ckpt_path = "output/multi-gpu-distributed-mp-amp-cls.pt"

max_seq_len = 128

ratio = 0.92

device = torch.device("cuda" if torch.cuda.is_available else "cpu")

train_batch_size = 32

dev_batch_size = 32

weight_decay = 0.01

epochs = 1

learning_rate = 3e-5

eval_step = 50

local_rank = None

local_world_size = None

device_ids = None

rank = None

dev = False

use_amp = True

def main_worker(local_rank, local_world_size):

# =======================================

# 设置参数

set_seed()

label2id = {

"其他": 0,

"喜好": 1,

"悲伤": 2,

"厌恶": 3,

"愤怒": 4,

"高兴": 5,

}

dist.init_process_group(backend="nccl", init_method="tcp://localhost:12345", world_size=local_world_size,

rank=local_rank)

n = torch.cuda.device_count() // local_world_size

device_ids = [local_rank]

print(

f"[{os.getpid()}] rank = {local_rank}, "

+ f"world_size = {local_world_size}, n = {n}, device_ids = {device_ids}

\n", end=''

)

torch.cuda.set_device(local_rank)

args = Args()

args.local_world_size = local_world_size

args.local_rank = local_rank

args.device_ids = device_ids

args.rank = local_rank

tokenizer = BertTokenizer.from_pretrained(args.model_path)

# =======================================

# =======================================

# 加载数据集

data = load_data()

# 取1万条数据出来

data = data[:10000]

random.shuffle(data)

train_num = int(len(data) * args.ratio)

train_data = data[:train_num]

dev_data = data[train_num:]

collate = Collate(tokenizer, args.max_seq_len)

train_dataset = ClsDataset(train_data)

train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)

train_loader = DataLoader(train_dataset,

batch_size=args.train_batch_size,

num_workers=2,

collate_fn=collate.collate_fn,

sampler=train_sampler)

total_step = len(train_loader) * args.epochs

args.total_step = total_step

dev_dataset = ClsDataset(dev_data)

dev_sampler = torch.utils.data.distributed.DistributedSampler(dev_dataset)

dev_loader = DataLoader(dev_dataset,

batch_size=args.dev_batch_size,

shuffle=False,

num_workers=2,

collate_fn=collate.collate_fn,

sampler=dev_sampler)

test_loader = dev_loader

# =======================================

# =======================================

# 定义模型、优化器、损失函数

config = BertConfig.from_pretrained(args.model_path, num_labels=6)

model = BertForSequenceClassification.from_pretrained(args.model_path,

config=config)

# 第三步:封装模型

model.cuda()

model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.device_ids)

criterion = torch.nn.CrossEntropyLoss()

optimizer = build_optimizer(model, args)

# =======================================

# =======================================

# 定义训练器,进行训练、验证和测试

trainer = Trainer(args,

config,

model,

criterion,

optimizer)

trainer.train(train_loader, dev_loader, train_sampler)

labels = list(label2id.keys())

config = BertConfig.from_pretrained(args.model_path, num_labels=6)

相关文章:

  • 代码随想录回文子序列
  • 如何判断你的PyTorch是GPU版还是CPU版?
  • C++泛型编程(一):模板详解
  • 47-dify案例分享-魔搭+Dify王炸组合!10分钟搭建你的专属 生活小助理
  • 利用【指针引用】对【非空单循环链表】进行删除操作
  • C/C++ | 高频手撕整理(1) —— strstr,memcpy,memmove函数模拟实现
  • ARM架构的微控制器总线矩阵优先级与配置
  • DeepSeek 的长上下文扩展机制
  • 【MCP】详细了解MCP协议:和function call的区别何在?如何使用MCP?
  • 0803分页_加载更多-网络ajax请求2-react-仿低代码平台项目
  • 【多线程】五、线程同步 条件变量
  • 逆向|dy|a_bogus|1.0.1.19-fix.01
  • RK3568 Debian调试记录
  • 基于强化学习的智能交通控制系统设计
  • 基于STM32单片机PWM讲解(HAL库)
  • html css js网页制作成品——HTML+CSS+js美甲店网页设计(5页)附源码
  • pytest 技术总结
  • Windows怎样使用curl下载文件
  • 大模型时代的语言格局演变:为什么是 JavaScript?
  • uml类关系(实现、继承,聚合、组合,依赖、关联)
  • 中法共创《海底两万里》,演员保剑锋重回戏剧舞台演船长
  • 2025年全国贸易摩擦应对工作会议在京召开
  • 刘非任中共浙江省委常委、杭州市委书记
  • 儒说︱问世间孝为何物
  • 孙燕姿演唱会本周末开唱,小票根如何在上海释放大活力
  • 鸿蒙智行八大车型亮相上海车展,余承东拉上三家车企老总“直播推销”