Could you post a small dummy code to see, how you are using this module?
If you want to manipulate the weights of your conv layer, you could guard it with a with torch.no_grad()
statement:
class MaskedConv1d(nn.Module):
def __init__(self, mask_type, *args, **kwargs):
super(MaskedConv1d, self).__init__()
assert mask_type in ['A', 'B']
self.conv = nn.Conv1d(3, 6, 3, 1, 1)
_, _ , kw = self.conv.weight.size()
with torch.no_grad():
self.conv.weight[:,:, kw // 2 + (mask_type == 'B'):] = 0
def forward(self, x):
return self.conv(x)
model = MaskedConv1d('A')
x = torch.randn(1, 3, 24)
output = model(x)
output.mean().backward()
print(model.conv.weight.grad)