Question as follows (first, thank you):
I have sequential model with Reshape nn.Module class inserted
model = torch.nn.Sequential(
torch.nn.Conv1d(batch_size, batch_size*filters_multiplier, 9),
torch.nn.ReLU(),
torch.nn.MaxPool1d(3),
Reshape(-1),
... pass to Linear
)
Reshape is this one (flattening):
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view(self.shape)
everything works great, I like the solution of embedding flatten into sequential. Model is saved via torch.save()
but when I’m trying to torch.load() into Flask app I get this:
AttributeError: Can't get attribute 'Reshape' on <module '__main__' from '/home/srg/anaconda3/bin/flask'>
And I can’t think of any workaround except not using sequential in the first place.