Any Difference between batch-style or for-loop style for time axis data?

Hello all

I have this navie question about: Is there any difference, in terms of training or gradient updating process, between batch-style or for-loop style for data that has time-axis, like video data ?

In my opinion, batch style work as in:

class network(nn.Module):
    def __init__(self):
        self.batch = 16
        self.T_length = 15
        self.conv = nn.Conv2d(...)
        self.cls = nn.Linear(...)

    def forward(self, Input):
        # Input has size [batch * time_length, 3, 224, 224]
        out = self.conv(Input) # [batch * time_length, 128, 28, 28]
        out = out.view(batch * time_length, -1)
        score = self.cls(out)

        score = score.view(batch, time_length, -1)
        # score for single video is the average over all frames
        score = torch.mean(score, axis=1)
        return 

While, for-loop style work as

class network(nn.Module):
    def __init__(self):
        self.batch = 16
        self.T_length = 15
        self.conv = nn.Conv2d(...)
        self.cls = nn.Linear(...)

    def forward(self, Input):
        # Input has size [time_length, batch, 3, 224, 224]
        score = []
        for i in input:
            score.append(self.cls(self.conv(Input).view(batch, -1)))

        # concate list or tensor to new tensor
        score = torck.stack(score, axis=1) # [ batch, time_length, num_class]
        # average over time axis
        score = torch.mean(score, axis=1)
        return score
        # Then optimize

Besides some performance difference, both approaches will yield the same results and gradients.
However, there are some minor issues in your code.
E.g. the view call in the first example should be score = score.view(time_length, batch_size, -1), if you want to get identical values for both models. Otherwise they will be a bit mixed up.
Also you are passing some wrong arguments, e.g. axis instead of dim.

Here is a fixed version with some comparisons:

class network(nn.Module):
    def __init__(self):
        super(network, self).__init__()
        self.batch = 16
        self.T_length = 15
        self.conv = nn.Conv2d(3, 128, 3, 1, 1)
        self.cls = nn.Linear(128*5*5, 2)

    def forward(self, Input):
        # Input has size [batch * time_length, 3, 224, 224]
        out = self.conv(Input) # [batch * time_length, 128, 28, 28]
        out = out.view(batch_size * time_length, -1)
        score = self.cls(out)
        score = score.view(time_length, batch_size, -1)
        # score for single video is the average over all frames
        score = torch.mean(score, dim=1)
        return score


class network2(nn.Module):
    def __init__(self):
        super(network2, self).__init__()
        self.batch = 16
        self.T_length = 15
        self.conv = nn.Conv2d(3, 128, 3, 1, 1)
        self.cls = nn.Linear(128*5*5, 2)

    def forward(self, Input):
        # Input has size [time_length, batch, 3, 224, 224]
        score = []
        for x in Input:
            #x = Input[torch.tensor([i]).long()]
            score.append(self.cls(self.conv(x).view(batch_size, -1)))
        # concate list or tensor to new tensor
        score = torch.stack(score, dim=0) # [ batch, time_length, num_class]
        # average over time axis
        score = torch.mean(score, dim=1)
        return score
    

batch_size = 5
time_length = 10
x = torch.randn(time_length, batch_size, 3, 5, 5)

modelA = network()
modelB = network2()
modelB.load_state_dict(modelA.state_dict())

outputA = modelA(x.view(time_length*batch_size, *x.size()[2:]))
outputB = modelB(x)
outputA.mean().backward()
outputB.mean().backward()

print(torch.allclose(outputA, outputB))
print(torch.allclose(modelA.conv.weight.grad, modelB.conv.weight.grad))
print(torch.allclose(modelA.cls.weight.grad, modelB.cls.weight.grad))

Note that you might get a False in the comparisons sometimes, although the tests should pass most of the time. This is due to floating point precision and can most likely be ignored.

1 Like

Great appreciate for your response. Now I figure it out better !

I should have mentioned that above code is not strictly right. I wrote them just to illurstrate the idea. Sorry for the bother.

Best.

1 Like