I have a function that is used at the end of my model and I would like to backpropogate the error through this function. However when I check the gradients they are extremely small leading me to believe the gradient tracking is not working as desired.
The function itself looks as follows:
z = torch.stack([*map(i.__eq__,torch.unique(i))]).type(torch.float)*(torch.stack([v]*z.shape)).type(torch.float)
It’s quite a complicated function (from this original post) which uses
.__eq__ which are not native PyTorch functions and therefore I have a suspicion that one of these are the main issue with tracking gradients.
As the operation is quite complex, I have provided a graphical representation for visualising the operation below to help understand what the line does for given input tensors:
z. Here, the tensor
i dictates which channels of tensor
z that the elements of tensor
v fall into.
Is there a quick and easy way to test the gradient tracking through this function?
Side note: If there is no way to quickly test whether the gradients can be tracked, maybe it is worth me posting a feature request to have an almost unit-test style functionality for checking functions are appropriate to place within networks…?