TorchScript does not support assigning output of module to element of tensor

I would like to define a module in torch-script that contains other modules. In the forward pass, a matrix would be populated by assigning the output of each sub-module to different elements of the matrix.

Here’s a simplified example:

import torch
from torch import jit, nn, Tensor


class Module1(jit.ScriptModule):
    def __init__(self, first: torch.nn.Module, second: torch.nn.Module):
        super(Module1, self).__init__()
        self.first = first
        self.second = second

    @jit.script_method
    def forward(self, input: Tensor) -> Tensor:
        out = torch.eye(2)
        out[0,0] = self.first(input)
        out[1,1] = self.second(input)
        return out


def test1():
    m1 = Module1(first=nn.Linear(1, 1), second=nn.Linear(1, 1))
    out = m1(torch.ones(1))
    print(out)


if __name__ == '__main__':
    test1()

This runs fine with PYTORCH_JIT=0 python3 test.py, but an error is thrown with jit:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "torch_kalman/rewrite/test.py", line 14, in forward
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = torch.eye(2)
        out[0,0] = self.first(input)
        ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        out[1,1] = self.second(input)
        return out
RuntimeError: output with shape [] doesn't match the broadcast shape [1]

I am not sure why the interpreter thinks that the module outputs have zero extent. Is this a known unsupported feature? I could not find it in documentation. Or is this a bug?

EDIT: Maybe “output of a module” is not needed in the title – seems like nothing generated at runtime can be assigned to tensor-elements? For example, the following attempt at a workaround hits the same error:

    @jit.script_method
    def forward(self, input: Tensor) -> Tensor:
        out = jit.annotate(List[Tensor], [])
        out += [self.first(input)]
        out += [self.second(input)]
        out = torch.stack(out)
        out2 = torch.eye(len(out))
        for i in range(len(out)):
            out2[i, i] = out[i]
        return out2

The following seems to work:

    def forward(self, input: Tensor) -> Tensor:
        out = torch.eye(2)
        out[slice(0, 1), slice(0, 1)] = self.first(input)
        out[slice(1, 2), slice(1, 2)] = self.second(input)
        return out

that’s a quirk with zero-dimensional tensors, out[0,0] and out[0,0:1] are not equivalent, and broadcasting from 1-dim to 0-dim is like a “special case”, that is apparently not handled in TorchScript mode.