Minimizing a custom function

I have to approximate the maximum of the following function :

image

This problem seems unsolvable, but I thought we could approximate the solution and obtain good functions g1, g2 by using PyTorch. Additionally, it gives me practice with this framework, which I don’t yet fully master.

I suppose that g1(x) € {1,…max_value} and I use a simple model which gives almost integer values for gi with the smooth argmax.
I approximate the double integration by a double summation.

With this method I obtain about 0.25 as a maximum. My colleagues said me I can obtain better with 0.35. Do you know how to improve the code?

    import torch
    import torch.nn as nn
    import torch.optim as optim
    
    def smooth_floor(x, alpha=0.01):
        #if type(x) == float : x = torch.tensor([x])
        return torch.floor(x) + torch.sigmoid(alpha * (x - torch.floor(x))) - 0.5
    
    
    def smooth_argmax(x, dim=-1, temperature=0.01):
        x_softmax = torch.softmax(x / temperature, dim=dim)
        indices = torch.arange(x.size(dim), device=x.device).float()
        return torch.sum(x_softmax * indices, dim=dim)
    
    
    def eps(x) : 
        return  smooth_floor(x) - 2 * smooth_floor(x/2)
        
    class Function(nn.Module): 
        
        def __init__(self, max_value):
            super().__init__()
            self.layer = nn.Linear(1,100)
            self.layer2 = nn.Linear(100, max_value)
    
        def forward(self,x):
            x = self.layer(x)
            x = nn.ReLU()(x)
            x = self.layer2(x)
            return 1 + smooth_argmax(x)
        
    def loss(discretisation_step) : 
        T = torch.linspace(0,1,discretisation_step).reshape(-1,1)
        X = f(T)
        Y = g(T)
        T = T.reshape(-1)
        part1 = eps(torch.einsum("b,c->bc",T,2**X))
        part2 = eps(torch.einsum("b,c->bc",T,2**Y)).T
        return -(part1*part2).sum()/discretisation_step**2
    
    max_value  = 10
    discretisation_step = 500
    num_epochs = 10000
                
    f = Function(max_value)
    g = Function(max_value)
    
    optimizer = optim.AdamW(f.parameters(), lr=0.1)
    optimizer2 = optim.AdamW(g.parameters(), lr=0.1)
    
    
    for epoch in range(num_epochs):
        total_loss = 0
        f.train()
        g.train()
        optimizer.zero_grad()
        optimizer2.zero_grad()
        my_loss = loss(discretisation_step)
        my_loss.backward()
        optimizer.step()
        optimizer2.step()
        if epoch%10 == 0 : print(epoch, -my_loss)