I have to approximate the maximum of the following function :
This problem seems unsolvable, but I thought we could approximate the solution and obtain good functions g1, g2 by using PyTorch. Additionally, it gives me practice with this framework, which I don’t yet fully master.
I suppose that g1(x) € {1,…max_value} and I use a simple model which gives almost integer values for gi with the smooth argmax.
I approximate the double integration by a double summation.
With this method I obtain about 0.25 as a maximum. My colleagues said me I can obtain better with 0.35. Do you know how to improve the code?
import torch
import torch.nn as nn
import torch.optim as optim
def smooth_floor(x, alpha=0.01):
#if type(x) == float : x = torch.tensor([x])
return torch.floor(x) + torch.sigmoid(alpha * (x - torch.floor(x))) - 0.5
def smooth_argmax(x, dim=-1, temperature=0.01):
x_softmax = torch.softmax(x / temperature, dim=dim)
indices = torch.arange(x.size(dim), device=x.device).float()
return torch.sum(x_softmax * indices, dim=dim)
def eps(x) :
return smooth_floor(x) - 2 * smooth_floor(x/2)
class Function(nn.Module):
def __init__(self, max_value):
super().__init__()
self.layer = nn.Linear(1,100)
self.layer2 = nn.Linear(100, max_value)
def forward(self,x):
x = self.layer(x)
x = nn.ReLU()(x)
x = self.layer2(x)
return 1 + smooth_argmax(x)
def loss(discretisation_step) :
T = torch.linspace(0,1,discretisation_step).reshape(-1,1)
X = f(T)
Y = g(T)
T = T.reshape(-1)
part1 = eps(torch.einsum("b,c->bc",T,2**X))
part2 = eps(torch.einsum("b,c->bc",T,2**Y)).T
return -(part1*part2).sum()/discretisation_step**2
max_value = 10
discretisation_step = 500
num_epochs = 10000
f = Function(max_value)
g = Function(max_value)
optimizer = optim.AdamW(f.parameters(), lr=0.1)
optimizer2 = optim.AdamW(g.parameters(), lr=0.1)
for epoch in range(num_epochs):
total_loss = 0
f.train()
g.train()
optimizer.zero_grad()
optimizer2.zero_grad()
my_loss = loss(discretisation_step)
my_loss.backward()
optimizer.step()
optimizer2.step()
if epoch%10 == 0 : print(epoch, -my_loss)