Issue with compiled resnet50 and tensor repeat

I am facing stride issues with a compiled resnet50 model and a repeated tensor. When I do backward after doing a forward pass, it gives a stride error. I am using torch version ‘2.1.2’, torchvision 0.16.2.

This is the code:


import torch
import torchvision
model = torchvision.models.resnet50(pretrained=True)
model.to('cuda')

model = torch.compile(model)
ref = torch.randn(1,3,224,224).to('cuda')
ref = torch.randn(1,3,224,224,requires_grad=True).to('cuda')
model(ref).sum().backward()
# This works
ref = torch.randn(1,3,224,224,requires_grad=True).to('cuda')
model(ref.repeat(20,1,1,1)).sum().backward()
# This throws an error

This is the printout of the error:

/opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
/opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
/opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
/opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,
---------------------------------------------------------------------------                                                                                                                                    [54/1823]
AssertionError                            Traceback (most recent call last)
Cell In[4], line 2
      1 ref = torch.randn(1,3,224,224,requires_grad=True).to('cuda')
----> 2 model(ref.repeat(20,1,1,1)).sum().backward()

File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    482 if has_torch_function_unary(self):
    483     return handle_torch_function(
    484         Tensor.backward,
    485         (self,),
   (...)
    490         inputs=inputs,
    491     )
--> 492 torch.autograd.backward(
    493     self, gradient, retain_graph, create_graph, inputs=inputs
    494 )

File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/autograd/__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    246     retain_graph = create_graph
    248 # The reason we repeat the same comment below is that                                                                                                                                                           
    249 # some Python versions print out the first line of a multi-line function                                                                                                                                        
    250 # calls in the traceback and some print out the last line                                                                                                                                                       
--> 251 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass                                                                                                                  
    252     tensors,
    253     grad_tensors_,
    254     retain_graph,
    255     create_graph,
    256     inputs,
    257     allow_unreachable=True,
    258     accumulate_grad=True,
    259 )

File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/autograd/function.py:288, in BackwardCFunction.apply(self, *args)
    282     raise RuntimeError(
    283         "Implementing both 'backward' and 'vjp' for a custom "
    284         "Function is not allowed. You should only implement one "
    285         "of them."
    286     )
    287 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 288 return user_fn(self, *args)

File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:3232, in aot_dispatch_autograd.<locals>.CompiledFunction.backward(ctx, *flat_args)
   3230     out = CompiledFunctionBackward.apply(*all_args)
   3231 else:
-> 3232     out = call_compiled_backward()
   3233 return out

File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:3204, in aot_dispatch_autograd.<locals>.CompiledFunction.backward.<locals>.call_compiled_backward()
   3199     with tracing(saved_context), context(), track_graph_compiling(aot_config, "backward"):
   3200         CompiledFunction.compiled_bw = aot_config.bw_compiler(
   3201             bw_module, placeholder_list
   3202         )
-> 3204 out = call_func_with_args(
   3205     CompiledFunction.compiled_bw,
   3206     all_args,
   3207     steal_args=True,
   3208     disable_amp=disable_amp,
   3209 )
   3211 out = functionalized_rng_runtime_epilogue(CompiledFunction.metadata, out)
   3212 return tuple(out)

File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:1506, in call_func_with_args(f, args, steal_args, disable_amp)
   1504 with context():
   1505     if hasattr(f, "_boxed_call"):
-> 1506         out = normalize_as_list(f(args))
   1507     else:
   1508         # TODO: Please remove soon
   1509         # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
   1510         warnings.warn(
   1511             "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. "
   1512             "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
   1513             "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
   1514         )

File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py:328, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    326 dynamic_ctx.__enter__()
    327 try:
--> 328     return fn(*args, **kwargs)
    329 finally:
    330     set_eval_frame(prior)

File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_dynamo/external_utils.py:17, in wrap_inline.<locals>.inner(*args, **kwargs)
     15 @functools.wraps(fn)
     16 def inner(*args, **kwargs):
---> 17     return fn(*args, **kwargs)

File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_inductor/codecache.py:374, in CompiledFxGraph.__call__(self, inputs)
    373 def __call__(self, inputs) -> Any:
--> 374     return self.get_current_callable()(inputs)

File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:628, in align_inputs_from_check_idxs.<locals>.run(new_inputs)
    626 def run(new_inputs):
    627     copy_misaligned_inputs(new_inputs, inputs_to_check)
--> 628     return model(new_inputs)

File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_inductor/codecache.py:401, in _run_from_cache(compiled_graph, inputs)
    391     from .codecache import PyCodeCache
    393     compiled_graph.compiled_artifact = PyCodeCache.load_by_key_path(
    394         compiled_graph.cache_key,
    395         compiled_graph.artifact_path,
   (...)
    398         else (),
    399     ).call
--> 401 return compiled_graph.compiled_artifact(inputs)

File /tmp/torchinductor_root/rm/crmma3whzelf5smeadzh2ifwfc5wxcl4cj7kbc24iomgdxewltut.py:5364, in call(args)
   5362 del primals_4
   5363 buf466 = buf465[0]
-> 5364 assert_size_stride(buf466, (s0, 64, 56, 56), (200704, 1, 3584, 64))
   5365 buf467 = buf465[1]
   5366 assert_size_stride(buf467, (64, 64, 1, 1), (64, 1, 64, 64))

AssertionError: expected size 64==64, stride 3136==1 at dim=1

I cannot reproduce the issue using a recent nightly binary torch==2.4.0.dev20240506+cu124, so you might need to update to 2.3.0 or also a nightly binary.

1 Like

Seems to be working with torch version ‘2.3.0+cu121’. Will update in case face issues in further tests.