I would like to define a module in torch-script that contains other modules. In the forward pass, a matrix would be populated by assigning the output of each sub-module to different elements of the matrix.
Here’s a simplified example:
import torch from torch import jit, nn, Tensor class Module1(jit.ScriptModule): def __init__(self, first: torch.nn.Module, second: torch.nn.Module): super(Module1, self).__init__() self.first = first self.second = second @jit.script_method def forward(self, input: Tensor) -> Tensor: out = torch.eye(2) out[0,0] = self.first(input) out[1,1] = self.second(input) return out def test1(): m1 = Module1(first=nn.Linear(1, 1), second=nn.Linear(1, 1)) out = m1(torch.ones(1)) print(out) if __name__ == '__main__': test1()
This runs fine with
PYTORCH_JIT=0 python3 test.py, but an error is thrown with jit:
RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): File "torch_kalman/rewrite/test.py", line 14, in forward def forward(self, input: torch.Tensor) -> torch.Tensor: out = torch.eye(2) out[0,0] = self.first(input) ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE out[1,1] = self.second(input) return out RuntimeError: output with shape  doesn't match the broadcast shape 
I am not sure why the interpreter thinks that the module outputs have zero extent. Is this a known unsupported feature? I could not find it in documentation. Or is this a bug?
EDIT: Maybe “output of a module” is not needed in the title – seems like nothing generated at runtime can be assigned to tensor-elements? For example, the following attempt at a workaround hits the same error:
@jit.script_method def forward(self, input: Tensor) -> Tensor: out = jit.annotate(List[Tensor], ) out += [self.first(input)] out += [self.second(input)] out = torch.stack(out) out2 = torch.eye(len(out)) for i in range(len(out)): out2[i, i] = out[i] return out2