每天五分钟深度学习框架PyTorch:算法模型的保存和加载(CPU和GPU)
本文重点
我们前面学习了模型的训练,比如线性回归,全连接神经网络,各种经典的卷积神经网络,模型训练完成之后,我们如何将训练的模型保存起来,然后方便之后的使用。pytorch已经封装好了相关的api,下面我们对此进行介绍。
保存模型的两种方式
在pytorch中使用torch.save来保存模型的结构和参数,有两种保存方式:
方式一:
torch.save(model , './model.pth ' )
方式二:
torch.save(model.state_dict(), '. /model_state.pth')
第一种方式:保存整个模型的结构信息和参数信息,保存的对象是模型model
第二种方式:只保存模型的参数,保存的对象是模型的状态
加载模型
当我们使用第一种方式保存模型的时候,我们通过下面的方式来加载模型
load_model = torch.load('model. pth' )
当我们使用第二种方式保存模型的时候,我们通过下面的方式来加载模型(先导入模型的结构再加载模型的参数信息)&#x