I am using torch.cuda.amp
for mixed precision.
My forward pass calls many functions with their own forward
passes.
I tried to decorate all the forward passes in the subsequent functions with torch.cuda.amp.autocast(enabled=True)
but the error persists.
Forward pass:
with torch.cuda.amp.autocast(enabled=True):
h, chunk, preds, labels = model.forward(batch, alphaSG, device)
label = labels
for worker in model.classification_workers:
loss = worker.loss_weight * worker.loss(preds[worker.name], label[worker.name])
losses[worker.name] = loss
tot_loss += loss
for worker in model.regression_workers:
loss = worker.loss_weight * worker.loss(preds[worker.name], label[worker.name])
losses[worker.name] = loss
tot_loss += loss
Error I am getting:
RuntimeError Traceback (most recent call last)
<ipython-input-10-f4ce5bf32b0d> in <module>()
2327
2328 with torch.cuda.amp.autocast(enabled=True):
-> 2329 h, chunk, preds, labels = model.forward(batch, alphaSG, device)
2330 label = labels
2331 for worker in model.classification_workers:
10 frames
/usr/local/lib/python3.6/dist-packages/torch/cuda/amp/autocast_mode.py in decorate_autocast(*args, **kwargs)
133 def decorate_autocast(*args, **kwargs):
134 with self:
--> 135 return func(*args, **kwargs)
136 return decorate_autocast
137
<ipython-input-10-f4ce5bf32b0d> in forward(self, x, alpha, device)
1945 # remove key if it exists
1946 x_.pop('cchunk', None)
-> 1947 h = self.frontend(x_, device)
1948 if len(h) > 1:
1949 assert len(h) == 2, len(h)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
556 result = self._slow_forward(*input, **kwargs)
557 else:
--> 558 result = self.forward(*input, **kwargs)
559 for hook in self._forward_hooks.values():
560 hook_result = hook(self, input, result)
/usr/local/lib/python3.6/dist-packages/torch/cuda/amp/autocast_mode.py in decorate_autocast(*args, **kwargs)
133 def decorate_autocast(*args, **kwargs):
134 with self:
--> 135 return func(*args, **kwargs)
136 return decorate_autocast
137
<ipython-input-10-f4ce5bf32b0d> in forward(self, batch, device, mode)
1827 dskips = []
1828 for n, block in enumerate(self.blocks):
-> 1829 h = block(h)
1830 if denseskips and (n + 1) < len(self.blocks):
1831 # denseskips happen til the last but one layer
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
556 result = self._slow_forward(*input, **kwargs)
557 else:
--> 558 result = self.forward(*input, **kwargs)
559 for hook in self._forward_hooks.values():
560 hook_result = hook(self, input, result)
/usr/local/lib/python3.6/dist-packages/torch/cuda/amp/autocast_mode.py in decorate_autocast(*args, **kwargs)
133 def decorate_autocast(*args, **kwargs):
134 with self:
--> 135 return func(*args, **kwargs)
136 return decorate_autocast
137
<ipython-input-10-f4ce5bf32b0d> in forward(self, x)
1494 P = (pad, pad)
1495 x = F.pad(x, P, mode=self.pad_mode)
-> 1496 h = self.conv(x)
1497 if hasattr(self, 'norm'):
1498 h = forward_norm(h, self.norm)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
556 result = self._slow_forward(*input, **kwargs)
557 else:
--> 558 result = self.forward(*input, **kwargs)
559 for hook in self._forward_hooks.values():
560 hook_result = hook(self, input, result)
/usr/local/lib/python3.6/dist-packages/torch/cuda/amp/autocast_mode.py in decorate_autocast(*args, **kwargs)
133 def decorate_autocast(*args, **kwargs):
134 with self:
--> 135 return func(*args, **kwargs)
136 return decorate_autocast
137
<ipython-input-10-f4ce5bf32b0d> in forward(self, waveforms)
1334 band=(high-low)[:,0]
1335
-> 1336 f_times_t_low = torch.matmul(low, self.n_)
1337 f_times_t_high = torch.matmul(high, self.n_)