Find weights with PyTorch

I have a ground truth data (G) that is a 3D numeric matrix of shape (1000, 25, 25) where 1000 refers to the number of time steps and (25, 25) is the x and y dimensions of each time frame in 2D.

I also have 5 surrogate datasets (S1, S2, S3, S4, S5) which are approximations of the above ground truth data (G).

I need to learn weights (w1, w2, w3, w4, w5) such that the linear combination of the surrogate datasets give the best possible approximation to the ground truth dataset, i.e.:

w1*S1 + w2*S2 + w3*S3 + w4*S4 + w5*S5 ≈ G

Is there some way to do this with PyTorch? Thanks!

This is a pretty simple use case for Pytorch.

import torch
import torch.nn as nn
import torch.optim as opt

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.w1 = nn.Parameter(torch.ones((1000,25,25), dtype=torch.float))
        self.w2 = nn.Parameter(torch.ones((1000,25,25), dtype=torch.float))
        self.w3 = nn.Parameter(torch.ones((1000,25,25), dtype=torch.float))
        self.w4 = nn.Parameter(torch.ones((1000,25,25), dtype=torch.float))
        self.w5 = nn.Parameter(torch.ones((1000,25,25), dtype=torch.float))
    def forward(self, s1, s2, s3, s4, s5):
        return self.w1*s1 + self.w2*s2 + self.w3*s3 + self.w4*s4 + self.w5*s5

ground_truth=torch.rand((1, 1000, 25, 25))

s1 = torch.rand(1, 1000, 25, 25)
s2 = torch.rand(1, 1000, 25, 25)
s3 = torch.rand(1, 1000, 25, 25)
s4 = torch.rand(1, 1000, 25, 25)
s5 = torch.rand(1, 1000, 25, 25)

model = Model()

epochs = 1000
criterion = nn.MSELoss()
optimizer = opt.Adam(model.parameters(), lr=0.01)

for i in range(epochs):
    model.zero_grad()
    out = model(s1, s2, s3, s4, s5)
    loss=criterion(out, ground_truth)
    loss.backward()
    optimizer.step()

print(loss.item())

Edit: Just noticed your equation at the bottom. I’ve updated the above to reflect that. If you’d prefer scalar values, you can just change each weight to nn.Parameter(torch.ones((1), dtype=torch.float)).

1 Like

Hi @J_Johnson thanks a lot for the detailed explanation and solution. I am trying this on my dataset and will let you know. Thanks again!

1 Like