PyTorch: access weights of a specific module in nn.Sequential()(PyTorch:nn.Sequential() 中特定模块的访问权重)
本文介绍了PyTorch:nn.Sequential() 中特定模块的访问权重的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
当我在 PyTorch 中使用预定义模块时,我通常可以相当轻松地访问其权重.但是,如果我先将模块包装在 nn.Sequential()
中,我该如何访问它们?r.g:
When I use a pre-defined module in PyTorch, I can typically access its weights fairly easily. However, how do I access them if I wrapped the module in nn.Sequential()
first? r.g:
class My_Model_1(nn.Module):
def __init__(self,D_in,D_out):
super(My_Model_1, self).__init__()
self.layer = nn.Linear(D_in,D_out)
def forward(self,x):
out = self.layer(x)
return out
class My_Model_2(nn.Module):
def __init__(self,D_in,D_out):
super(My_Model_2, self).__init__()
self.layer = nn.Sequential(nn.Linear(D_in,D_out))
def forward(self,x):
out = self.layer(x)
return out
model_1 = My_Model_1(10,10)
print(model_1.layer.weight)
model_2 = My_Model_2(10,10)
我现在如何打印重量?model_2.layer.0.weight
不起作用.
How do I print the weights now?
model_2.layer.0.weight
doesn't work.
推荐答案
来自 PyTorch 论坛,这是推荐的方式:
From the PyTorch forum, this is the recommended way:
model_2.layer[0].weight
这篇关于PyTorch:nn.Sequential() 中特定模块的访问权重的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
沃梦达教程
本文标题为:PyTorch:nn.Sequential() 中特定模块的访问权重


猜你喜欢
- 使用公司代理使Python3.x Slack(松弛客户端) 2022-01-01
- 计算测试数量的Python单元测试 2022-01-01
- 检查具有纬度和经度的地理点是否在 shapefile 中 2022-01-01
- ";find_element_by_name(';name';)";和&QOOT;FIND_ELEMENT(BY NAME,';NAME';)";之间有什么区别? 2022-01-01
- 我如何卸载 PyTorch? 2022-01-01
- CTR 中的 AES 如何用于 Python 和 PyCrypto? 2022-01-01
- 使用 Cython 将 Python 链接到共享库 2022-01-01
- 我如何透明地重定向一个Python导入? 2022-01-01
- 如何使用PYSPARK从Spark获得批次行 2022-01-01
- YouTube API v3 返回截断的观看记录 2022-01-01