Pruning nn.Linear weights causing unexpected errors

This fails

import torch

def test1():  
  layer = nn.Linear(100, 10)
  x = 5 - torch.sum(layer(torch.ones(100)))
  x.backward()
  layer.weight.data = layer.weight.data[:, :90]
  layer.weight.grad.data = layer.weight.grad.data[:, :90]
  x = 5 - torch.sum(layer(torch.ones(90)))
  x.backward()
test1()

with error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-bb36a010bd86> in <cell line: 10>()
      8     x = 5 - torch.sum(layer(torch.ones(90)))
      9     x.backward()
---> 10 test1()
     11 # and this works as well
     12 

2 frames
/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    249     # some Python versions print out the first line of a multi-line function
    250     # calls in the traceback and some print out the last line
--> 251     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252         tensors,
    253         grad_tensors_,

RuntimeError: Function TBackward0 returned an invalid gradient at index 0 - got [10, 90] but expected shape compatible with [10, 100]

This works

import torch

def test2():  
  layer = torch.nn.Linear(100, 10)
  x = 5 - torch.sum(layer(torch.ones(100)))
  x.backward()
  del x    #main change
  layer.weight.data = layer.weight.data[:, :90]
  layer.weight.grad.data = layer.weight.grad.data[:, :90]
  x = 5 - torch.sum(layer(torch.ones(90)))
  x.backward()
test2()

and this works as well

import torch
def test3():  
  layer = torch.nn.Linear(100, 10)
  x = 5 - torch.sum(layer(torch.ones(100)))
  x.backward()
  layer.weight.data = layer.weight.data[:, :90]
  layer.weight.grad.data = layer.weight.grad.data[:, :90]
  layer.weight = torch.nn.Parameter(layer.weight)   #main change
  x = 5 - torch.sum(layer(torch.ones(90)))
  x.backward()
test3()

I encountered this when trying to implement a paper on model pruning. I believe this has something to do with the autograd graph, but I have am not sure what exactly is going on. Any explanation these almost identical code snippets work/fail would be very appreciated. Thanks in advance.