Using PT2 compile in tandem with the functional API

Background:
Hi everyone! I’m having an awesome time with PT2 and I would like some advice on how I can further integrate torch.compile in my code. I’m doing some research in differential privacy, and thus I make heavy use of the per_sample gradients and the functional API provided by torch throughout the training loop.

Problem:
Here’s my simplified training loop:

import torch
from torch.func import grad, vmap, functional_call

cfg = {} # Large config object for configuring train loop
device = 'cuda:0'
train_loader = '...'
model = '...'.to(device)
optim = '...'
params = dict(model.named_parameters())
buffers = dict(model.named_buffers())
criterion = lambda input, target: torch.nn.functional.cross_entropy(
        input, target, reduction='mean'
)

def compute_loss(model, params, buffers, batch, target, loss_fn):
    batch = batch.unsqueeze(0)
    targets = target.unsqueeze(0)
    predictions = functional_call(model, {**params, **buffers}, batch)
    loss = loss_fn(predictions, targets)
    return loss

# Mark static arguments/ constants with None or
# the dimension along which the operation ought
# to vectorize i.e. batch and and ground truth
compute_per_sample_grads = vmap(
    grad(compute_loss, argnums=1),
    in_dims=(None, None, None, 0, 0, None),
    randomness="same",
)

# ....

for batch, target in train_loader:
    batch = batch.to(device)
    target = target.to(device)
    y_pred = functional_call(model, params, (batch,))
    per_sample_grads = compute_per_sample_grads(
        model, params, batch, target, criterion
    )
    for param, grad_sample in zip(params.values(), per_sample_grads.values()):
        if param.grad_sample is None:
            param.grad_sample = grad_sample.contiguous()
    # ....
    optim.step()

# ...

The behaviour I expected was to be able to wrap the model with torch.compile and then use functional_call, but I receive RuntimeError: Cannot access data pointer of Tensor that doesn't have storage. I have also attempted to compile the compute_per_sample_grads () or compute_loss instead, resulting in the same error.

I have found issues regarding compile+vmap or compile+grad, but I could not find a situation of using both. I did not want to raise an issue since I feel that I’m misusing the API atm.

Many thanks for your time! I would be happy to provide other artificats as needed e.g. exact ResNet9 architecture.

Error Log

(venv) bratua@encephalon:~/sparse-dp$ python train.py
Files already downloaded and verified
Files already downloaded and verified
Number of epochs: 113
Noise multiplier: 1.24
[2023-05-05 15:40:49,259][opacus.data_loader][WARNING] - Ignoring drop_last as it is not compatible with DPDataLoader.
Started training run: 1024_BS-1_LR-0.0_MTM-1.24_NOISE-5000_STEPS-1024_BS-2_CLIP-True_DP
Number of parameters: 417,114
Epoch 1: 1%|▋ | 233/45000 [00:00<00:44, 1001.41it/s][2023-05-05 15:40:51,043] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-05-05 15:40:51,214] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
[2023-05-05 15:40:51,223] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-05-05 15:40:53,190] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 0
[2023-05-05 15:40:54,472] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 0
[2023-05-05 15:40:54,473] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper
Error executing job with overrides: []
Traceback (most recent call last):
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py”, line 547, in preserve_rng_state
yield
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py”, line 814, in wrap_fx_proxy_cls
example_value = wrap_to_fake_tensor_and_record(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py”, line 957, in wrap_to_fake_tensor_and_record
fake_e = wrap_fake_exception(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py”, line 808, in wrap_fake_exception
return fn()
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py”, line 958, in
lambda: tx.fake_mode.from_tensor(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py”, line 1324, in from_tensor
return self.fake_tensor_converter(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py”, line 314, in call
return self.from_real_tensor(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py”, line 272, in from_real_tensor
out = self.meta_converter(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_subclasses/meta_utils.py”, line 502, in call
r = self.meta_tensor(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_subclasses/meta_utils.py”, line 275, in meta_tensor
base = self.meta_tensor(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_subclasses/meta_utils.py”, line 381, in meta_tensor
s = t.untyped_storage()
NotImplementedError: Cannot access storage of TensorWrapper

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py”, line 324, in _compile
out_code = transform_code_object(code, transform)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py”, line 445, in transform_code_object
transformations(instructions, code_options)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py”, line 299, in transform
tracer = InstructionTranslator(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py”, line 1670, in init
self.symbolic_locals = collections.OrderedDict(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py”, line 1673, in
VariableBuilder(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py”, line 172, in call
return self._wrap(value).clone(**self.options())
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py”, line 238, in _wrap
return self.wrap_tensor(value)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py”, line 639, in wrap_tensor
tensor_variable = wrap_fx_proxy(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py”, line 754, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py”, line 814, in wrap_fx_proxy_cls
example_value = wrap_to_fake_tensor_and_record(
File “/usr/lib/python3.8/contextlib.py”, line 131, in exit
self.gen.throw(type, value, traceback)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py”, line 549, in preserve_rng_state
torch.random.set_rng_state(rng)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/random.py”, line 18, in set_rng_state
default_generator.set_state(new_state)
RuntimeError: Cannot access data pointer of Tensor that doesn’t have storage

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py”, line 104, in _fn
return fn(*args, **kwargs)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py”, line 262, in _convert_frame_assert
return _compile(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py”, line 163, in time_wrapper
r = func(*args, **kwargs)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py”, line 394, in _compile
raise InternalTorchDynamoError() from e
torch._dynamo.exc.InternalTorchDynamoError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File “train.py”, line 185, in main
per_sample_grads = compute_per_sample_grads(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_functorch/vmap.py”, line 434, in wrapped
return _flat_vmap(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_functorch/vmap.py”, line 39, in fn
return f(*args, **kwargs)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_functorch/vmap.py”, line 619, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py”, line 1380, in wrapper
results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_functorch/vmap.py”, line 39, in fn
return f(*args, **kwargs)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py”, line 1245, in wrapper
output = func(*args, **kwargs)
File “train.py”, line 40, in compute_loss
predictions = functional_call(model, {**params, **buffers}, batch)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_functorch/functional_call.py”, line 143, in functional_call
return nn.utils.stateless._functional_call(
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/nn/utils/stateless.py”, line 262, in _functional_call
return module(*args, **kwargs)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py”, line 82, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py”, line 209, in _fn
return fn(*args, **kwargs)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py”, line 337, in catch_errors
return callback(frame, cache_size, hooks)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py”, line 404, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py”, line 109, in _fn
torch.cuda.set_rng_state(cuda_rng_state)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/cuda/random.py”, line 64, in set_rng_state
_lazy_call(cb)
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/cuda/init.py”, line 183, in _lazy_call
callable()
File “/home/bratua/sparse-dp/venv/lib/python3.8/site-packages/torch/cuda/random.py”, line 62, in cb
default_generator.set_state(new_state_copy)
RuntimeError: Cannot access data pointer of Tensor that doesn’t have storage

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.