阅读 84

Pytorch实现线性回归

import torchfrom torch import nn # 第一步  准备数据集 # x,y是矩阵,31列 也就是说总共有3个数据,每个数据只有1个特征 x_data = torch.tensor([[1.0], [2.0], [3.0]]) y_data = torch.tensor([[2.0], [4.0], [6.0]]) # 第二步  设计数据模型 # Module构造出来的对象会自动根据计算图实现backward()操作class LinearModel(torch.nn.Module):    def __init__(self):        # 调用父类的init        super(LinearModel, self).__init__()        # 主要是为了构造对象,输入了(1,1)维度的初始值,        # Linear(1, 1, bias)  bias用于设定是否加上b这个参数        # 会自动进行backward()反向传播        # Linear大概就是执行y = w * x + b        self.linear = torch.nn.Linear(1, 1)    # 这里的forward把父类Module中的forward覆盖了,可以通过__call__魔术函数的形式调用    def forward(self, x):        # 调用了这个类的linear属性        # __call__魔术函数,直接拿实例化的对象作为方法名去调用类内的__call__函数        y_pred = self.linear(x)        return y_pred # 实例化class LinearModelmodel = LinearModel() # 第三步  构造损失函数 和  训练器(优化器) # MSELoss继承自nn下的Module # 参数size_average = True 损失是否求均值, reduce = True 是否降维 criterion = torch.nn.MSELoss(reduction = 'sum')''' pytorch提供的优化器----不同优化器降低的损失是不同的 torch.optim.-----    Adgrad    Adam    Adamax    ASGD    LBFGS    RMSprop    Rprop    SGD '''# 优化器 不会构建计算图 # optimizer是为了更新梯度 # lr 学习率 # model.parameters()自动完成参数的初始化操作,找到需要优化的权重 optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) # 第四步 训练循环 # 训练过程 training cyclefor epoch in range(100):    '''    y_hat  --->   loss  --->  优化器梯度归零  --->  backward  --->  优化器更新参数    '''    # 利用__call__调用实现了forward(x_data)    y_pred = model(x_data)    # 计算损失    loss = criterion(y_pred, y_data)  # forward: loss    # print时调用loss 会调用一个___str___()函数,所以不会产生计算图    print(epoch, loss.item())    # 梯度归零    optimizer.zero_grad()    loss.backward()  # backward: autograd,自动计算梯度    optimizer.step()  # update 参数,即更新w和b的值 print('w = ', model.linear.weight.item()) print('b = ', model.linear.bias.item()) # 第五步 测试 # 测试模型 x_test = torch.tensor([[4.0]]) y_test = model(x_test) print('y_pred = ', y_test.data)


作者:DLung
链接:https://juejin.cn/post/7025253227251630088


文章分类
后端
文章标签
版权声明:本站是系统测试站点,无实际运营。本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 XXXXXXo@163.com 举报,一经查实,本站将立刻删除。
相关推荐