Hello there. I’m new to pytorch and I’m having problems implementing a custom forward pass. I’m trying to implement a custom linear for a MLP with the following forward method:
def forward(self, x, hebb): torch.clamp(hebb, min = -1.0, max = 1.0) if self.training: unit_noise = torch.rand(self.size_out) * self.noise_scale # noise to be added to units' outputs. w_times_x = torch.add(torch.mm(x, (self.weights + hebb).t()), unit_noise) else: w_times_x = torch.mm(x, (self.weights + hebb).t()) yout = torch.add(w_times_x, self.bias) hebb = torch.einsum('bi,bj->bij', x, yout) # H(i, j) = y_i * y_j return yout, hebb
The problem that I’m having is that hebb is a 3D tensor where the first dimension is the batch size, so I can’t add it to self.weights since it is a 2D tensor. What I need is that each “sample index” in the batch has access to the respective index is this hebb 3D tensor, like so when going through the training loop:
hebb = torch.zeros((batch_size, n, m), dtype = torch.float32, requires_grad = False) for batch_inputs, batch_outputs in dataloader: model(batch_inputs, hebb) # inside the foward pass I would need something like: # w_times_x = torch.add(torch.mm(x[sample_index], (self.weights + hebb[sample_index]).t()), unit_noise)
I can’t really provide the samples one-by-one so I’m trying to make this work with batches. Thx for the help in advance.