from sentence_transformers import SentenceTransformer
import torch
activation = {}
def hook(name, output):
activation[name] = output[0].detach()
model = SentenceTransformer('T-Systems-onsite/cross-en-de-roberta-sentence-transformer')
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
cos_sim = torch.nn.CosineSimilarity(dim=0)
optimizer.zero_grad()
query_prediction = model.encode("a man is cutting up a tomato", convert_to_tensor=True)
positive_prediction = model.encode("a man is slicing a tomato", convert_to_tensor=True)
negative_prediction = model.encode("she's brushing her hair", convert_to_tensor=True)
dist_pos = 1 - cos_sim(query_prediction, positive_prediction)
dist_neg = 1 - cos_sim(query_prediction, negative_prediction)
loss = torch.max(torch.tensor(0.), 0.7 + dist_pos - dist_neg)
if loss != torch.tensor(0.):
loss.backward()
for p in model.parameters():
print(p.grad)
for name, layer in model.named_modules():
layer.register_forward_hook(hook(name,query_prediction))
print(activation)
The loss is being calculated, but the gradients are None. Therefore, the model is not training.
When going through the sentence transformer code, within the encode method the forward seems to be calculated with no grad. Might that be the problem?
Any tips or ideas on why the gradients are None would be much appreciated.
This would mean that at least the model output is attached to the graph, so you could check the grad_fn attributes of previous activations and check, if any yields a None.
The parameters don’t have any grad_fn, as they are leaf nodes, so you would need to check the forward activations either directly in the forward method, e.g. via:
def forward(self, x):
x = self.layer(x)
print(x.grad_fn)
...
def hook(name, output):
activation[name] = output[0].detach()
query_prediction = model.encode("a man is cutting up a tomato", convert_to_tensor=True)
for name, layer in model.named_modules():
layer.register_forward_hook(hook(name,query_prediction))
print(activation)
I have got the same issue, and I am getting <UnsafeViewBackward0 object at 0x2b2ac389e220> when I access grad_fn. Is this something I should not be expecting?