Vectorizing the forward function for performance

Hello,

Some background:
After training my model - I want to calculate the backpropagation value of an input without updating the weights and reject(overwrite the label to K+1) any input that would have made large changes if the model was training.

The way I’m doing it right now is very slow.

I know that I can set the reduction to False in the loss function and get the loss for every batch element, but the autograd calculation will not work on multiple values like that.

Is there any way to make the forward function faster?

class MalwareKnown(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.CNN = torch.nn.Sequential(torch.nn.Conv1d(in_channels=2, out_channels=32, kernel_size=3),
                                       torch.nn.ReLU(),
                                       torch.nn.MaxPool1d(kernel_size=3),
                                       torch.nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3),
                                       torch.nn.ReLU(),
                                       torch.nn.MaxPool1d(kernel_size=3),
                                       torch.nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3),
                                       torch.nn.ReLU())
        self.clf = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(in_features=256, out_features=128),
                                       torch.nn.Linear(in_features=128, out_features=2))

        self.finished_training = False
        self.thresh = 0.6
        self.criterion = torch.nn.CrossEntropyLoss()
    def forward(self, x):
        if self.finished_training:
            o = self.clf(self.CNN(x))
            new_data = torch.zeros_like(F.pad(input=o, pad=(0, 1)))
            x.requires_grad=True
            unknown_labels = torch.max(o, axis=1)
            unknown_dataset = TensorDataset(x, unknown_labels[1])
            unknown_dataloader = DataLoader(unknown_dataset, batch_size=1, shuffle=False)
            for i, data in enumerate(unknown_dataloader):
                inputs, labels = data
                self.zero_grad()
                # forward + backward + optimize
                outputs = self.clf(self.CNN(inputs.float()))
                loss = self.criterion(outputs, labels)
                loss.backward()
                grad_after = self.CNN[6].bias.grad

                if torch.sqrt(torch.sum(torch.pow(grad_after, 2))) > self.thresh:
                    outputs[outputs != 0] = 0
                    outputs = F.pad(input=outputs, pad=(0, 1), value=1)
                    new_data[i] = outputs
                else:
                    new_data[i] = F.pad(input=outputs, pad=(0, 1), value=0)
            return new_data

        x = self.CNN(x)
        return F.pad(input=self.clf(x), pad=(0, 1))

You might want to check out GitHub - pytorch/functorch: functorch is JAX-like composable function transforms for PyTorch. which has a vectorization map for PyTorch