pytorch学习笔记
torch.index_select( input, dim, index, out=None )
函数返回的是沿着输入张量的指定维度的指定索引号进行索引的张量子集,其中输入张量、指定维度和指定索引号就是 torch.index_select(input, dim, index, out=None) 函数的三个关键参数,函数参数有:
- input(Tensor) - 需要进行索引操作的输入张量;
- dim(int) - 需要对输入张量进行索引的维度;
- index(LongTensor) - 包含索引号的 1D 张量;
- out(Tensor, optional) - 指定输出的张量。比如执行 torch.zeros([2, 2], out = tensor_a),相当于执行 tensor_a = torch.zeros([2, 2]);
xxx.cuda(non_blocking==True)
使用non_blocking=True来并行处理数据传输
1. x = x.cuda(non_blocking=True) 2. 进行一些和x无关的操作 3. 执行和x有关的操作
在non_blocking=true下,1不会阻塞2,1和2并行。这样将数据从CPU移动到GPU的时候,它是异步的。在它传输的时候,CPU还可以干其他的事情(不依赖于数据的事情)
.cuda()是为了将模型放在GPU上进行训练
non_blocking 默认值为 False
参考:Pytorch的cuda non_blocking (pin_memory)_hxxjxw的博客-CSDN博客_non_blocking pytorch
torch.
exp
(input, out=None)返回具有输入张量输入元素指数的新张量。
>>> torch.exp(torch.tensor([0, math.log(2.)])) tensor([ 1., 2.])
torch.gather()
gather()就是一个很好的tool,它可以帮助我们从批量tensor中取出指定乱序索引下的数据,因此其用途:方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的
本文含有隐藏内容,请 开通VIP 后查看