Very low squared-mean difference but still `TracerWarning: Trace had nondeterministic nodes.`

Hi,

I’m trying to implement convolution transpose 2D for StyleGAN2 that uses constant weight instead of computed one.

This constant weight for transpose convolution is required for CoreML.

I am using an implementation from this issue in coremltools repo:

import torch
from torch.nn import functional

def conv_transpose_stride2(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
  dilate = torch.nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=1, stride=2, groups=128, bias=False)
  dilate.weight.data = torch.ones([128, 1, 1, 1])
  pad = torch.nn.ZeroPad2d([1, 1, 1, 1])
  return functional.conv2d(dilate(pad(x)), w.transpose(0, 1).flip(2, 3))

Now if I compute the squared-mean difference, it is very low. Even original conv transpose and custom implementation give same mean as output for same input tensor.

torch.manual_seed(0)
x = torch.randn([1, 128, 256, 256])
w = torch.randn([128, 64, 3, 3])

y = functional.conv_transpose2d(x, w, stride=2)
y_ = conv_transpose_stride2(x, w)
size = torch.tensor(y.shape).prod() # tensor(16842816)

with torch.no_grad():
  print((y - y_).square().mean().numpy())
  print(y_.square().mean().numpy())
  print(y.square().mean().numpy())

Output

3.852411e-11
286.88467
286.88467

But tracing conv_transpose_stride2() gives a non-deterministic warning:

traced_model = torch.jit.trace(conv_transpose_stride2, (x, w))

Tracing warning is as follows:

/opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:828: TracerWarning: Trace had nondeterministic nodes. Did you forget call .eval() on your model? Nodes:
	%60 : Float(128, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=1, device=cpu) = aten::uniform_(%tensor, %57, %58, %59) # /opt/homebrew/lib/python3.9/site-packages/torch/nn/init.py:412:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
  _check_trace(
/opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:828: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Tensor-likes are not close!

Mismatched elements: 16842766 / 16842816 (100.0%)
Greatest absolute difference: 134.11310195922852 at index (0, 23, 324, 138) (up to 1e-05 allowed)
Greatest relative difference: 15173386.579398053 at index (0, 14, 291, 143) (up to 1e-05 allowed)
  _check_trace(

I even tried putting conv_transpose_stride2 in an nn.Module class and tried tracing with eval() mode but still same warning.

How can I fix this?

The warning is raised because you are randomly initializing the module before assigning it’s weights:

def conv_transpose_stride2(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
  dilate = torch.nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=1, stride=2, groups=128, bias=False)
  dilate.weight.data = torch.ones([128, 1, 1, 1])
  ...

The internal weight init via .uniform_ is non-deterministic since it’s sampling random values.
You could thus create the module once, initialize it properly (in a no_grad guard and without using the deprecated .data attribute) and just trace the functional conv2d op.

1 Like

Thank you so much for quick reply @ptrblck! This works!