阅读 133

Pytorch系列:(四)IO操作

Pytorch系列:(四)IO操作

首先注意pytorch中模型保存有两种格式,pth和pkl,其中,pth是pytorch默认格式,pkl还支持pickle库,不过一般如果没有特殊需求的时候,推荐使用默认pth格式保存

pytorch中有两种数据保存方法,一种是存储整个模型,一种只存储参数

方法一:存储整个模型

#保存torch.save(model1, 'net.pth')#读取model1 = torch.load('net.pth')

方法二:存储模型参数

#保存torch.save(model.state_dict(), 'checkpoint.pth')#提取state_dict = torch.load('checkpoint.pth')

model.load_state_dict(state_dict)

state_dict说明

state_dict 包含了模型使用的所有参数(Parameter类型),如果自定义的模型参数没有用Parameter封装,那么不会出现在state_dict中, 所以使用的时候,自定义参数一定不要忘记使用Parameter进行封装。

class MLP(nn.Module):
    def __init__(self):        super(MLP, self).__init__()        self.w1 = torch.randn(10,2)        self.w2 = nn.Parameter(torch.randn(2,1))        self.l1 = nn.Linear(10,1)    def forward(self,x):
        pass 


net = MLP()

net.state_dict()

输出,可以发现只有w2和l1

OrderedDict([('w2',
              tensor([[0.9826],                      [0.4665]])),
             ('l1.weight',
              tensor([[ 0.3098,  0.0985, -0.2566, -0.1024,  0.0449, -0.1681, -0.1743,  0.2985,                       -0.0644, -0.0181]])),
             ('l1.bias', tensor([-0.2871]))])

中间状态保存

在训练的时候,可以保存训练中的中间状态,只需要把参数都保存到state字典中就可以了。 例如,在断点续传任务中,可以把epoch,模型状态,优化器状态,初始learning rate 等进行保存。

state = {          'state_dict': net.state_dict(),          'optimizer': optim.optimizer.state_dict(),          'lr_base': optim.lr_base          'epoch': epoch
        }
            
torch.save(
            state,            self.CKPTS_PATH +            'ckpt_' + self.VERSION +            '/epoch'+ str(epoch) +            '.pkl'
          )

加载

state = torch.load(
                    self.CKPTS_PATH +                    'ckpt_' + self.VERSION +                    '/epoch'+ str(epoch) +                    '.pkl'
                   )  

 
net.load_state_dict(state['state_dict'])

optim.optimizer.load_state_dict(state['optimizer'])
optim.lr_base = state['lr_base']
start_epoch = state['epoch']


__EOF__

本文作者Taaccoo 
本文链接:https://www.cnblogs.com/quant-q/p/14737156.html
关于博主:评论和私信会在第一时间回复。或者直接私信我。
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
声援博主:如果您觉得文章对您有帮助,可以点击文章右下角【推荐】一下。您的鼓励是博主的最大动力!


文章分类
后端
文章标签
版权声明:本站是系统测试站点,无实际运营。本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 XXXXXXo@163.com 举报,一经查实,本站将立刻删除。
相关推荐