Hi Rachel!
Good luck! This is likely to be, at best, annoying. It’s probably doable, with some
work-arounds.
Not that I am aware of.
No. Sometimes the functionality in question doesn’t make sense with complex
numbers (for example, ReLU
) and sometimes things that do make sense haven’t
been implemented yet.
Complex support in pytorch is still a work in progress. There are also some
nuances working with complex tensors. These are not pytorch’s fault, but you
need to understand them and how they interact with your use case.
In general, look at pytorch’s documentation, but, in your particular context, your best
bet will be to go through your complex U-Net step by step, see what still works, and
implement work-arounds as necessary for what doesn’t.
Some comments:
Complex numbers are not ordered. That is, for two complex numbers u
and v
,
it doesn’t make sense to say u < v
. This means that things like ReLU
can’t work
in a natural way for complex numbers.
Note, for example, that for this reason MaxPool2d
– often used in convolutional
neural networks – doesn’t work with complex tensors.
Pytorch’s autograd and optimizers have built into them the assumption that you
are training your (complex) network with a real loss. The current version of pytorch
seems to be good at warning you if you don’t follow this rule, but this restriction has
the potential to lead to some unexpected results.
Here are some illustrations of pytorch’s complex support and limitations:
>>> import torch
>>> print (torch.__version__)
2.3.0
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> r = torch.ones (1, 1, 4, 4)
>>> c = torch.ones (1, 1, 4, 4, dtype = torch.complex64)
>>>
>>> r
tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]])
>>> c
tensor([[[[1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j]]]])
>>>
>>> lin = torch.nn.Linear (4, 2, dtype = torch.complex64)
>>> conv = torch.nn.Conv2d (1, 1, 3, dtype = torch.complex64)
>>>
>>> relu = torch.nn.ReLU() # won't work with complex -- complex not ordered
>>> maxp = torch.nn.MaxPool2d (2) # won't work with complex -- complex not ordered
>>> avgp = torch.nn.AvgPool2d (2) # should work with complex, but doesn't
>>>
>>> lin (c)
tensor([[[[-0.3101+0.3275j, 0.0007+0.5325j],
[-0.3101+0.3275j, 0.0007+0.5325j],
[-0.3101+0.3275j, 0.0007+0.5325j],
[-0.3101+0.3275j, 0.0007+0.5325j]]]], grad_fn=<ViewBackward0>)
>>> conv (c)
tensor([[[[-0.0212-0.0528j, -0.0212-0.0528j],
[-0.0212-0.0528j, -0.0212-0.0528j]]]], grad_fn=<AddBackward0>)
>>>
>>> relu (r)
tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]])
>>> maxp (r)
tensor([[[[1., 1.],
[1., 1.]]]])
>>> avgp (r)
tensor([[[[1., 1.],
[1., 1.]]]])
>>>
>>> # relu (r) # would raise error
>>> # maxp (r) # would raise error
>>> # avgp (r) # would raise error
>>>
>>> # lin (c).sum().backward() # would raise error -- "loss" not real
>>> # conv (c).sum().backward() # would raise error -- "loss" not real
>>>
>>> lin (c).sum().abs().backward() # works -- "loss" is real
>>> lin.weight.grad
tensor([[-1.3540+3.7639j, -1.3540+3.7639j, -1.3540+3.7639j, -1.3540+3.7639j],
[-1.3540+3.7639j, -1.3540+3.7639j, -1.3540+3.7639j, -1.3540+3.7639j]])
>>>
>>> conv (c).sum().abs().backward() # works -- "loss" is real
>>> conv.weight.grad
tensor([[[[-1.4866-3.7135j, -1.4866-3.7135j, -1.4866-3.7135j],
[-1.4866-3.7135j, -1.4866-3.7135j, -1.4866-3.7135j],
[-1.4866-3.7135j, -1.4866-3.7135j, -1.4866-3.7135j]]]])
>>>
>>> # implement AvgPool2d with Conv2d
>>>
>>> avgp_conv = torch.nn.Conv2d (1, 1, 2, 2, bias = False, dtype = torch.complex64)
>>> avgp_conv.weight.requires_grad = False
>>> avgp_conv.weight.copy_ (torch.ones (1, 1, 2, 2, dtype = torch.complex64) / 4)
Parameter containing:
tensor([[[[0.2500+0.j, 0.2500+0.j],
[0.2500+0.j, 0.2500+0.j]]]])
>>>
>>> s = torch.randn (1, 1, 4, 4)
>>> t = torch.randn (1, 1, 4, 4, dtype = torch.complex64)
>>>
>>> s
tensor([[[[ 0.7203, -1.4108, -0.4384, 0.3551],
[ 0.3730, -1.3050, -0.7983, 1.0442],
[-0.1227, 0.4022, -1.4295, -0.5656],
[ 0.6971, 0.1258, -0.0434, 0.5366]]]])
>>> t
tensor([[[[ 0.2242-1.3234j, -0.0633+0.5293j, 0.4339-0.3124j, 0.0696-0.1074j],
[-0.3859+0.6665j, 0.2333+1.0793j, -0.1318+0.6501j, -0.2297+0.9898j],
[-0.8779-0.0999j, -0.4100-0.2612j, -0.8163+0.8257j, 0.1332+0.9607j],
[ 0.1305+0.8744j, 0.6930+0.1200j, -0.7032+0.7892j, 0.6485+1.2661j]]]])
>>>
>>> avgp (s)
tensor([[[[-0.4056, 0.0407],
[ 0.2756, -0.3755]]]])
>>> avgp_conv (s.type (torch.complex64))
tensor([[[[-0.4056+0.j, 0.0407+0.j],
[ 0.2756+0.j, -0.3755+0.j]]]])
>>>
>>> avgp_conv (t)
tensor([[[[ 0.0021+0.2379j, 0.0355+0.3050j],
[-0.1161+0.1583j, -0.1844+0.9604j]]]])
Best.
K. Frank