I created model to do custom edge detection and the forward pass works fine and produces the right result.
However training does not work because the gradients of the parameters can not be computed. I made a minimal reproducer below, which throws the error. I think the issue here is that my output tensor has no gradients, but setting it to require_grad=True does not solve the issue, instead no error is thrown, but the gradients do not change even if the loss is super high.
This might be due to masks I use. Maybe you have an idea how to solve this issue since I want to use the model in downstream tasks and train it in combination with a classificator.
Here is the code:
from torch import nn
import torch
from torchvision.transforms import v2
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.x_sob = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 3), stride=(1, 1), padding="same")
self.y_sob = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 3), stride=(1, 1), padding="same")
self.x_sob.weight = nn.Parameter(torch.Tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]))
self.y_sob.weight = nn.Parameter(torch.Tensor([[[[1, 2, 1], [0, 0, 0], [-1, -2, -1]]]]))
#self.apply_thresh = MaxPool2dStaticSamePadding(kernel_size=(3, 3), stride=1)
self.low_threshold = nn.Parameter(torch.Tensor([0.05]))
self.high_threshold = nn.Parameter(torch.Tensor([0.15]))
self.blur = v2.GaussianBlur(kernel_size=5, sigma=1.4)
def forward(self, x):
# Gaussian blur & sobel filter
x = self.blur(x)
gX = self.x_sob(x)
gY = self.y_sob(x)
# Calculate Magnitude and angle theta
mag = torch.sqrt(torch.pow(gX, 2) + torch.pow(gY, 2))
mag = ((mag / torch.max(mag)) * 255)
theta = torch.arctan2(gY, gX)
# Create nms representation of img
#img = self.non_max_suppression(mag, theta)
img = mag
# calculate low and high threshold
batch_size = img.size(0)
h_t = img.reshape(batch_size, -1).max(dim=1).values.reshape(batch_size, 1, 1, 1) * self.high_threshold
l_t = h_t * self.low_threshold
# accept pixels values higher than h_t, reject pixel value lower l_t, create mask
accept_mask = (img > h_t).bool()
reject_mask = (img < l_t).bool()
# get values between high and low threshold
bet_mask = ((img > l_t) & (img < h_t)).bool()
# set all value in "img" to zero which are not True in the accept mask
img_acc = img.masked_fill(~accept_mask, 0)
# apply hysteresis by using max_pooling -> all pixels with neighbouring "strong" pixels get also "strong"
#img_th = self.apply_thresh(img_acc)
img_th = img_acc
# create a mask where all fields are true which fall between the thresholds
img_b_mask = img.masked_fill(~bet_mask, 0).bool()
# Replace all values < 0 in img_acc with the strong neighbours which are in bet_mask, else 0
out = torch.where(img_acc > 0, img_acc, img_th.masked_fill(~img_b_mask, 0))
out = torch.where(reject_mask, 0, out) # Apply reject mask directly
out = torch.where(out > 0, 255.0, 0.0).to(torch.float32)
# out.requires_grad = True
#THIS WORKS BETTER: return img_th[0, 0, :200, :200]
return out[0,0,:200,:200]
# Init
net = NeuralNetwork()
net.train()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
# Print parameters before backward to see if there is a change
print(list(net.parameters()))
# Dummy input and labels
lbl = torch.randn(3,1,224,224)
inp = torch.randn(3,1,224,224)
optimizer.zero_grad()
res = net(inp)
l = loss_fn(res, lbl[0,0,:200,:200])
print("LOSS:", l)
l.backward()
optimizer.step()
# Print parameters again
print(list(net.parameters()))