# Modified multihead attention and its compatibility with opacus

This is a continuation of the first issue: Making a custom transformer architecture work with opacus
This is another notebook to reproduce and understand the problem with monotonic multihead attention: Google Colab

“”"
This is called by Multi-head atention object to find the values.
“”"
scores = torch.matmul(q, k.transpose(-2, -1)) /
math.sqrt(d_k) # BS, 8, seqlen, seqlen
bs, head, seqlen = scores.size(0), scores.size(1), scores.size(2)

``````x1 = torch.arange(seqlen).expand(seqlen, -1).to(device)
x2 = x1.transpose(0, 1).contiguous()

scores_ = F.softmax(scores_, dim=-1)  # BS,8,seqlen,seqlen
distcum_scores = torch.cumsum(scores_, dim=-1)  # bs, 8, sl, sl
disttotal_scores = torch.sum(
scores_, dim=-1, keepdim=True)  # bs, 8, sl, 1
position_effect = torch.abs(
x1-x2)[None, None, :, :].type(torch.FloatTensor).to(device)  # 1, 1, seqlen, seqlen
# bs, 8, sl, sl positive distance
dist_scores = torch.clamp(
(disttotal_scores-distcum_scores)*position_effect, min=0.)
dist_scores = dist_scores.sqrt().detach()
m = nn.Softplus()
gamma = -1. * m(gamma).unsqueeze(0)  # 1,8,1,1
# Now after do exp(gamma*distance) and then clamp to 1e-5 to 1e5
total_effect = torch.clamp(torch.clamp(
(dist_scores*gamma).exp(), min=1e-5), max=1e5)
scores = scores * total_effect

scores = F.softmax(scores, dim=-1)  # BS,8,seqlen,seqlen
scores = torch.cat([pad_zero, scores[:, :, 1:, :]], dim=2)
scores = dropout(scores)
output = torch.matmul(scores, v)
return output
``````

To me it seemed more like the model parameters are being used to calculate terms that are applied to the distance. In this case, could this be an issue that prevents back propagation?

## /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior. warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "

TypeError Traceback (most recent call last)
in <cell line: 1>()
----> 1 best_epoch = train_one_dataset(train_q_data, train_qa_data, train_pid, valid_q_data, valid_qa_data, valid_pid)
2

21 frames
in train_one_dataset(train_q_data, train_qa_data, train_pid, valid_q_data, valid_qa_data, valid_pid)
38 for idx in range(max_iter):
39 # Train Model
—> 40 train_loss, train_accuracy, train_auc = train(
41 dp_model, dp_optimizer, train_q_data, train_qa_data, train_pid, accountant, label=‘Train’)
42 # Validation step

in train(net, optimizer, q_data, qa_data, pid_data, accountant, label)
82 loss, pred, true_ct = net(input_q, input_qa, target)
83 pred = pred.detach().cpu().numpy() # (seqlen * batch_size, 1)
—> 84 loss.backward()
85 true_el += true_ct.cpu().numpy()
86

/usr/local/lib/python3.10/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
520 inputs=inputs,
521 )
523 self, gradient, retain_graph, create_graph, inputs=inputs
524 )

264 # some Python versions print out the first line of a multi-line function
265 # calls in the traceback and some print out the last line
→ 266 Variable.execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
267 tensors,
,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in call(self, *args, **kwargs)
69 if module is None:
70 raise RuntimeError(“You are trying to call the hook of a dead Module!”)
—> 71 return self.hook(module, *args, **kwargs)
72 return self.hook(*args, **kwargs)
73

337
339 for param, gs in grad_samples.items():

92 prepare_layer(layer)
93
95 parameters, activations[0], backprops
96 )

/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py in wrapped(*args, **kwargs)
186 # @functools.wraps(func)
187 def wrapped(*args, **kwargs):
→ 188 return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
189
190 return wrapped

/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
276
277 # If chunk_size is not specified.
→ 278 return _flat_vmap(
279 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
280 )

/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in fn(*args, **kwargs)
42 def fn(*args, **kwargs):
—> 44 return f(*args, **kwargs)
45 return fn
46

/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
389 try:
390 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
→ 391 batched_outputs = func(*batched_inputs, **kwargs)
392 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
393 finally:

/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py in wrapper(*args, **kwargs)
361 @functools.wraps(func)
362 def wrapper(*args, **kwargs):
→ 363 return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
364 return wrapper

/usr/local/lib/python3.10/dist-packages/torch/functorch/eager_transforms.py in grad_impl(func, argnums, has_aux, args, kwargs)
1293 def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs):
1294 func = lazy_dynamo_disable(func)
→ 1295 results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
1296 if has_aux:
, aux) = results

/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in fn(*args, **kwargs)
42 def fn(*args, **kwargs):
—> 44 return f(*args, **kwargs)
45 return fn
46

/usr/local/lib/python3.10/dist-packages/torch/functorch/eager_transforms.py in wrapper(*args, **kwargs)
1254 tree_map
(partial(_create_differentiable, level=level), diff_args)
1255
→ 1256 output = func(*args, **kwargs)
1257 if has_aux:
1258 if not (isinstance(output, tuple) and len(output) == 2):

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py in _fn(*args, **kwargs)
487 dynamo_config_ctx.enter()
488 try:
→ 489 return fn(*args, **kwargs)
490 finally:
491 set_eval_frame(prior)

69 batched_backprops = backprops.unsqueeze(1)
70
—> 71 output = flayer(params, batched_activations)
72 loss = (output * batched_backprops).sum()
73 return loss

34 name: value for name, value in zip(params_names, new_params_values)
35 }
37

/usr/local/lib/python3.10/dist-packages/torch/_functorch/functional_call.py in functional_call(module, parameter_and_buffer_dicts, args, kwargs, tie_weights, strict)
141 )
142
→ 143 return nn.utils.stateless._functional_call(
144 module,
145 parameters_and_buffers,

/usr/local/lib/python3.10/dist-packages/torch/nn/utils/stateless.py in _functional_call(module, parameters_and_buffers, args, kwargs, tie_weights, strict)
261 module, parameters_and_buffers, tie_weights=tie_weights, strict=strict
262 ):
→ 263 return module(*args, **kwargs)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
→ 1511 return self._call_impl(*args, **kwargs)
1512
1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1520 return forward_call(*args, **kwargs)
1521
1522 try: