Hi everyone!
I want to build a specific implementation with PyTorch and train it with FSDP, and I am facing some problems.
First of all, let me describe my implementation with pseudocode.
Let model be a nn.Module with a shared submodule (also nn.Module), and the heads being also submodules (and nn.Modules). I have implemented the following pseudocode in PyTorch, and due to memory constraints, I have to use FSDP to train it.
z = model.shared(x)
d = z.detach()
d.requires_grad = True
for i in range(n):
p = model.heads[i](d)
loss = loss_function(p, y[i])
loss.backward()
# gradients accumulate in d
z.backward(gradient=d.grad)
However, there are some limitations.
In my original implementation, all of this code was included in the forward method of the model, but with FSDP, we can’t call backward inside the forward pass. So I have made some if-else statements in the forward, to know if I want to use the shared part or a head, and I wrote the sequential forward-backward logic in my training loop. However, when the code reaches the last line (where I want to do backwards with custom gradient), it breaks, and logs a RunTime error :
setStorage: sizes [4096, 14336], strides [14336, 1], storage offset 159383552, and itemsize 2 requiring a storage size of 436207616 are out of bounds for storage of size 0.
I suppose the error is due to the custom gradient, since the other backward passes were executed without problems. But is there any workaround to fix this?
Or is there any other way to use FSDP to train a model according to this pseudocode?