I’m trying to get the following code to work:
from torch.nn.modules import Module
class Sine(Module):
def forward(self, z):
return torch.sin(100 * z)
act_fn = Sine()
m = nn.Sequential(
nn.Linear(1, 256, bias=True),
act_fn,
nn.Linear(256, 256, bias=True),
act_fn,
nn.Linear(256, 1, bias=True),
)
m = nn.Sequential(torch.ao.quantization.QuantStub(),
*m,
torch.ao.quantization.DeQuantStub())
m.train()
m.qconfig = torch.ao.quantization.get_default_qconfig('x86')
torch.ao.quantization.prepare_qat(m, inplace=True)
m.cuda()
n_epochs = 500
optim = torch.optim.AdamW(lr=1e-4, params=m.parameters(), weight_decay=1e-5)
for epoch in tqdm(range(n_epochs)):
x = timestamps
out = m(x)
loss = F.mse_loss(out, ground_truth)
optim.zero_grad()
loss.backward()
optim.step()
m.eval()
m.cpu()
torch.ao.quantization.convert(m, inplace=True)
res = m(timestamps.cpu())
The code runs until res = m(timestamps.cpu())
. There, I get the error:
Traceback (most recent call last):
File "qat.py", line 132, in <module>
res = m(model_input.cpu())
File "anaconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "anaconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "anaconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "qat.py", line 94, in forward
return torch.sin(100 * z)
RuntimeError: empty_strided not supported on quantized tensors yet see https://github.com/pytorch/pytorch/issues/74540
How do I solve this? I’d really like to test quantization aware training on this net…