Hi,
For my use case, I require to be able to take a pytorch module and interpret the sequence of layers in the module so that I can create a “connection” between the layers in some file format. Now let’s say I have a simple module as below
class mymodel(nn.Module):
def __init__(self, input_channels):
super(mymodel, self).__init__()
self.fc = nn.Linear(input_channels, input_channels)
def forward(self, x):
out = self.fc(x)
out += x
return out
if __name__ == "__main__":
net = mymodel(5)
for mod in net.modules():
print(mod)
Here the output yields:
mymodel(
(fc): Linear(in_features=5, out_features=5, bias=True)
)
Linear(in_features=5, out_features=5, bias=True)
as you can see the information about the plus equals operation or plus operation is not captured as it is not a nnmodule in the forward function. My goal is to be able to create a graph connection from the pytorch module object to say something like this in json :
layers {
"fc": {
"inputTensor" : "t0",
"outputTensor": "t1"
}
"addOp" : {
"inputTensor" : "t1",
"outputTensor" : "t2"
}
}
The input tensor names are arbitrary but it captures the essence of the graph and the connections between layers. Is there a way to extract the information from a pytorch object? I was thinking to use the .modules() but then realized that hand written operations are not captured this way as a module. I guess if everything is an nn.module then the .modules() might give me the network layer arrangement. Looking for some help here. Thanks!