Optimization with unitary constraints

Hi. I want to take a constrained optimization. Specifically, the problem is to minimize a function f(U1, U2, …), with U_i is a unitary matrix.
For example,

import torch
from torch import nn
import numpy as np

Ui = []
for i in range(4):
    H = np.random.rand(4, 4)
    np.add(H.T.conjugate(), H, H)
    np.multiply(.5, H, H)
    val, vec = np.linalg.eig(H)
    unitary = np.dot(vec * np.exp(1.j * val), vec.T.conjugate())
    Ui.append(np.stack([unitary.real, unitary.imag], axis=0))
guess_U = torch.from_numpy(np.stack(Ui, axis=1)).float().unsqueeze(0)
guess_U.requires_grad = True

layer1 = nn.Conv3d(2, 16, kernel_size=3, stride=1)
layer2 = nn.Linear(16*2*2*2, 1)
func = lambda x: layer2(nn.ReLU()(layer1(x)).view(1, -1))

optimizer = torch.optim.Adam([guess_U], lr=3e-4)
for epoch in range(1000):
    out = func(guess_U)
    with torch.no_grad():
        ...  # Unitary matrix constraints

Since PyTorch is only supported real values, considering U as A+iB, with A and B are real matrices, I have turned the problem as
\min_{A, B} f(A, B),
s.t. A A^T+B B^T = 1 and B A^T - A B^T=0.
Since the constraints asre with respect to inner product and transpose, I don’t know how to implement it with PyTorch.