Pytorch系列:(四)IO操作
Pytorch系列:(四)IO操作
首先注意pytorch中模型保存有两种格式,pth和pkl,其中,pth是pytorch默认格式,pkl还支持pickle库,不过一般如果没有特殊需求的时候,推荐使用默认pth格式保存
pytorch中有两种数据保存方法,一种是存储整个模型,一种只存储参数
方法一:存储整个模型
#保存torch.save(model1, 'net.pth')#读取model1 = torch.load('net.pth')
方法二:存储模型参数
#保存torch.save(model.state_dict(), 'checkpoint.pth')#提取state_dict = torch.load('checkpoint.pth') model.load_state_dict(state_dict)
state_dict说明
state_dict 包含了模型使用的所有参数(Parameter类型),如果自定义的模型参数没有用Parameter封装,那么不会出现在state_dict中, 所以使用的时候,自定义参数一定不要忘记使用Parameter进行封装。
class MLP(nn.Module): def __init__(self): super(MLP, self).__init__() self.w1 = torch.randn(10,2) self.w2 = nn.Parameter(torch.randn(2,1)) self.l1 = nn.Linear(10,1) def forward(self,x): pass net = MLP() net.state_dict()
输出,可以发现只有w2和l1
OrderedDict([('w2', tensor([[0.9826], [0.4665]])), ('l1.weight', tensor([[ 0.3098, 0.0985, -0.2566, -0.1024, 0.0449, -0.1681, -0.1743, 0.2985, -0.0644, -0.0181]])), ('l1.bias', tensor([-0.2871]))])
中间状态保存
在训练的时候,可以保存训练中的中间状态,只需要把参数都保存到state字典中就可以了。 例如,在断点续传任务中,可以把epoch,模型状态,优化器状态,初始learning rate 等进行保存。
state = { 'state_dict': net.state_dict(), 'optimizer': optim.optimizer.state_dict(), 'lr_base': optim.lr_base 'epoch': epoch } torch.save( state, self.CKPTS_PATH + 'ckpt_' + self.VERSION + '/epoch'+ str(epoch) + '.pkl' )
加载
state = torch.load( self.CKPTS_PATH + 'ckpt_' + self.VERSION + '/epoch'+ str(epoch) + '.pkl' ) net.load_state_dict(state['state_dict']) optim.optimizer.load_state_dict(state['optimizer']) optim.lr_base = state['lr_base'] start_epoch = state['epoch']
__EOF__
本文作者:Taaccoo
本文链接:https://www.cnblogs.com/quant-q/p/14737156.html
关于博主:评论和私信会在第一时间回复。或者直接私信我。
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
声援博主:如果您觉得文章对您有帮助,可以点击文章右下角【推荐】一下。您的鼓励是博主的最大动力!