How to modify weights of a layer before the layer is applied to input?

Hi! I’m new to PyTorch and I am trying to modify weights of a layer before the layer is applied to input, but I don’t know how to get gradients right.

I have searched the forum and found some related discussions, but they are a bit different, in which weights are not modified at run time and in training.

Here’s an example. I want to apply a function (say Sigmoid) to the weight of a Conv2d before it is applied to an image, which means I want to use the sigmoid value of the weights to do the convolution instead of the weights themselves.

Here is my code, taking 28*28 vectors of MNIST dataset as input.

My intention is to save the original weights in self.conv_weight, and when doing forwarding, replace the weights of conv layers with f(wieghts) which is here sigmoid(self.conv_weight) while still preserving origal weights for BP. And I was expecting autograd will update self.conv_weight after opt.step().

But the problem is that it seems no grads are attatched to self.conv_weight when doing forwarding(self.conv_weight.grad is None after calculating the loss).

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
        self.conv1_weight = self.conv1.weight.data
        self.conv1_weight.requires_grad_()
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)
        self.conv2_weight = self.conv2.weight.data
        self.conv2_weight.requires_grad_()
        self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1)
        self.conv3_weight = self.conv3.weight.data
        self.conv3_weight.requires_grad_()

    def forward(self, xb):
        relu = F.relu
        sigmoid = torch.sigmoid
        xb = xb.view(-1, 1, 28, 28)
        self.conv1.weight.data = sigmoid(self.conv1_weight)
        xb = relu(self.conv1(xb))
        self.conv2.weight.data = sigmoid(self.conv2_weight)
        xb = relu(self.conv2(xb))
        self.conv3.weight.data = sigmoid(self.conv3_weight)
        xb = relu(self.conv3(xb))
        xb = F.avg_pool2d(xb, 4)
        return xb.view(-1, xb.size(1))

My code related to fitting and loss is below

def get_model():
    model = Net()
    return model, optim.SGD(model.parameters(), lr=LR, momentum=0.9)

def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)
    print(model.conv1_weight.grad)
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)


def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train()
        orig = model.conv1.weight.data.clone()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print(epoch, val_loss)

train_dl , valid_dl = get_dataloaders(train_ds, valid_ds, BATCH_SIZE)
model, opt = get_model()
fit(10, model, F.cross_entropy,opt, train_dl, valid_dl)

My code is a “workaround”, though it seems failed. Could you please help me with that and possibly explain why it failed and how nn.Module actually works? THANKS A LOT!

Don’t use the .data attribute, as it might break your code in various ways.

For your use case, you could use the functional API as given in this example:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1)
        
    def forward(self, x):
        x = F.conv2d(x, torch.sigmoid(self.conv1.weight), self.conv1.bias)
        x = F.relu(x)
        x = F.conv2d(x, torch.sigmoid(self.conv2.weight), self.conv2.bias)
        x = F.relu(x)
        x = F.conv2d(x, torch.sigmoid(self.conv3.weight), self.conv3.bias)
        x = F.relu(x)
        x = F.avg_pool2d(x, 4)
        return x


model = Net()
data = torch.randn(1, 1, 28, 28)
out = model(data)
out.mean().backward()

for name, param in model.named_parameters():
    print(name, param.grad)