Hello there. I’m new to pytorch and I’m having problems implementing a custom forward pass. I’m trying to implement a custom linear for a MLP with the following forward method:
def forward(self, x, hebb):
torch.clamp(hebb, min = -1.0, max = 1.0)
if self.training:
unit_noise = torch.rand(self.size_out) * self.noise_scale # noise to be added to units' outputs.
w_times_x = torch.add(torch.mm(x, (self.weights + hebb).t()), unit_noise)
else:
w_times_x = torch.mm(x, (self.weights + hebb).t())
yout = torch.add(w_times_x, self.bias)
hebb = torch.einsum('bi,bj->bij', x, yout) # H(i, j) = y_i * y_j
return yout, hebb
The problem that I’m having is that hebb is a 3D tensor where the first dimension is the batch size, so I can’t add it to self.weights since it is a 2D tensor. What I need is that each “sample index” in the batch has access to the respective index is this hebb 3D tensor, like so when going through the training loop:
hebb = torch.zeros((batch_size, n, m), dtype = torch.float32, requires_grad = False)
for batch_inputs, batch_outputs in dataloader:
model(batch_inputs, hebb)
# inside the foward pass I would need something like:
# w_times_x = torch.add(torch.mm(x[sample_index], (self.weights + hebb[sample_index]).t()), unit_noise)
I can’t really provide the samples one-by-one so I’m trying to make this work with batches. Thx for the help in advance.