I would like to define a module in torch-script that contains other modules. In the forward pass, a matrix would be populated by assigning the output of each sub-module to different elements of the matrix.
Here’s a simplified example:
import torch
from torch import jit, nn, Tensor
class Module1(jit.ScriptModule):
def __init__(self, first: torch.nn.Module, second: torch.nn.Module):
super(Module1, self).__init__()
self.first = first
self.second = second
@jit.script_method
def forward(self, input: Tensor) -> Tensor:
out = torch.eye(2)
out[0,0] = self.first(input)
out[1,1] = self.second(input)
return out
def test1():
m1 = Module1(first=nn.Linear(1, 1), second=nn.Linear(1, 1))
out = m1(torch.ones(1))
print(out)
if __name__ == '__main__':
test1()
This runs fine with PYTORCH_JIT=0 python3 test.py
, but an error is thrown with jit:
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "torch_kalman/rewrite/test.py", line 14, in forward
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = torch.eye(2)
out[0,0] = self.first(input)
~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
out[1,1] = self.second(input)
return out
RuntimeError: output with shape [] doesn't match the broadcast shape [1]
I am not sure why the interpreter thinks that the module outputs have zero extent. Is this a known unsupported feature? I could not find it in documentation. Or is this a bug?
EDIT: Maybe “output of a module” is not needed in the title – seems like nothing generated at runtime can be assigned to tensor-elements? For example, the following attempt at a workaround hits the same error:
@jit.script_method
def forward(self, input: Tensor) -> Tensor:
out = jit.annotate(List[Tensor], [])
out += [self.first(input)]
out += [self.second(input)]
out = torch.stack(out)
out2 = torch.eye(len(out))
for i in range(len(out)):
out2[i, i] = out[i]
return out2