阅读 134

PyTorch教程-8:举例详解TensorBoard的使用

笔者PyTorch的全部简单教程请访问:https://www.jianshu.com/nb/48831659

PyTorch教程-8:举例详解TensorBoard的使用

TensroBoard 是一个帮助我们对模型训练、数据处理等很有帮助的可视化工具,虽然这些可视化操作都可以通过代码配合matplotlib这些很好用的绘图库来实现,但是TensorBoard使得它变得更加简单。

官方的一个简单教程:https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html

首先依旧使用最开始的一个简单例子:CIFAR10的分类任务,先引入数据、构建模型、创建优化器、损失函数等任务:

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

接入TensorBoard

tensorboard包需要在 torch.utils 中引入,首先我们先通过一个 SummaryWriter 实例来准备一个接入TensorBoard的接口:

from torch.utils.tensorboard import SummaryWriter

summaryWriter = SummaryWriter("./runs/")

这里我们设置了 run 文件夹用来存储记录信息的位置。

如果在引入tensorboard时报错:No module named 'torch.utils.tensorboard',那么需要你手动安装tensorboard,比如pip的方式下可以使用:

pip install tensorboard

启动TensorBoard

启动Tensorboard需要在命令行中使用,其中的 logdir 参数就是我们存放log的文件夹位置:

tensorboard --logdir=runs

TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.4.0 at http://localhost:6006/ (Press CTRL+C to quit)

等待TensorBoard启动后,在浏览器中访问 https://localhost:6006 或者 http://127.0.0.1:6006/ 就能打开TensorBoard了,6006是它的默认端口。此时打开后会发现它提示没有任何数据,是因为我们还没有向 runs 文件夹下写入记录信息的log文件。

向TensorBoard中写入信息

向TensorBoard中写入信息很简单,使用 SummaryWriter 的方法就可以完成,比如我们将第一个batch的图片展示到TensorBoard中的例子:

trainloader_iterator = iter(trainloader)

images, labels = trainloader_iterator.next()

# create grid of images
img_grid = torchvision.utils.make_grid(images)

# write to tensorboard
summaryWriter.add_image("A Batch of Image Samples",img_grid)

运行后打开/刷新TensorBoard页面就可以看到新的效果,图片被显示了出来(这里Tensor格式的图片的RGB通道是不对的,要显示正常的图片,需要进行转换,这里为了展示TensorBoard的效果所以忽略了这一点)

1.png

可以看到,使用了 add_image 将一个图片写入到TensorBoard中,SummaryWriter还提供了很多方法,这里列一下,详细的参数说明和例子请参考:
https://pytorch.org/docs/stable/tensorboard.html#torch-utils-tensorboard

  • add_scalar(tag, scalar_value, global_step=None, walltime=None):添加标量数据
  • add_scalars(main_tag, tag_scalar_dict, global_step=None, walltime=None):添加多个标量数据
  • add_histogram(tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None):添加一个柱状图
  • add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):添加一张图片
  • add_images(tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'):添加多个图片
  • add_figure(tag, figure, global_step=None, close=True, walltime=None):渲染一个matplotlib的图片然后添加到TensorBoard
  • add_video(tag, vid_tensor, global_step=None, fps=4, walltime=None):添加视频
  • add_audio(tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None):添加音频
  • add_text(tag, text_string, global_step=None, walltime=None):添加文本
  • add_graph(model, input_to_model=None, verbose=False):添加图像
  • add_embedding(mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None):添加嵌入式投影,一个很好的例子就是我们可以将高维数据映射到三维空间中进行直观地展示和可视化
  • add_pr_curve(tag, labels, predictions, global_step=None, num_thresholds=127, weights=None, walltime=None):添加PR曲线
  • add_custom_scalars(layout):添加用户定义的标量
  • add_mesh(tag, vertices, colors=None, faces=None, config_dict=None, global_step=None, walltime=None):添加3D模型
  • add_hparams(hparam_dict, metric_dict, hparam_domain_discrete=None, run_name=None):添加一些可以调节的超参数

在TensorBoard中展示网络模型

TensorBoard不仅可以可视化数据,还可以可视乎模型,使用 add_graph 方法就可以将一个模型写入,他接受的第一个参数用于传入一个模型,第二个参数是要喂给这个模型的数据,这里就是一个batch的图片:

images, labels = trainloader_iterator.next()
summaryWriter.add_graph(net,images)
2.png

使用TensorBoard记录训练过程

我们在之前的例子中展示过,对于训练过程,我们把训练的loss打印到了控制台上,当然如果为了能够更加直观的展示loss的变化过程,我们可以使用一个list保存这些loss,等待训练完成后使用matplotlib等工具对其进行可视化。TensorBoard提供了更简单的方式,我们可以直接将loss写到TensorBoard中,这样更加的简单,只要对之前训练的代码做小小的修改即可:

epochs = 2

# running_loss to record the losses
running_loss = 0.0

for epoch in range(epochs):
    for i,data in enumerate(trainloader,0):
        # get input images and their labels
        inputs, labels = data
        # set optimizer buffer to 0
        optimizer.zero_grad()
        # forwarding
        outputs = net(inputs)
        # computing loss
        loss = loss_function(outputs, labels)
        # loss backward
        loss.backward()
        # update parameters using optimizer
        optimizer.step()

        # printing some information
        running_loss += loss.item()
        # for every 1000 mini-batches, print the loss
        if i % 1000 == 0:
            print("epoch {} - iteration {}: average loss {:.3f}".format(epoch+1, i, running_loss/1000))

            summaryWriter.add_scalar("training_loss",running_loss/1000, epoch * len(trainloader) + i)

            running_loss = 0.0


print("Training Finished!")

然后在TensorBoard中就可以看到我们的loss的变化结果,还可以对其进行一些其他设置,比如横坐标的格式、曲线的光滑程度等。

3.png

作者:超级超级小天才

原文链接:https://www.jianshu.com/p/8020076e72b3

文章分类
后端
版权声明:本站是系统测试站点,无实际运营。本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 XXXXXXo@163.com 举报,一经查实,本站将立刻删除。
相关推荐