PyTorch autograd -- grad can be implicitly created only for scalar outputs(PyTorch autograd -- 只能为标量输出隐式创建 grad)
问题描述
我在 PyTorch 中使用了 autograd 工具,并且发现自己需要通过整数索引访问一维张量中的值.像这样:
I am using the autograd tool in PyTorch, and have found myself in a situation where I need to access the values in a 1D tensor by means of an integer index. Something like this:
def basic_fun(x_cloned):
res = []
for i in range(len(x)):
res.append(x_cloned[i] * x_cloned[i])
print(res)
return Variable(torch.FloatTensor(res))
def get_grad(inp, grad_var):
A = basic_fun(inp)
A.backward()
return grad_var.grad
x = Variable(torch.FloatTensor([1, 2, 3, 4, 5]), requires_grad=True)
x_cloned = x.clone()
print(get_grad(x_cloned, x))
我收到以下错误消息:
[tensor(1., grad_fn=<ThMulBackward>), tensor(4., grad_fn=<ThMulBackward>), tensor(9., grad_fn=<ThMulBackward>), tensor(16., grad_fn=<ThMulBackward>), tensor(25., grad_fn=<ThMulBackward>)]
Traceback (most recent call last):
File "/home/mhy/projects/pytorch-optim/predict.py", line 74, in <module>
print(get_grad(x_cloned, x))
File "/home/mhy/projects/pytorch-optim/predict.py", line 68, in get_grad
A.backward()
File "/home/mhy/.local/lib/python3.5/site-packages/torch/tensor.py", line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/mhy/.local/lib/python3.5/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
总的来说,我有点怀疑使用变量的克隆版本应该如何在梯度计算中保持该变量.在A 的计算中实际上不使用变量本身,因此当您调用A.backward() 时,它不应成为该操作的一部分.
I am in general, a bit skeptical about how using the cloned version of a variable is supposed to keep that variable in gradient computation. The variable itself is effectively not used in the computation of A, and so when you call A.backward(), it should not be part of that operation.
感谢您对这种方法的帮助,或者是否有更好的方法来避免丢失梯度历史并仍然通过 requires_grad=True 索引通过一维张量!
I appreciate your help with this approach or if there is a better way to avoid losing the gradient history and still index through a 1D tensor with requires_grad=True!
res 是一个包含 1 到 5 平方值的零维张量列表.为了连接一个包含 [1.0, 4.0, ..., 25.0] 的张量,我改变了 return Variable(torch.FloatTensor(res)) 到 torch.stack(res, dim=0),产生 tensor([ 1., 4., 9., 16., 25.], grad_fn=.
res is a list of zero-dimensional tensors containing squared values of 1 to 5. To concatenate in a single tensor containing [1.0, 4.0, ..., 25.0], I changed return Variable(torch.FloatTensor(res)) to torch.stack(res, dim=0), which produces tensor([ 1., 4., 9., 16., 25.], grad_fn=<StackBackward>).
但是,我收到了这个由 A.backward() 行引起的新错误.
However, I am getting this new error, caused by the A.backward() line.
Traceback (most recent call last):
File "<project_path>/playground.py", line 22, in <module>
print(get_grad(x_cloned, x))
File "<project_path>/playground.py", line 16, in get_grad
A.backward()
File "/home/mhy/.local/lib/python3.5/site-packages/torch/tensor.py", line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/mhy/.local/lib/python3.5/site-packages/torch/autograd/__init__.py", line 84, in backward
grad_tensors = _make_grads(tensors, grad_tensors)
File "/home/mhy/.local/lib/python3.5/site-packages/torch/autograd/__init__.py", line 28, in _make_grads
raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs
推荐答案
我将我的 basic_fun 更改为以下内容,解决了我的问题:
I changed my basic_fun to the following, which resolved my problem:
def basic_fun(x_cloned):
res = torch.FloatTensor([0])
for i in range(len(x)):
res += x_cloned[i] * x_cloned[i]
return res
此版本返回标量值.
这篇关于PyTorch autograd -- 只能为标量输出隐式创建 grad的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
本文标题为:PyTorch autograd -- 只能为标量输出隐式创建 grad
- 使用公司代理使Python3.x Slack(松弛客户端) 2022-01-01
- 如何使用PYSPARK从Spark获得批次行 2022-01-01
- 使用 Cython 将 Python 链接到共享库 2022-01-01
- 我如何卸载 PyTorch? 2022-01-01
- ";find_element_by_name(';name';)";和&QOOT;FIND_ELEMENT(BY NAME,';NAME';)";之间有什么区别? 2022-01-01
- 我如何透明地重定向一个Python导入? 2022-01-01
- 计算测试数量的Python单元测试 2022-01-01
- 检查具有纬度和经度的地理点是否在 shapefile 中 2022-01-01
- YouTube API v3 返回截断的观看记录 2022-01-01
- CTR 中的 AES 如何用于 Python 和 PyCrypto? 2022-01-01
