Is there any way to train a shared neural network in one step?

Hello, I’m trying to implement a model that is used multiple times in one step of each epoch. I think the code is like below. How to train such model in proper way? Thank you.

encoder = Encoder()
optimizer = torch.optim.SGD(encoder.parameters())
criterion = torch.nn.MSELoss()
for epoch in range(num_epoch):
    for images, labels in train_loader:
        # in this line, images.shape [b, c, h, w]
        something processing ...
        # in this line, images.shape [n, b, c, h, w]
        outs = []
        for _ in images[:]:
            outs = encoder(_)
        something processing ...
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

I changed some part of the above code, in which the encoder take a whole tensor as input like below. In this case, are there any differences from above code in learning process?

class Encoder(nn.Module):
    def __init__(self):
        # ... layers to use

    def _forward(self, x):
        # x.shape: [b, c, h, w]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        return x

    def forward(self, x):
        # x.shape: [n, b, c, h, w]
        r = []
        for _ in x[:]:
            r.append(self._forward(_))
        return torch.stack(r)

...

for epoch in range(num_epochs):
    for images, labels in train_loader:
        # ...
        outs = encoder(images)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

I guess there shouldn’t be a difference, but since your codes are incomplete I would recommend to run a quick test using a constant input and check the output as well as the gradients for an iteration.

I’m not familiar with your use case, but would it be possible to pass the complete input directly to the model instead of using a loop?

Thank you for reply! Yes, I knew that it is possible to pass entire input to the model by reshaping the dimensions of the input and try to test on your recommendation. :smiley: