Forward method in custom layer with 3D tensor multiplying weights

Hello there. I’m new to pytorch and I’m having problems implementing a custom forward pass. I’m trying to implement a custom linear for a MLP with the following forward method:

def forward(self, x, hebb):

      torch.clamp(hebb, min = -1.0, max = 1.0)

      if self.training:

          unit_noise = torch.rand(self.size_out) * self.noise_scale # noise to be added to units' outputs.

          w_times_x = torch.add(torch.mm(x, (self.weights + hebb).t()), unit_noise)

      else:

          w_times_x = torch.mm(x, (self.weights + hebb).t())

      yout = torch.add(w_times_x, self.bias)

      hebb = torch.einsum('bi,bj->bij', x, yout) # H(i, j) = y_i * y_j

      return yout, hebb

The problem that I’m having is that hebb is a 3D tensor where the first dimension is the batch size, so I can’t add it to self.weights since it is a 2D tensor. What I need is that each “sample index” in the batch has access to the respective index is this hebb 3D tensor, like so when going through the training loop:

hebb = torch.zeros((batch_size, n, m), dtype = torch.float32, requires_grad = False)

for batch_inputs, batch_outputs in dataloader:
        model(batch_inputs, hebb)
        # inside the foward pass I would need something like:
        # w_times_x = torch.add(torch.mm(x[sample_index], (self.weights + hebb[sample_index]).t()), unit_noise)

I can’t really provide the samples one-by-one so I’m trying to make this work with batches. Thx for the help in advance.