Can someone explain the benefits of Batches?

For example, suppose I have:

x= torch.randn(100, 3, 3) # 300 3-dim vectors (or 100 3 x 3 matrices but in this context, vectors makes more sense) 
W = torch.randn(100, 3, 2) # 100 3 x 2 matrices, parameters
h = torch.bmm(x, W) #  300 2-dim vectors. 

If I want to achieve the same using linear algebra, I could do:

x = torch.randn(300, 3) # 300 3-dim vectors)
W = torch.randn(3, 2) # Much less parameters
h = torch.matmul(x, W)  # 300 2-dim vectors

Both operations transformed essentially 300 3-dim vectors into 300 2-dim vectors. One used much more parameters to pull this off, which could be an advantage of batch processing but I’m not sure. It could be also related to parallelization but it’s hard for me to distinguish where parallelization makes something really fast because of design (using multi-head rather than single head for Transformers) vs where it makes something fast because it was implemented under the hood somewhere.

In the optimal case, you estimate the loss over all instances of your data. In most cases, this is unrealistic. The opposite case would be to compute the loss for every data point, which would result in worse convergence of the model, since estimating the fit of the model according to a single observation is not the best thing to do - this is intuitively correct when taking into consideration real world examples such as weather or finances. So you take a batch, a specific number of instances, that you process in parallel, estimate the loss for the batch, and then adjust the model parameters for a better fit. This is done for all batches in your data. When every batch was seen once, you completed one epoch of training. When loss is still too large, you do another epoch and so on.
What it comes down to is that batches are important in order to prevent the loss (and hence parameter updates done (maybe look into SGD and backprop for this) don’t jump back and forth without any improvements in the long run. Target is to identify the parameter set which best fits your data, and that is best done in batches. As far as I know, choose the largest batch size your hardware can cope with, in this way your model will converge best.

In regard to your examples:

In the first, dimensions [100, 3, 3] tell you that you have 100 elements in your batch, and each element is a 3x3 matrix. In the second line, you have 100 3x2 matrices. The last line computes the dot product between the 3x3 and 3x2 matrices for each of the 100 batch elements. You GPU can do this in parallel, given the VRAM is large enough to cope with the memory requirements.

In the second example, you have one matrix with 300x3, and you compute the dot product with a single 3x2 matrix. You see the difference? its the same 3x2 matrix for all data, in the former, you have 100 different ones.

If you take a 300 x 2 matrix in the second line of the second example too, then the two approaches should be equivalent.

2 Likes

Thanks for response, here some follow up questions:

If you do it in batches:

W = torch.randn(100, 3, 2)

Refers to 100 3x2 matrices with each matrix performing the computation for a single batch?

After one epoch, do we adjust every single matrix based on what we expected the output for that batch to look like?

In the next epoch, will the same matrix multiplication be performed in the same batch? Because in the code, it’s just a for-loop, so yes. If you do an SGD, it should be a random one if I’m not mistaken…

No, I don’t. I wrote in the comments what x should be, it refers to 300 vectors but just split up 100 batches, with each batch being of size 3. So it’s not 100 3x3 matrices but 300 3-dim vectors because a 3x3 matrix can be either interpreted as 3 row-vectors or a matrix by itself. It’s just a different way of how I want to represent the same information. That’s why I say [100, 3, 3] and [300, 3] are equivalent to me.

x1 = torch.tensor([[[1, 2, 3], [2, 3, 4]], [[3, 1, 2], [1, 2, 1]]]) #Yes, 2 * 2x3 matrices but to me, it's 4 3-dim vectors!

x2 = torch.tensor([[1, 2, 3], [2, 3, 4], [3, 1, 2], [1, 2, 1]]) # Contains the exact same information, just represented with less dimensions

My main question was, if you do it in Batches, you have more parameters involved in the linear transformation that acts on your information. If you do it with less dimensions, you will use less parameters which reduce capacity of what the model is able to learn. At least that’s what I wanted to confirm with this post and then also find out what additional benefits torch.bmm yield, which you nicely explained with the parallel processing of the GPU.