Modified multihead attention and its compatibility with opacus

This is a continuation of the first issue: Making a custom transformer architecture work with opacus
This is another notebook to reproduce and understand the problem with monotonic multihead attention: Google Colab
(See MultiheadAttention class)

def attention(q, k, v, d_k, mask, dropout, zero_pad, gamma=None):
“”"
This is called by Multi-head atention object to find the values.
“”"
scores = torch.matmul(q, k.transpose(-2, -1)) /
math.sqrt(d_k) # BS, 8, seqlen, seqlen
bs, head, seqlen = scores.size(0), scores.size(1), scores.size(2)

x1 = torch.arange(seqlen).expand(seqlen, -1).to(device)
x2 = x1.transpose(0, 1).contiguous()

with torch.no_grad():
    scores_ = scores.masked_fill(mask == 0, -1e32)
    scores_ = F.softmax(scores_, dim=-1)  # BS,8,seqlen,seqlen
    scores_ = scores_ * mask.float().to(device)
    distcum_scores = torch.cumsum(scores_, dim=-1)  # bs, 8, sl, sl
    disttotal_scores = torch.sum(
        scores_, dim=-1, keepdim=True)  # bs, 8, sl, 1
    position_effect = torch.abs(
        x1-x2)[None, None, :, :].type(torch.FloatTensor).to(device)  # 1, 1, seqlen, seqlen
    # bs, 8, sl, sl positive distance
    dist_scores = torch.clamp(
        (disttotal_scores-distcum_scores)*position_effect, min=0.)
    dist_scores = dist_scores.sqrt().detach()
m = nn.Softplus()
gamma = -1. * m(gamma).unsqueeze(0)  # 1,8,1,1
# Now after do exp(gamma*distance) and then clamp to 1e-5 to 1e5
total_effect = torch.clamp(torch.clamp(
    (dist_scores*gamma).exp(), min=1e-5), max=1e5)
scores = scores * total_effect

scores.masked_fill_(mask == 0, -1e32)
scores = F.softmax(scores, dim=-1)  # BS,8,seqlen,seqlen
if zero_pad:
    pad_zero = torch.zeros(bs, head, 1, seqlen).to(device)
    scores = torch.cat([pad_zero, scores[:, :, 1:, :]], dim=2)
scores = dropout(scores)
output = torch.matmul(scores, v)
return output

To me it seemed more like the model parameters are being used to calculate terms that are applied to the distance. In this case, could this be an issue that prevents back propagation?

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "

TypeError Traceback (most recent call last)
in <cell line: 1>()
----> 1 best_epoch = train_one_dataset(train_q_data, train_qa_data, train_pid, valid_q_data, valid_qa_data, valid_pid)
2

21 frames
in train_one_dataset(train_q_data, train_qa_data, train_pid, valid_q_data, valid_qa_data, valid_pid)
38 for idx in range(max_iter):
39 # Train Model
—> 40 train_loss, train_accuracy, train_auc = train(
41 dp_model, dp_optimizer, train_q_data, train_qa_data, train_pid, accountant, label=‘Train’)
42 # Validation step

in train(net, optimizer, q_data, qa_data, pid_data, accountant, label)
82 loss, pred, true_ct = net(input_q, input_qa, target)
83 pred = pred.detach().cpu().numpy() # (seqlen * batch_size, 1)
—> 84 loss.backward()
85 true_el += true_ct.cpu().numpy()
86

/usr/local/lib/python3.10/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
520 inputs=inputs,
521 )
→ 522 torch.autograd.backward(
523 self, gradient, retain_graph, create_graph, inputs=inputs
524 )

/usr/local/lib/python3.10/dist-packages/torch/autograd/init.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
264 # some Python versions print out the first line of a multi-line function
265 # calls in the traceback and some print out the last line
→ 266 Variable.execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
267 tensors,
268 grad_tensors
,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in call(self, *args, **kwargs)
69 if module is None:
70 raise RuntimeError(“You are trying to call the hook of a dead Module!”)
—> 71 return self.hook(module, *args, **kwargs)
72 return self.hook(*args, **kwargs)
73

/usr/local/lib/python3.10/dist-packages/opacus/grad_sample/grad_sample_module.py in capture_backprops_hook(self, module, _forward_input, forward_output, loss_reduction, batch_first)
336 grad_sampler_fn = ft_compute_per_sample_gradient
337
→ 338 grad_samples = grad_sampler_fn(module, activations, backprops)
339 for param, gs in grad_samples.items():
340 create_or_accumulate_grad_sample(

/usr/local/lib/python3.10/dist-packages/opacus/grad_sample/functorch.py in ft_compute_per_sample_gradient(layer, activations, backprops)
92 prepare_layer(layer)
93
—> 94 per_sample_grads = layer.ft_compute_sample_grad(
95 parameters, activations[0], backprops
96 )

/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py in wrapped(*args, **kwargs)
186 # @functools.wraps(func)
187 def wrapped(*args, **kwargs):
→ 188 return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
189
190 return wrapped

/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
276
277 # If chunk_size is not specified.
→ 278 return _flat_vmap(
279 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
280 )

/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in fn(*args, **kwargs)
42 def fn(*args, **kwargs):
43 with torch.autograd.graph.disable_saved_tensors_hooks(message):
—> 44 return f(*args, **kwargs)
45 return fn
46

/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
389 try:
390 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
→ 391 batched_outputs = func(*batched_inputs, **kwargs)
392 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
393 finally:

/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py in wrapper(*args, **kwargs)
361 @functools.wraps(func)
362 def wrapper(*args, **kwargs):
→ 363 return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
364 return wrapper

/usr/local/lib/python3.10/dist-packages/torch/functorch/eager_transforms.py in grad_impl(func, argnums, has_aux, args, kwargs)
1293 def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs):
1294 func = lazy_dynamo_disable(func)
→ 1295 results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
1296 if has_aux:
1297 grad, (
, aux) = results

/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in fn(*args, **kwargs)
42 def fn(*args, **kwargs):
43 with torch.autograd.graph.disable_saved_tensors_hooks(message):
—> 44 return f(*args, **kwargs)
45 return fn
46

/usr/local/lib/python3.10/dist-packages/torch/functorch/eager_transforms.py in wrapper(*args, **kwargs)
1254 tree_map
(partial(_create_differentiable, level=level), diff_args)
1255
→ 1256 output = func(*args, **kwargs)
1257 if has_aux:
1258 if not (isinstance(output, tuple) and len(output) == 2):

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py in _fn(*args, **kwargs)
487 dynamo_config_ctx.enter()
488 try:
→ 489 return fn(*args, **kwargs)
490 finally:
491 set_eval_frame(prior)

/usr/local/lib/python3.10/dist-packages/opacus/grad_sample/functorch.py in compute_loss_stateless_model(params, activations, backprops)
69 batched_backprops = backprops.unsqueeze(1)
70
—> 71 output = flayer(params, batched_activations)
72 loss = (output * batched_backprops).sum()
73 return loss

/usr/local/lib/python3.10/dist-packages/opacus/grad_sample/functorch.py in fmodel(new_params_values, *args, **kwargs)
34 name: value for name, value in zip(params_names, new_params_values)
35 }
—> 36 return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
37
38 if disable_autograd_tracking:

/usr/local/lib/python3.10/dist-packages/torch/_functorch/functional_call.py in functional_call(module, parameter_and_buffer_dicts, args, kwargs, tie_weights, strict)
141 )
142
→ 143 return nn.utils.stateless._functional_call(
144 module,
145 parameters_and_buffers,

/usr/local/lib/python3.10/dist-packages/torch/nn/utils/stateless.py in _functional_call(module, parameters_and_buffers, args, kwargs, tie_weights, strict)
261 module, parameters_and_buffers, tie_weights=tie_weights, strict=strict
262 ):
→ 263 return module(*args, **kwargs)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
→ 1511 return self._call_impl(*args, **kwargs)
1512
1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1520 return forward_call(*args, **kwargs)
1521
1522 try:

TypeError: MultiHeadAttention.forward() missing 4 required positional arguments: ‘k’, ‘v’, ‘mask’, and ‘zero_pad’

here is a notebook of the model running without opacus: Google Colab