Hi,
I created a module that holds a list of different modules and randomly applies one of them on each forward path. This is my current (ugly!) implementation:
class OneOf(nn.Module):
def __init__(self, *args) -> None:
super().__init__()
self.layers = nn.ModuleList(list(args) if len(args) else [nn.Identity()])
n_layers = len(self.layers)
self._weights = nn.parameter.Parameter(
torch.rand(n_layers), requires_grad=False
)
def forward(
self, x: torch.Tensor, sample: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
log_prob = nn.functional.log_softmax(self._weights, dim=0)
idx = (
torch.multinomial(torch.exp(log_prob), 1)
if sample
else torch.argmax(log_prob)
)
y = torch.zeros_like(x)
for i, f in enumerate(self.layers): # change me!
if i == idx: # change me!
y = f(x)
return y, log_prob[idx]
if __name__ == '__main__':
f = OneOf()
f = torch.jit.script(f)
x = torch.rand(2, 3, 5)
y, log_prob = f(x)
assert torch.allclose(x, y)
assert np.isclose(log_prob.item(), 0.0)
This hideous iteration & idx
comparison is my workaround to avoid the indexing with a non-literal number. Is there a better way to do this w/o breaking the compilation to TorchScript?
Thanks for your help!