I have below implementation which can flatten and unflatten both the net itself.Hopefully you can use some of that.
import numpy as np
#############################################################################
# Flattening the NET
#############################################################################
def flattenNetwork(net):
flatNet = []
shapes = []
for param in net.parameters():
#if its WEIGHTS
curr_shape = param.cpu().data.numpy().shape
shapes.append(curr_shape)
if len(curr_shape) == 2:
param = param.cpu().data.numpy().reshape(curr_shape[0]*curr_shape[1])
flatNet.append(param)
elif len(curr_shape) == 4:
param = param.cpu().data.numpy().reshape(curr_shape[0]*curr_shape[1]*curr_shape[2]*curr_shape[3])
flatNet.append(param)
else:
param = param.cpu().data.numpy().reshape(curr_shape[0])
flatNet.append(param)
finalNet = []
for obj in flatNet:
for x in obj:
finalNet.append(x)
finalNet = np.array(finalNet)
return finalNet,shapes
#############################################################################
# UN-Flattening the NET
#############################################################################
def unFlattenNetwork(weights, shapes):
#this is how we know how to slice weights
begin_slice = 0
end_slice = 0
finalParams = []
#print(len(weights))
for idx,shape in enumerate(shapes):
if len(shape) == 2:
end_slice = end_slice+(shape[0]*shape[1])
curr_slice = weights[begin_slice:end_slice]
param = np.array(curr_slice).reshape(shape[0], shape[1])
finalParams.append(param)
begin_slice = end_slice
elif len(shape) == 4:
end_slice = end_slice+(shape[0]*shape[1]*shape[2]*shape[3])
curr_slice = weights[begin_slice:end_slice]
#print("shape: "+str(shape))
#print("curr_slice: "+str(curr_slice.shape))
param = np.array(curr_slice).reshape(shape[0], shape[1], shape[2], shape[3])
finalParams.append(param)
begin_slice = end_slice
else:
end_slice = end_slice+shape[0]
curr_slice = weights[begin_slice:end_slice]
param = np.array(curr_slice).reshape(shape[0],)
finalParams.append(param)
begin_slice = end_slice
finalArr = np.array(finalParams)
return np.array(finalArr)
flat_weights,shapes=flattenNetwork(model)
unFlattenNetwork(flat_weights,shapes) --Gives you Numpy n-dimensional array in your case the image which you can directly assign to a variable