While implementing a linear model, after pretraining and saving the SimCLR() model and loading it again for testing the accuracy, getting the NotImplementedError. Not able to identify the mistake.
class LinearHeadModel(nn.Module):
def __init__(self,simclr_model_dict, num_classes=2):
super(LinearHeadModel,self).__init__()
self.num_classes=num_classes
model = SimCLR(LSTMModel, 128, 41)
if simclr_model_dict:
print("loading feature extractor")
# filepath = '/Desktop/Untitled Folder/model.pt'
smclr = SimCLR(LSTMModel, 128, 41)
smclr.load_state_dict(torch.load('model.pt', map_location='cpu'))
# self.features = smclr.f
self.g = nn.Sequential(nn.Linear(512, out_features=self.num_classes, bias=True))
model = LinearHeadModel(simclr_model_dict=simclr_model_dict, num_classes=2)
# model.features.requires_grad_ = False
parameters = [param for param in model.parameters() if param.requires_grad is True] # trainable parameters
optimizer = torch.optim.SGD(
parameters,
0.1, # lr = 0.1 * batch_size / 256, see section B.6 and B.7 of SimCLR paper.
momentum=0.9,
weight_decay=0.,
nesterov=True)
model.train()
# device = model.device
for epoch in range(epochs):
total_loss = []
for batch_idx, (X_tr, Y_tr) in enumerate(train_ldr):
X_tr = X_tr.unsqueeze(0)
print(X_tr.shape)
X_tr = X_tr.view(256, 1, 41)
print(X_tr.shape)
print('TRAINING SAMPLES', X_tr)
Y_tr = Y_tr.type(torch.LongTensor)
optimizer.zero_grad()
# Forward pass
output = model(X_tr)
print('output', output)
loss = loss_func(output)
# Backward pass
loss.backward()
# Optimize the weights
optimizer.step()
total_loss.append(loss.item())
correct = (torch.argmax(out.to("cpu").data,1) == Y_tr.data).float().sum()
acc += float(100.0*(correct))
loader.set_description(f"Epoch: {i}, training_loss: {loss}, accuracy :{acc/total_num}")
print(f"Epoch {i} training loss: {total_loss/total_num} acc : {acc/total_num}")
NotImplementedError Traceback (most recent call last)
<ipython-input-41-a362a7aab1f1> in <module>
26
27 # Forward pass
---> 28 output = model(X_tr)
29 print('output', output)
30
~\anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
~\anaconda3\lib\site-packages\torch\nn\modules\module.py in _forward_unimplemented(self, *input)
199 registered hooks while the latter silently ignores them.
200 """
--> 201 raise NotImplementedError
202
203
NotImplementedError: