Hi, I have a use case, Where I’m trying to have a convolutional network output a mask,
the way I’m doing it is by using a standard encoder-decoder architecture, and then the pixel locations with the highest x% of values is the mask I pick.
On a high level, this is what I’m doing is (can be seen in the FilterOutMask module in the code)
- Sorting the flattened output array
- Picking the Kth Index and Assigning Threshold as that value
- Creating a mask of all indices greater than the threshold ( by running array > threshold)
and the code for this looks like this
class FilterOutMask(nn.Module):
def __init__(self, sparsity):
super(FilterOutMask, self).__init__()
self.sparsity = sparsity
def forward(self, x):
w_shape=x.shape
print("using")
x = x.view(w_shape[0], -1).transpose(0, 1)
norm = torch.abs(x)
idx = int(0.5 * x.shape[0])
threshold = (torch.sort(norm, dim=0)[0])[idx]
mask = norm>threshold
mask = mask.transpose(0, 1).view(w_shape)
#new = torch.zeros_like(mask,dtype=torch.float32)
#print(new.shape)
#print(mask.shape)
#new[mask]=1
return mask.float()
class SamplerNetwork(nn.Module):
def __init__(self,gpu=True):
super(SamplerNetwork, self).__init__()
self.conv_1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=32,
kernel_size=3,
padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv_2 = nn.Sequential(
nn.Conv2d(
in_channels=32,
out_channels=64,
kernel_size=3,
padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.deconv_1 = nn.Sequential(
nn.ConvTranspose2d(
in_channels=64,
out_channels=32,
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(32),
nn.ReLU()
)
self.deconv_2 = nn.Sequential(
nn.ConvTranspose2d(
in_channels=32,
out_channels=1,
kernel_size = 3,
stride=2,
padding=1,
output_padding=1),
nn.ReLU()
)
self.drop = nn.Dropout2d(0.25)
print(gpu)
def forward(self, x):
input_shape=x.shape
out = self.conv_1(x)
out = self.conv_2(out)
out = self.deconv_1(out)
out = self.deconv_2(out)
out=out.view(input_shape)
return out
input = torch.randn(8,1,32,32,requires_grad=True)
sampler=SamplerNetwork()
outputs = torch.zeros((8,1,32,32),requires_grad=True)
inputs2=sampler(input)
filter=FilterOutMask(0.5)
pred=filter(inputs2)
loss = torch.nn.MSELoss()
output=loss(pred.float(),outputs.float())
output.backward()
print(input.requires_grad)
print(input.grad)```
**the loss Ive picked is arbitrary just for testing, **
**but the problem Im facing is **
**-That gradients for the input are None,**
Im pretty sure the issue is that any operation that outputs a bool tensor doesnt generate gradients, but do you guys have suggested workarounds for it ? do I have to write a custom backward pass?