Hi,
I’m trying to implement convolution transpose 2D for StyleGAN2 that uses constant weight instead of computed one.
This constant weight for transpose convolution is required for CoreML.
I am using an implementation from this issue in coremltools repo:
import torch
from torch.nn import functional
def conv_transpose_stride2(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
dilate = torch.nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=1, stride=2, groups=128, bias=False)
dilate.weight.data = torch.ones([128, 1, 1, 1])
pad = torch.nn.ZeroPad2d([1, 1, 1, 1])
return functional.conv2d(dilate(pad(x)), w.transpose(0, 1).flip(2, 3))
Now if I compute the squared-mean difference, it is very low. Even original conv transpose and custom implementation give same mean as output for same input tensor.
torch.manual_seed(0)
x = torch.randn([1, 128, 256, 256])
w = torch.randn([128, 64, 3, 3])
y = functional.conv_transpose2d(x, w, stride=2)
y_ = conv_transpose_stride2(x, w)
size = torch.tensor(y.shape).prod() # tensor(16842816)
with torch.no_grad():
print((y - y_).square().mean().numpy())
print(y_.square().mean().numpy())
print(y.square().mean().numpy())
Output
3.852411e-11
286.88467
286.88467
But tracing conv_transpose_stride2()
gives a non-deterministic warning:
traced_model = torch.jit.trace(conv_transpose_stride2, (x, w))
Tracing warning is as follows:
/opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:828: TracerWarning: Trace had nondeterministic nodes. Did you forget call .eval() on your model? Nodes:
%60 : Float(128, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=1, device=cpu) = aten::uniform_(%tensor, %57, %58, %59) # /opt/homebrew/lib/python3.9/site-packages/torch/nn/init.py:412:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
_check_trace(
/opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:828: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Tensor-likes are not close!
Mismatched elements: 16842766 / 16842816 (100.0%)
Greatest absolute difference: 134.11310195922852 at index (0, 23, 324, 138) (up to 1e-05 allowed)
Greatest relative difference: 15173386.579398053 at index (0, 14, 291, 143) (up to 1e-05 allowed)
_check_trace(
I even tried putting conv_transpose_stride2
in an nn.Module
class and tried tracing with eval()
mode but still same warning.
How can I fix this?