Hi Aidan, thanks for suggesting I use TorchScript! I’m part of the way there in re-writing my custom GRU, though I found out that there’s a problem with TorchScript using getattr
and setattr
to utilize layer names. Here’s a minimum example to replicate the issue:
Example Code
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.jit as jit
import warnings
from collections import namedtuple
from typing import List, Tuple
from torch import Tensor
from torch.nn.functional import one_hot, softmax
import numpy as np
import numbers
import sys
class BSGRUCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size, num_blocks, bias=True):
super(BSGRUCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_blocks = num_blocks
block_size = int(hidden_size / num_blocks)
if (block_size * num_blocks) != hidden_size:
raise ValueError("hidden_size must be divisible by num_blocks")
self.block_size = block_size
self.weight_ik = Parameter(torch.randn(input_size, num_blocks))
self.weight_hk = Parameter(torch.randn(hidden_size, num_blocks))
self.weight_ih = Parameter(torch.randn(
1, num_blocks, input_size, 3 * block_size))
self.weight_hh = Parameter(torch.randn(
1, num_blocks, num_blocks, block_size, 3 * block_size))
self.bias_ik = Parameter(torch.randn(num_blocks))
self.bias_ih = Parameter(torch.randn(3 * hidden_size))
self.bias_hh = Parameter(torch.randn(3 * hidden_size))
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
# "x" and "y" variable names refer to previous and current timesteps
hx, kx = state
# predict the current active block for this time step
ky = (torch.mm(input, self.weight_ik) +
torch.mm(hx, self.weight_hk) +
self.bias_ik)
if self.training:
beta = 5.
ky = softmax(beta*ky, dim=-1)
else:
ky = one_hot(ky.argmax(-1), num_classes=self.num_blocks).float()
# use the current block predictions to sparsely activate
# the input to hidden (W-matrix) weights
mat_W = torch.mul(ky.unsqueeze(-1).unsqueeze(-1), self.weight_ih)
mat_W = mat_W.view(-1, 3 * self.hidden_size, self.input_size)
mat_W = torch.transpose(mat_W, 1, 2)
# use the previous + current block predictions to sparsely activate
# the hidden to hidden (U-matrix) weights
kk = torch.bmm(kx.unsqueeze(2), ky.unsqueeze(1))
mat_U = torch.mul(kk.unsqueeze(-1).unsqueeze(-1), self.weight_hh)
mat_U = mat_U.view(-1, 3 * self.hidden_size, self.hidden_size)
mat_U = torch.transpose(mat_U, 1, 2)
# compute all the gate values
i_gates = torch.bmm(input.unsqueeze(1), mat_W) + self.bias_ih
h_gates = torch.bmm(hx.unsqueeze(1), mat_U) + self.bias_hh
i_z, i_r, i_n = i_gates.squeeze(1).chunk(3, -1)
h_z, h_r, h_n = h_gates.squeeze(1).chunk(3, -1)
updategate = torch.sigmoid(i_z + h_z)
resetgate = torch.sigmoid(i_r + h_r)
newgate = torch.tanh(i_n + h_n * resetgate)
hy = newgate + updategate * (hx - newgate)
return hy, (hy, ky)
class BSGRU(jit.ScriptModule):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
num_blocks: int = 1,
bias: bool = True,
batch_first: bool = False,
bidirectional: bool = False,
beta: float = 1.,
dropout: float = 0.,
device=None,
dtype=None) -> None:
super(BSGRU, self).__init__()
self.num_layers = num_layers
self.batch_first = batch_first
for l in range(num_layers):
setattr(self, "layer_{}".format(l),
BSGRUCell(input_size, hidden_size, num_blocks, bias=bias))
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
# put seq_len as the first axes
if self.batch_first:
input = torch.transpose(input, 0, 1)
in_tensor = input
for l in range(self.num_layers):
in_seq = in_tensor.unbind(0)
out_seq = torch.jit.annotate(List[Tensor], [])
for i in range(len(in_seq)):
out, state = getattr(self, "layer_{}".format(l))(in_seq[i], state)
out_seq += [out]
out_seq = torch.stack(out_seq)
in_tensor = out_seq
return out_seq
seq_len, batch_size, input_size, hidden_size, num_blocks = 11, 5, 7, 4, 2
inp = torch.randn(seq_len, batch_size, input_size)
h0 = torch.rand(batch_size, hidden_size)
k0 = torch.from_numpy(np.tile([0,1],[batch_size, 1])).float()
rnn = BSGRU(input_size, hidden_size, 1, num_blocks)
out = rnn(inp, (h0, k0))
print(out)
The error I get:
RuntimeError:
getattr's second argument must be a string literal:
File "temp.py", line 114
out_seq = torch.jit.annotate(List[Tensor], [])
for i in range(len(in_seq)):
out, state = getattr(self, "layer_{}".format(l))(in_seq[i], state)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
out_seq += [out]
out_seq = torch.stack(out_seq)
I was able to make the BSGRUCell successfully run, but how does one write modules that have a variable amount of layers if they aren’t supposed to use attributes?