For example:
import torch
hidden_dim1 = 10
hidden_dim2 = 5
tagset_size = 2
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.line1 = torch.nn.Linear(hidden_dim1, hidden_dim2)
self.line2 = torch.nn.Linear(hidden_dim2, tagset_size)
def forward(self, x, y):
out1 = self.line1(x)
out2 = self.line2(y)
return out1
X = torch.randn(20, hidden_dim1)
Y = torch.randn(hidden_dim1, hidden_dim2)
inputs = (X, Y)
model = MyModel()
f = './model.onnx'
torch.onnx.export(model, inputs, f,
opset_version=9,
example_outputs=None,
input_names=["X"], output_names=["Y"],verbose=True)
graph(%X : Float(20, 10, strides=[10, 1], requires_grad=0, device=cpu),
%line1.weight : Float(5, 10, strides=[10, 1], requires_grad=1, device=cpu),
%line1.bias : Float(5, strides=[1], requires_grad=1, device=cpu)):
%Y : Float(20, 5, strides=[5, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1](%X, %line1.weight, %line1.bias) # /root/.conda/envs/torch1.9/lib/python3.6/site-packages/torch/nn/functional.py:1847:0
return (%Y)
How every, the exported graph doesn’t contain line2
, maybe because the output of MyModel is not depend on out2 = self.line2(y)
? I guess the graph is pruned by default.
What should I do if I want to not do pruning?
Motivation
I want to do something for self.named_parameters()
in model.forward()
, eg.
def check_parameters():
# do something for parameters by calling
# some ops including OP1, OP2 and so on
return
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.line = torch.nn.Linear(hidden_dim1, hidden_dim2)
def forward(self, x, y):
out = self.line1(x)
check_parameters()
return out
How every, the exported graph doesn’t contain OP1, OP2
, maybe because the output of MyModel is not depend on check_parameters()
? I guess the graph is pruned by default.