Speed up applying the same function to different tensors in Pytorch

I have three tensors with very different shape on the last dimension, let us say tensor_a: (5, 35), tensor_b: (5, 70) and tensor_c: (5, 10). I need to apply exactly the same transformation function f() to these tensors, f(tensor_a), f(tensor_b), f(tensor_c) . Function f() consists of a bunch of layers. This is implemented easily with Pytorch.

But I noticed Gpu usage by doing sequential like that is low (around 50%). I sped up by adding padding for tensor_a and tensor_c to make them to the same shape of tensor_b, i.e.(5, 70) first, then combining these three together into a single one and calling f(combing_tensor) instead. After that I simply slice the tensor again (:5, 5:10, 10:) to get thee resulting tensors I want. With this I noticed GPU usage is around 90% and thus faster. However, it has a downside that it requires a lot of extra GPU memory because of padding, so it is not a perfect solution.

I wonder how to do it better in my case? I think need sort of embarrassing parallel in Pytorch but googling does not give me a good answer to my problem. Please let me know if there is a way to improve. Many thanks!

How about sth like this:

concatenatted_tensor = torch.cat((a, b, c), dim=1)

This will result in a tensor of shape torch.Size([5, 115]). After computing f on this tensor, just do

a_new = concatenatted_tensor[:, :35]
b_new = concatenatted_tensor[:, 35:105]
c_new = concatentatted_tensor[:, 105:]

Thanks. This way it does not work actually, because I need the transformation function processes a, b and c separately. Modifying f to take the concatenatted_tensor as input while keeping its subparts a, b, c separately is quite complicate to do.

How had you done this here? You wrote that you had padded your a into the shape torch.Size([5, 70]), how did you afterwards combine a and b? Did you do it with torch.cat() or torch.stack()?

Let me assume two new tensors a_new, c_new have shape: (5, 70) after padding. So it is as simple as having a combine_tensor = torch.cat([a_new, b, c_new], 0). With that combine_tensor has shape of (15, 70) and Function f takes combine_tensor as input.

OK, I understand your problem now. I don’t see a solution to your problem… There is a Python module called akward that is better able to deal with non-rectangular data than NumPy, but I’d say in your case, since you need to calculate f(x), where f is a NN, you cannot use this module here.

Therefore, I’d go with the sequential approach, even if that has a lower utilization of the GPU as a consequence. I’d only switch to the padding option if that’s faster (you can just time both approaches).

1 Like