I am facing stride issues with a compiled resnet50 model and a repeated tensor. When I do backward after doing a forward pass, it gives a stride error. I am using torch version ‘2.1.2’, torchvision 0.16.2.
This is the code:
import torch
import torchvision
model = torchvision.models.resnet50(pretrained=True)
model.to('cuda')
model = torch.compile(model)
ref = torch.randn(1,3,224,224).to('cuda')
ref = torch.randn(1,3,224,224,requires_grad=True).to('cuda')
model(ref).sum().backward()
# This works
ref = torch.randn(1,3,224,224,requires_grad=True).to('cuda')
model(ref.repeat(20,1,1,1)).sum().backward()
# This throws an error
This is the printout of the error:
/opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
torch.has_cuda,
/opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
torch.has_cudnn,
/opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
torch.has_mps,
/opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
torch.has_mkldnn,
--------------------------------------------------------------------------- [54/1823]
AssertionError Traceback (most recent call last)
Cell In[4], line 2
1 ref = torch.randn(1,3,224,224,requires_grad=True).to('cuda')
----> 2 model(ref.repeat(20,1,1,1)).sum().backward()
File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
482 if has_torch_function_unary(self):
483 return handle_torch_function(
484 Tensor.backward,
485 (self,),
(...)
490 inputs=inputs,
491 )
--> 492 torch.autograd.backward(
493 self, gradient, retain_graph, create_graph, inputs=inputs
494 )
File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/autograd/__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
246 retain_graph = create_graph
248 # The reason we repeat the same comment below is that
249 # some Python versions print out the first line of a multi-line function
250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
252 tensors,
253 grad_tensors_,
254 retain_graph,
255 create_graph,
256 inputs,
257 allow_unreachable=True,
258 accumulate_grad=True,
259 )
File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/autograd/function.py:288, in BackwardCFunction.apply(self, *args)
282 raise RuntimeError(
283 "Implementing both 'backward' and 'vjp' for a custom "
284 "Function is not allowed. You should only implement one "
285 "of them."
286 )
287 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 288 return user_fn(self, *args)
File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:3232, in aot_dispatch_autograd.<locals>.CompiledFunction.backward(ctx, *flat_args)
3230 out = CompiledFunctionBackward.apply(*all_args)
3231 else:
-> 3232 out = call_compiled_backward()
3233 return out
File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:3204, in aot_dispatch_autograd.<locals>.CompiledFunction.backward.<locals>.call_compiled_backward()
3199 with tracing(saved_context), context(), track_graph_compiling(aot_config, "backward"):
3200 CompiledFunction.compiled_bw = aot_config.bw_compiler(
3201 bw_module, placeholder_list
3202 )
-> 3204 out = call_func_with_args(
3205 CompiledFunction.compiled_bw,
3206 all_args,
3207 steal_args=True,
3208 disable_amp=disable_amp,
3209 )
3211 out = functionalized_rng_runtime_epilogue(CompiledFunction.metadata, out)
3212 return tuple(out)
File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:1506, in call_func_with_args(f, args, steal_args, disable_amp)
1504 with context():
1505 if hasattr(f, "_boxed_call"):
-> 1506 out = normalize_as_list(f(args))
1507 else:
1508 # TODO: Please remove soon
1509 # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
1510 warnings.warn(
1511 "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. "
1512 "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
1513 "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
1514 )
File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py:328, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
326 dynamic_ctx.__enter__()
327 try:
--> 328 return fn(*args, **kwargs)
329 finally:
330 set_eval_frame(prior)
File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_dynamo/external_utils.py:17, in wrap_inline.<locals>.inner(*args, **kwargs)
15 @functools.wraps(fn)
16 def inner(*args, **kwargs):
---> 17 return fn(*args, **kwargs)
File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_inductor/codecache.py:374, in CompiledFxGraph.__call__(self, inputs)
373 def __call__(self, inputs) -> Any:
--> 374 return self.get_current_callable()(inputs)
File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:628, in align_inputs_from_check_idxs.<locals>.run(new_inputs)
626 def run(new_inputs):
627 copy_misaligned_inputs(new_inputs, inputs_to_check)
--> 628 return model(new_inputs)
File /opt/conda/envs/gpnnenv/lib/python3.8/site-packages/torch/_inductor/codecache.py:401, in _run_from_cache(compiled_graph, inputs)
391 from .codecache import PyCodeCache
393 compiled_graph.compiled_artifact = PyCodeCache.load_by_key_path(
394 compiled_graph.cache_key,
395 compiled_graph.artifact_path,
(...)
398 else (),
399 ).call
--> 401 return compiled_graph.compiled_artifact(inputs)
File /tmp/torchinductor_root/rm/crmma3whzelf5smeadzh2ifwfc5wxcl4cj7kbc24iomgdxewltut.py:5364, in call(args)
5362 del primals_4
5363 buf466 = buf465[0]
-> 5364 assert_size_stride(buf466, (s0, 64, 56, 56), (200704, 1, 3584, 64))
5365 buf467 = buf465[1]
5366 assert_size_stride(buf467, (64, 64, 1, 1), (64, 1, 64, 64))
AssertionError: expected size 64==64, stride 3136==1 at dim=1