Implementing captum with pytorch-lightning

Hi there,

I also posted this question here: Implementing captum with pytorch-lightning · Issue #726 · pytorch/captum · GitHub

I am trying to use LayerGradCam in captum to interpret a particular layer in my model.

Part of the problem/complication seems to be that my model and forward method are defined in a pytorch-lightning module.

My pytorch-lightning module is:

class model(pl.LightningModule):
    def __init__(self, learning_rate = float):
        super().__init__()
        self.learning_rate = learning_rate
        self.criterion = nn.BCEWithLogitsLoss()
        self.cam = LayerGradCam(self.forward, 'model.5')
        self.model = nn.Sequential(cnnBlock1(), cnnBlock2(), cnnBlock3(), linearBlock())

    def forward(self, x):
        return self.model(x)

    def train_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        train_loss = self.criterion(y_hat, y)
        self.log('train_loss', train_loss)
        return train_loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        val_loss = self.criterion(y_hat, y)
        self.log('val_loss', val_loss)
        return val_loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        attr = self.cam.attribute(x)
        return attr

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate
        return optimizer

However, when I run the test step I am getting the error:

AttributeError: 'str' object has no attribute 'register_forward_hook'

I have two questions then:

  1. What does this error mean and how do I fix it?
  2. How do I/what is best practice for implementing captum with pytorch-lightning?

Thanks for your help!