阅读 86 SEO

解析Pytorch中的torch.gather()函数

本文给大家介绍了Pytorch中的torch.gather()函数,通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧

参数说明

以官方说明为例,gather()函数需要三个参数,输入input,维度dim,以及索引index

input必须为Tensor类型

dim为int类型,代表从哪个维度进行索引

index为LongTensor类型

举例说明

1
2
3
4
5
6
7
8
9
input=torch.tensor([[1,2,3],[4,5,6]]) #作为输入
  
index1=torch.tensor([[0,1,1],[0,1,1]]) #作为索引矩阵
  
# dim=0时,按列进行索引
print (torch.gather(input,dim=0,index=index1))
  
# dim=1时,按行进行索引
print (torch.gather(input,dim=1,index=index1))

 结果如下图所示:


1
2
3
4
5
6
7
# 按列进行索引
tensor([[1, 5, 6],
        [4, 2, 6]])
  
# 按行进行索引
tensor([[1, 2, 2],
        [5, 4, 5]])

画图说明 

官方文档

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def gather(self, input, dim, index, *args, **kwargs):
         
        For a 3-D tensor the output is specified by::
         
            out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
            out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
            out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2       
  
        Args:
            input (Tensor): the source tensor
            dim (int): the axis along which to index
            index (LongTensor): the indices of elements to gather    
       
        Example::
         
            >>> t = torch.tensor([[1, 2], [3, 4]])
            >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
            tensor([[ 11],
                    [ 43]])

到此这篇关于Pytorch中的torch.gather()函数的文章就介绍到这了

原文链接:https://blog.csdn.net/xiaoliujun1999/article/details/121292061


文章分类
代码人生
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 gxwowoo@163.com 举报,一经查实,本站将立刻删除。
相关推荐