import time
from transformers import AutoModelForCausalLM
torch.manual_seed(10)
num_tokens = 500
input = torch.randint(0, 30000, torch.Size([8, num_tokens])).cuda()
model = AutoModelForCausalLM.from_pretrained("gpt2").cuda()
o = model(input_ids=input).logits
##################################
# o[:, :500].detach()
# o = o[:, 500:]
##################################
output = o.sum()
torch.cuda.synchronize()
s = time.time()
output.backward()
torch.cuda.synchronize()
e = time.time()
print(e-s)
I am surprised that uncommenting the two lines between ### does not result in speeding up the backward pass, even though o ended up being an empty tensor. I wonder why that is the case. Thanks!
Calling backward on an empty tensor should raise an error:
lin = nn.Linear(10, 10)
x = torch.randn(1, 10)
out = lin(x)
out = out[:, 10:]
print(out)
# tensor([], size=(1, 0), grad_fn=<SliceBackward0>)
out.backward()
# RuntimeError: grad can be implicitly created only for scalar outputs
My output is the sum of an empty tensor. bp with it takes the same time as back-propagating through the sum of a non-empty tensor and I wonder why that is the case.