In order of difficulty:
- make batch size smaller,
- make a minimal reproducing example (i.e. just two or three inputs from torch.random and the call to the torch.nn.functional.linear) and file a bug,
- hot-patch torch.nn.functional.linear with a workaround (splitting the operation into multiple linear or matmul calls),
- submit a PR with a fix in PyTorch and discuss whether you can add a test or whether it’d take a prohibitive large amount of GPU memory to run (or hire someone to do so).
Best regards
Thomas