How to invert stacking a set of tensors?

Hi All,

TL;DR How can I de-stack a large tensor into a dict of Tensors in PyTorch?

Let’s say I compute the gradients of a model for a set of samples and concatenate them into a single tensor, is there a way to invert this process?

I have a minimal reproducible example below, is there an efficient way to convert grad_params_stack back to grads_params?

import torch
from torch import nn
from torch.func import jacrev, vmap, functional_call

class network(nn.Module):

  def __init__(self, num_input, num_hidden, num_output):
    super(network, self).__init__()
    self.num_input = num_input
    self.num_hidden = num_hidden
    self.num_output = num_output
    
    self.func = nn.Tanh()
    
    self.fc1=nn.Linear(self.num_input, self.num_hidden)
    self.fc2=nn.Linear(self.num_hidden, self.num_hidden)
    self.fc3=nn.Linear(self.num_hidden, self.num_hidden)
    self.fc4=nn.Linear(self.num_hidden, self.num_output)
    
  def forward(self, x):
    x=self.fc1(x)
    x=self.func(x)
    x=self.fc2(x)
    x=self.func(x)
    x=self.fc3(x)
    x=self.func(x)
    x=self.fc4(x)
    return x.squeeze(0)
    
batch_size=100
num_input=4
num_hidden=32
num_output=1

net = network(num_input=num_input, num_hidden=num_hidden, num_output=num_output)
num_params = sum([p.numel() for p in net.parameters()])

#inputs
params = dict(net.named_parameters())
x = torch.randn(batch_size, num_input)

def fcall(params, x):
  return functional_call(net, params, x)

#per-sample gradients of the network w.r.t model params
grads_params = vmap(jacrev(fcall, argnums=(0)), in_dims=(None,0))(params, x) #compute gradients via torch.func

grad_params_stack = torch.cat([v.flatten(start_dim=1) for v in grads_params.values()], dim=-1)
print(grad_params_stack.shape) # batch_size by num_params 

#How to invert grad_params_stack back to grads_params?

Figured it out,

#How to invert params_stack back to params?
import numpy as np
param_loc = np.cumsum([0]+[p.numel() for p in net.parameters()])
layer_idx = [[param_loc[i], param_loc[i+1]] for i in range(len(param_loc)-1)]
grad_params_destacked = {key: grad_params_stack[:,indices[0]:indices[1]].reshape(batch_size, *value.shape) for (key, value), indices in zip(params.items(), layer_idx)}

for (k1,v1), (k2,v2) in zip(grads_params.items(), grad_params_destacked.items()):
  assert k1==k2
  print(k1, torch.allclose(v1,v2)) #returns True