Backward became extremely slow when batchsize increased

I changed the batchsize of input(from 1 to 2), however, loss.backward() became extremely slow(from 0.4s to 1s).
I am really confused, any help would be appreciated.

Can you share a minimal reproducible example of this behavior?

Thanks for your reply!

Scenario one:

# depths and poses are the output of NN
# mr_loss is the loss function

forward_depths, backward_depths = torch.split(depths, [b, b], dim=0)
forward_poses, backward_poses = torch.split(poses, [b, b], dim=0)
forward_mr_loss, forward_flows = mr_loss(forward_depths, forward_poses, img_seq, intrinsics)
backward_mr_loss, backward_flows = mr_loss(backward_depths, backward_poses, img_seq.flip(dims=[1]), intrinsics)
all_mr_loss = (forward_mr_loss + backward_mr_loss) / 2

all_mr_loss.backward()

Scenario two:

all_mr_loss, flows = mr_loss(depths, poses, torch.cat([img_seq, img_seq.flip(dims=[1])], dim=0), intrinsics.repeat(2, 1, 1))

all_mr_loss.backward()

Both scenarios have the same all_mr_loss value, actually the only different factor is the size of input. In scenario one, the input was split into two parts, and the results were added up.

However, the backward time is 0.4s in scenario one but 1s in scenario two.

Would it possible to share a snippet that is executable (e.g., with dummy/random tensors if the dataset is not available)? It would also be useful to see how the timing is being done to pinpoint the issue.