# How to invert stacking a set of tensors?

Hi All,

TL;DR How can I de-stack a large tensor into a dict of Tensors in PyTorch?

Let’s say I compute the gradients of a model for a set of samples and concatenate them into a single tensor, is there a way to invert this process?

I have a minimal reproducible example below, is there an efficient way to convert `grad_params_stack` back to `grads_params`?

``````import torch
from torch import nn
from torch.func import jacrev, vmap, functional_call

class network(nn.Module):

def __init__(self, num_input, num_hidden, num_output):
super(network, self).__init__()
self.num_input = num_input
self.num_hidden = num_hidden
self.num_output = num_output

self.func = nn.Tanh()

self.fc1=nn.Linear(self.num_input, self.num_hidden)
self.fc2=nn.Linear(self.num_hidden, self.num_hidden)
self.fc3=nn.Linear(self.num_hidden, self.num_hidden)
self.fc4=nn.Linear(self.num_hidden, self.num_output)

def forward(self, x):
x=self.fc1(x)
x=self.func(x)
x=self.fc2(x)
x=self.func(x)
x=self.fc3(x)
x=self.func(x)
x=self.fc4(x)
return x.squeeze(0)

batch_size=100
num_input=4
num_hidden=32
num_output=1

net = network(num_input=num_input, num_hidden=num_hidden, num_output=num_output)
num_params = sum([p.numel() for p in net.parameters()])

#inputs
params = dict(net.named_parameters())
x = torch.randn(batch_size, num_input)

def fcall(params, x):
return functional_call(net, params, x)

#per-sample gradients of the network w.r.t model params

``````

Figured it out,

``````#How to invert params_stack back to params?
import numpy as np
param_loc = np.cumsum([0]+[p.numel() for p in net.parameters()])
layer_idx = [[param_loc[i], param_loc[i+1]] for i in range(len(param_loc)-1)]
grad_params_destacked = {key: grad_params_stack[:,indices[0]:indices[1]].reshape(batch_size, *value.shape) for (key, value), indices in zip(params.items(), layer_idx)}