SVCCA on LSTM output

Hi everyone,

I am kind of a newbie to PyTorch and ML. I have been trying to check for nodes correlation using the svcca module from google here.

The examples work well and if I try to use random tensors converted into np arrays I’m able to get the svcca values.

I then tried to get the activations from a layer of an LSTM net by adding a self.activations variable to the model, which saves the activations of the LSTM layer output.

    def forward(self, text):

        embedded = self.embedding(text)
        output, hidden = self.rnn(embedded)
        self.activations = output
        assert torch.equal(output[-1,:,:], hidden[0].squeeze(0))
        out = self.linear(hidden[0].squeeze(0))
        return out

Then during evaluation I save the activations to a global variable:

    with torch.no_grad():
        for batch in iterator:

            predictions = model(batch.text).squeeze(1)

Then, just to test it, I take one of the activations and transpose it as required by the svcca module, and check the cca similarity with itself which should give a coefficient of 1

a = activations[1].cpu().detach().numpy()
a = a[0,:,:]
a = np.transpose(a, (1,0))

results = cca_core.get_cca_similarity(a, a, verbose=True)

What I get is this error:

  File "<stdin>", line 1, in <module>
  File "/home/main/Projects/svcca/", line 295, in get_cca_similarity
  File "/home/main/Projects/svcca/", line 162, in compute_ccas
    u, s, v = np.linalg.svd(arr)
  File "<__array_function__ internals>", line 6, in svd
  File "/home/main/.local/lib/python3.6/site-packages/numpy/linalg/", line 1636, in svd
    u, s, vh = gufunc(a, signature=signature, extobj=extobj)
ValueError: On entry to DLASCL parameter number 4 had an illegal value

I checked for NaNs and Infs and there seem to be none.

But if I make random matrices with the same shape (65, 128), I can get the cca similarity just fine. Anyone has any idea where I might be getting stuff wrong? Is there something wrong in how I get the activations from torch tensors to np arrays?

Thanks a lot.