关于Torch中的 scatter_ 函数

关于Torch中的 scatter_ 函数

scatter_ 和 one hot

看了许多博客,中国人写博客有一个特点就是复制来复制去,根本没有讲到重点,好了废话不多扯,今天讲下 scatter_ 函数。

操作一:

import torch  # 导入 torch模块,这里操作的都是张量数据
src = torch.arange(1, 11).reshape((2, 5)) # 这里创建一个 2行5列的数据
print(src) # 打印出来

tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])

上面这个是准备数据,是一个两行五列的数据。再创建一个索引数据

index = torch.tensor([[0, 1, 2, 0, 2]])
print(index)

tensor([[0, 1, 2, 0, 2]])

在这之前都是很简单的,相比读者肯定能看到,无非就是两个数据,请耐心往下看

result_1 = torch.zeros(3, 5, dtype=src.dtype) # 创建一个3行5列的数据全是0
print(result_1)
tensor([[0, 0, 0, 0 0],
        [0, 0 0, 0, 0],
        [0, 0, 0, 0, 0]])

解析来就是使用 scatter_函数: 也就是根据相关索引,把result_1的指定位置填充下

result = result_1.scatter_(0, index, src)

这里是什么意思呢, 0 表明按列来处理,result_1 是需要被更改的数据,index是索引位置, src数用来填充的数据,举例子: 如上面描述:

result_1 = tensor([[0, 0, 0, 0 0],

[0, 0 0, 0, 0],

[0, 0, 0, 0, 0]])

index = tensor([[0, 1, 2, 0, 2]])

tensor([[ 1, 2, 3, 4, 5],

[ 6, 7, 8, 9, 10]])

第一个参数 0 表明按列来处理

索引第1个值为0,这表明第1列的第1个数据设置为scr中的第2个数据

ensor([[1, 0, 0, 0 0],

[0, 0 0, 0, 0],

[0, 0, 0, 0, 0]])

索引第2个值为1,这表明第2列的第2个数据设置为scr中的第2个数据

ensor([[1, 0, 0, 0 0],

[0, 2 0, 0, 0],

[0, 0, 0, 0, 0]])

索引第3个值为2,这表明第3列的第3个数据设置为scr中的第3个数据

ensor([[1, 0, 0, 0 0],

[0, 2 0, 0, 0],

[0, 0, 3, 0, 0]])

索引第4个值为0,这表明第4列的第1个数据设置为scr中的第4个数据

ensor([[1, 0, 0, 4 0],

[0, 2 0, 0, 0],

[0, 0, 3, 0, 0]])

索引第5个值为2,这表明第5列的第3个数据设置为scr中的第5个数据

ensor([[1, 0, 0, 4 0],

[0, 2 0, 0, 0],

[0, 0, 3, 0, 5]])

以上就是详细的计算流程

操作2:

idx = torch.tensor([[0, 1, 2, 3,4]])
last = torch.zeros(3, 5, dtype=src.dtype).scatter_(dim=1, index=idx, value=2)

这里第一步我信任大家都熟悉,就是创建一个数据而已,这里我们理解为索引数据

1、torch.zeros(3, 5, dtype=src.dtype). 表明的是创建一个3行5列的数据矩阵,全是0

tensor([[0, 0, 0, 0 0],

[0, 0 0, 0, 0],

[0, 0, 0, 0, 0]])

2、dim=1,表明是按行计算

3、value,表明相应的位置上设置为某个值

idx = torch.tensor([[0, 1, 2, 3,4]])

表明的是第一行的第 0 1 2 3 4 的位置上全是设置为2,也就是

tensor([[2, 2, 2, 2, 2],

[0, 0, 0, 0, 0],

[0, 0, 0, 0, 0]])

当然,我信任某些人还是一脸懵逼,再继续往下看

idx = torch.tensor([[0, 1, 2, 3,4],[0,0,0,0,4]])
last = torch.zeros(3, 5, dtype=src.dtype).scatter_(dim=1, index=idx, value=2)

这里我们看到idx为 torch.tensor([[0, 1, 2, 3,4],[0,0,0,0,4]])

这个idx有两行,那么他对应的也是 torch.zeros(3, 5, dtype=src.dtype)中的两行数据,

[0, 1, 2, 3,4] 表明的是第一行的第 0 1 2 3 4 的位置上全是设置为2

[0,0,0,0,4]]表明的是第二行的第 0 、4 的位置上设置为2,其他地方不变

因此整体数据变成了

tensor([[2, 2, 2, 2, 2],

[2, 0, 0, 0, 2],

[0, 0, 0, 0, 0]])

好了,这个函数介绍到此为止,希望能帮到大家

© 版权声明
THE END
如果内容对您有所帮助,就支持一下吧!
点赞0 分享
娴雯怡的头像 - 鹿快
评论 抢沙发

请登录后发表评论

    暂无评论内容