Simple Optimization/Embedding Example

I’m looking to do a simple optimization, similar to the collaborative learning example that Jeremy Howard (fast.ai) showed in his Deep Learning MOOC (Lesson 4 @ 1:08)

Given blockData.shape=(20,14) filled with random numbers, I want to start with two matrices also filled with random numbers, with shapes vert.shape=(20,10) and hori.shape=(10,14) such that at the end of the optimization I minimize the quantity:

mse( blockData - dot(vert,hori) )

by modifying hori and vert simultaneously using autograd.

If I write the backward() call

loss = mse(tensor(numpy.dot(vert,hori)),blockData)
loss.backward()

I get an error RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn which, based on my internet searching, is most commonly seen in cases when the .requires_grad isn’t set to True.

It appears I have a case where the gradient isn’t computed because (maybe) there isn’t a variable to set requires_grad=True for.

Two main questions are:

  1. How can I display the contents of tensors so I can see what element 0 is referring to?
  2. Can autograd do what I’m asking it to? And if so, how?

Here is the code I have, which isn’t what I’m looking for, but is close, and it runs without error:

from fastai.basics import tensor, nn
import torch, numpy, pandas

def hypothesis(vert,hori):
    return numpy.dot(vert,hori)

def mse(y_hat,y):
    return ((y_hat-y)**2).mean()

def update(y_hat):
    # perform gradient descent
    loss = mse(y_hat,blockData)
    loss.backward()
    if t%10 == 0:
        print(t,'-------------',loss)
    with torch.no_grad():
        y_hat.sub_(lr * y_hat.grad)
        y_hat.grad.zero_()

vecSize = 10
shape = (20,14)
# random large block of data
blockData = tensor(numpy.random.random_sample(shape))

hori = nn.Parameter(tensor(numpy.random.random_sample((vecSize,shape[1]))))
vert = nn.Parameter(tensor(numpy.random.random_sample((shape[0],vecSize))))


lr = 1e-1
y_hat = tensor(hypothesis(vert,hori))
y_hat.requires_grad_(True)
for t in range(101):
    update(y_hat)

The code above will drive y_hat to the correct answer, but that’s not really the idea, since what I’m trying to do is to get vert and hori to drive towards the correct answers as part of the same optimization. As it is written now I’m just taking the gradient from the hypothesis matrix (y_hat) to the target (blockData), which is a trivial optimization.

I’m not familiar with the mentioned Lesson in FastAI, however from a code perspective you are detaching the computation graph from vert and hori by using numpy.dot and rewrapping the result in a tensor with requires_grad=True.
Since vert and hori are already nn.Parameters, you could stick to PyTorch methods, so that Autograd will automatically create the computation graph (and can thus backpropagate).

I’m not sure, if this will yield the desired result, but this core updates vert and hori:

def hypothesis(vert,hori):
    return torch.matmul(vert,hori)

def mse(y_hat,y):
    return ((y_hat-y)**2).mean()

def update(y_hat):
    y_hat = hypothesis(vert,hori)
    # perform gradient descent
    loss = mse(y_hat,blockData)
    loss.backward()
    if t%10 == 0:
        print(t,'-------------',loss)
    with torch.no_grad():
        vert.sub_(lr * vert.grad)
        hori.sub_(lr * hori.grad)
        vert.grad.zero_()
        hori.grad.zero_()

vecSize = 10
shape = (20,14)
# random large block of data
blockData = torch.randn(shape)

hori = nn.Parameter(torch.randn(vecSize,shape[1]))
vert = nn.Parameter(torch.randn(shape[0],vecSize))

lr = 1e-1
for t in range(101):
    update(y_hat)

Thanks for the clear explanation, that fixed the issue (and gave me some useful things to read up on to dive deeper into Torch). The optimization works now, here is the full code. The output of the resultant tensors/vectors at the end isn’t probably the best method, but it works.

One minor note about the code below, I changed vert to vertLatents and hori to horiLatents

from fastai.basics import tensor, nn
import torch, numpy, pandas

def hypothesis(vertLatents,horiLatents):
    return torch.matmul(vertLatents,horiLatents)

def mse(horiLatents,vertLatents,y):
    return ((torch.matmul(horiLatents,vertLatents)-y)**2).mean()

def update(horiLatents,vertLatents):
    # perform gradient descent
    loss = mse(vertLatents,horiLatents,blockData)
    loss.backward()
    if t%100 == 0:
        print(t,'-------------',loss)
    with torch.no_grad():
        vertLatents.sub_(lr * vertLatents.grad)
        vertLatents.grad.zero_()
        horiLatents.sub_(lr * horiLatents.grad)
        horiLatents.grad.zero_()
    return loss.item()

vecLatents = 10
shape = (20,14)
# random large block of data
blockData = tensor(numpy.random.random_sample(shape))

horiLatents = \
        nn.Parameter(tensor(numpy.random.random_sample((vecLatents,shape[1]))))
vertLatents = \
        nn.Parameter(tensor(numpy.random.random_sample((shape[0],vecLatents))))

lr = 1e-1
horiLatents.requires_grad_(True)
vertLatents.requires_grad_(True)
lossDict = {}
for t in range(10001):
    lossDict[t] = update(horiLatents,vertLatents)
pandas.DataFrame.from_dict(lossDict,orient='index').to_csv('lossDict.csv',
                                                           index=False)
pandas.DataFrame(horiLatents.data.tolist()).to_csv('horiLatents.csv',
                                                   index=False)
pandas.DataFrame(vertLatents.data.tolist()).to_csv('vertLatents.csv',
                                                   index=False)
pandas.DataFrame(blockData.data.tolist()).to_csv('blockData.csv',
                                                   index=False)
pandas.DataFrame(torch.matmul(vertLatents,horiLatents).data.tolist()).to_csv(
    'optimized.csv',index=False)
1 Like