Vectorizing the forward function for performance


Some background:
After training my model - I want to calculate the backpropagation value of an input without updating the weights and reject(overwrite the label to K+1) any input that would have made large changes if the model was training.

The way I’m doing it right now is very slow.

I know that I can set the reduction to False in the loss function and get the loss for every batch element, but the autograd calculation will not work on multiple values like that.

Is there any way to make the forward function faster?

class MalwareKnown(torch.nn.Module):
    def __init__(self):
        self.CNN = torch.nn.Sequential(torch.nn.Conv1d(in_channels=2, out_channels=32, kernel_size=3),
                                       torch.nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3),
                                       torch.nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3),
        self.clf = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(in_features=256, out_features=128),
                                       torch.nn.Linear(in_features=128, out_features=2))

        self.finished_training = False
        self.thresh = 0.6
        self.criterion = torch.nn.CrossEntropyLoss()
    def forward(self, x):
        if self.finished_training:
            o = self.clf(self.CNN(x))
            new_data = torch.zeros_like(F.pad(input=o, pad=(0, 1)))
            unknown_labels = torch.max(o, axis=1)
            unknown_dataset = TensorDataset(x, unknown_labels[1])
            unknown_dataloader = DataLoader(unknown_dataset, batch_size=1, shuffle=False)
            for i, data in enumerate(unknown_dataloader):
                inputs, labels = data
                # forward + backward + optimize
                outputs = self.clf(self.CNN(inputs.float()))
                loss = self.criterion(outputs, labels)
                grad_after = self.CNN[6].bias.grad

                if torch.sqrt(torch.sum(torch.pow(grad_after, 2))) > self.thresh:
                    outputs[outputs != 0] = 0
                    outputs = F.pad(input=outputs, pad=(0, 1), value=1)
                    new_data[i] = outputs
                    new_data[i] = F.pad(input=outputs, pad=(0, 1), value=0)
            return new_data

        x = self.CNN(x)
        return F.pad(input=self.clf(x), pad=(0, 1))

You might want to check out GitHub - pytorch/functorch: functorch is JAX-like composable function transforms for PyTorch. which has a vectorization map for PyTorch