Hello everyone,
I’m trying to get a script of a model.
Here is the code(I removed most of the code in init function to avoid verbosity):
(spec.matrix is a square matrix that its shape is not predetermined.)
class Cell(nn.Module):
def __init__(self, spec, in_channels, out_channels):
super(Cell, self).__init__()
self.spec = spec
self.num_vertices = self.spec.matrix.shape[0]
# vertex_channels[i] = number of output channels of vertex i
self.vertex_channels = ComputeVertexChannels(in_channels, out_channels, self.spec.matrix)
def forward(self, x: Any) -> Any:
tensors = [x]
out_concat = []
for t in range(1, self.num_vertices - 1):
fan_in = []
for src in range(1, t):
if self.spec.matrix[src, t]:
fan_in.append(Truncate(tensors[src], torch.tensor(self.vertex_channels[t])))
if self.spec.matrix[0, t]:
fan_in.append(self.input_op[t](x))
# perform operation on node
vertex_input = sum(fan_in)
vertex_output = self.vertex_op[t](vertex_input)
tensors.append(vertex_output)
if self.spec.matrix[t, self.num_vertices - 1]:
out_concat.append(tensors[t])
if not out_concat:
assert self.spec.matrix[0, self.num_vertices - 1]
outputs = self.input_op[self.num_vertices - 1](tensors[0])
else:
if len(out_concat) == 1:
outputs = out_concat[0]
else:
outputs = torch.cat(out_concat, 1)
if self.spec.matrix[0, self.num_vertices - 1]:
outputs += self.input_op[self.num_vertices - 1](tensors[0])
return outputs
def Truncate(inputs, channels):
"""Slice the inputs to channels if necessary."""
input_channels = inputs.size()[1]
if input_channels < channels:
raise ValueError('input channel < output channels for truncate')
elif input_channels == channels:
return inputs # No truncation necessary
else:
# Truncation should only be necessary when channel division leads to
# vertices with +1 channels. The input vertex should always be projected to
# the minimum channel count.
assert input_channels - channels == 1
return inputs[:, :channels, :, :]
It returns the following RuntimeError:
Truncate(Tensor inputs, Tensor channels) -> (Tensor):
Expected a value of type 'Tensor (inferred)' for argument 'inputs' but instead found type 'Any'.
Inferred 'inputs' to be of type 'Tensor' because it was not annotated with an explicit type.:
File "model.py", line 136
for src in range(1, t):
if self.spec.matrix[src, t]:
fan_in.append(Truncate(tensors[src], torch.tensor(self.vertex_channels[t])))
~~~~~~~~ <--- HERE
if self.spec.matrix[0, t]:
As error declares, I should annotate the input in tensors. However, doing that will result in RuntimeError:Expected integer literal for index ...
.
As I understood from this post I tried adding @ torch.jit.interface
before the class Cell(nn.Module)
. It returned RuntimeError:interface declarations must have a return type annotated.
and it underlines all the lines in the init function.
Although I can shun the problem by adding a @torch.jit.ignore
before the forward function in the Cell Module, it seems a bit of an unreliable solution.
Thank you in advance for any guidance.