When F.interpolate is used model is no longer optimizing

Hi!

I’m trying to visualize the features of a cnn. I first register a forward hook on the layer I want to visualize the feature maps. Afterwards I create a noisy image with required_grad=True which I feed into the cnn. As loss I use the negative mean from the activation map I want to optimize the given input image.
To get better results I try to rescale the image after each optimization phase by using F.interpolate. But as soon as I use it the input image is no longer optimized. Why is that the case?

class SaveFeatures():
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.features = output
    def close(self):
        self.hook.remove()

class FilterVisualizer():
    def __init__(self, encoder, decoder, encoder_path=None, decoder_path=None, upscaling_steps=12, upscaling_factor=1.2):
        self.upscaling_steps, self.upscaling_factor = upscaling_steps, upscaling_factor
        self.encoder = encoder
        self.decoder = decoder
        if encoder_path and decoder_path:
            self.encoder.load_state_dict(torch.load(encoder_path, map_location=torch.device('cpu')))
            self.decoder.load_state_dict(torch.load(decoder_path, map_location=torch.device('cpu')))
        self.encoder.eval()
        self.decoder.eval()
        set_trainable(self.encoder, False)
        set_trainable(self.decoder, False)
        
    def visualize(self, layer, filter, PATH='data/snail.jpg', imsize=-1, lr=0.1, opt_steps=30, sigma=25):
        # the image that was used for training the model
        img = pil_to_np(crop_image(get_image(PATH, imsize)[0], d=32))

        # create noisy img with same shape as image that was sued for training
        noisy_img = np.clip(np.random.normal(scale=sigma, size=img.shape), 0, 1).astype(np.float32)

        # make it a variable
        img_var = np_to_torch(noisy_img, requires_grad=True)
        shape = torch_to_np(img_var).shape[1:]

        # register hook on specific layer
        activations = SaveFeatures(list(list(self.decoder.children())[0].children())[layer])

        optimizer = torch.optim.Adam([img_var], lr=lr, weight_decay=1e-6)

        # scale the image down to extract low frequency patterns
        img_var = F.interpolate(img_var, size=(int(shape[0]*0.2), int(shape[1]*0.2)))

        for factor in range(3, 11):
            scale_factor = factor/10
            for n in range(opt_steps):  # optimize pixel values for opt_steps times
                optimizer.zero_grad()

                out = self.decoder(self.encoder(img_var))

                loss = -activations.features[0, filter].mean()
                loss.backward()
                optimizer.step()

            # scale image up for next optimization step
            img_var = F.interpolate(img_var, size=(int(shape[0]*scale_factor), int(shape[1]*scale_factor)))

layer = 19
filter = -1
FV = FilterVisualizer(encoder=Encoder(), decoder=Decoder(), encoder_path='saved_models/encoder_4.pth', decoder_path='saved_models/decoder_4.pth', upscaling_steps=12, upscaling_factor=2)
FV.visualize(layer, filter)

F.interpolate won’t detach the computation graph and the output tensor will have UpsampleNearest2DBackward as its .grad_fn.
You could check the model for valid gradients by printing them inside the loop just to make sure the parameters are still being updated.

1 Like

Thank you very much @ptrblck for your reply!

What do you mean by valid gradients? requires_grad is always set to True.

You could print the gradient norm or sum after calling loss.backward() to check, the magnitude etc. and more importantly, if gradients are calculated at all or if somehow the computation graph was detached.

loss.backward()
print(model.layer.weight.grad.norm())
1 Like

First the architecture of my model:

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), 
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(True),
            nn.Conv2d(32, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(True),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(True),
            nn.Conv2d(64, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(True)
        )

    def forward(self, x):
        x = self.encoder(x)
        return x


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReLU(True),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReLU(True),
            nn.Conv2d(32, 32, 3, stride=1, padding=1),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReLU(True),
            nn.Conv2d(32, 16, 3, stride=1, padding=1),
            nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReLU(True),
            nn.Conv2d(16, 3, 3, stride=1, padding=1),
            nn.BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReLU(True),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

I use the output of the ReLU module in the decoder as my loss. If I try the proposed print(model.layer.weight.grad.norm() I get AttributeError: 'ReLU' object has no attribute 'weight' which also makes sense. But If I try to print the .weight.grad.norm() of the Conv2d layer before the ReLU activation I get None as output. I get also None when I don’t use F.interpolate on my input image but I get a loss to optimize.