I have a “Reversible Vision Transformer”. It has a custom backwards pass, since it does not cache intermediate activations.
It seems to work fine for the most part, except for the following test case (2+ GPUs):
print("Running more complicated DDP test ...")
x1 = torch.randn((1, 3, 224, 224)).to(rank)
x2 = torch.randn((1, 3, 224, 224)).to(rank)
ddp_model = DDP(model, device_ids=[rank], static_graph=True)
optimizer = optim.SGD(ddp_model.parameters(), lr=0.0001)
outputs1 = ddp_model(x1)
outputs2 = ddp_model(x2)
(outputs1 + outputs2).sum().backward()
optimizer.step()
print("Re-entrant DDP passed")
then, I get the error:
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/home/vedant/miniconda3/envs/video-rec/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/home/vedant/Desktop/video-recommendation/model_contrastive/models/rev_vit.py", line 402, in ddp_test_helper
(outputs1 + outputs2).sum().backward()
File "/home/vedant/miniconda3/envs/video-rec/lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/vedant/miniconda3/envs/video-rec/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/vedant/miniconda3/envs/video-rec/lib/python3.8/site-packages/torch/autograd/function.py", line 253, in apply
return user_fn(self, *args)
File "/home/vedant/Desktop/video-recommendation/model_contrastive/models/rev_vit.py", line 270, in backward
X_1, X_2, dX_1, dX_2 = layer.backward_pass(
File "/home/vedant/Desktop/video-recommendation/model_contrastive/models/rev_vit.py", line 201, in backward_pass
g_Y_1.backward(dY_2, retain_graph=True)
File "/home/vedant/miniconda3/envs/video-rec/lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/vedant/miniconda3/envs/video-rec/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Your training graph has changed in this iteration, e.g., one parameter is unused in first iteration, but then got used in the second iteration. this is not compatible with static_graph set to True.
For context, the custom backward pass for a single “ReversibleBlock” in the transformer is:
def backward_pass(
self,
Y_1,
Y_2,
dY_1,
dY_2,
):
print(f"Backwards: {self.layer_id} device_id={Y_1.get_device()}")
"""
equation for activation recomputation:
X_2 = Y_2 - G(Y_1), G = MLP
X_1 = Y_1 - F(X_2), F = Attention
"""
# temporarily record intermediate activation for G
# and use them for gradient calculcation of G
with torch.enable_grad():
Y_1.requires_grad = True
g_Y_1 = self.G(Y_1)
assert g_Y_1.shape == Y_1.shape
g_Y_1.backward(dY_2, retain_graph=True)
# activation recomputation is by design and not part of
# the computation graph in forward pass.
with torch.no_grad():
X_2 = Y_2 - g_Y_1
del g_Y_1
dY_1 = dY_1 + Y_1.grad
Y_1.grad = None
# record F activations and calc gradients on F
with torch.enable_grad():
X_2.requires_grad = True
f_X_2 = self.F(X_2)
# torch.manual_seed(self.seeds["droppath"])
# f_X_2 = drop_path(
# f_X_2, drop_prob=self.drop_path_rate, training=self.training
# )
f_X_2.backward(dY_1, retain_graph=True)
# propagate reverse computed acitvations at the start of
# the previou block for backprop.s
with torch.no_grad():
X_1 = Y_1 - f_X_2
del f_X_2, Y_1
dY_2 = dY_2 + X_2.grad
X_2.grad = None
X_2 = X_2.detach()
return X_1, X_2, dY_1, dY_2
My main question is: why am I getting this error? My back-propagation function has no if-statements / conditionals, so it should be executing the same thing every time. Is there something I’m missing about how backprop works?
Full code: Reversible VIT · GitHub
Original paper (my code is a cut/paste/simplification of their implementation): https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf