How to cat layer with batch_size of 1 to a larger batch?

I have a set of data which I’ve used to generate an image, which I then iterate through convolution layers and then ultimately flatten the features into one layer so it becomes a size: [1, 11552].

I now want to connect that layer to be included in the training for each sample in a batch of say [4000, 310]. I know that if I copied that layer into a size [4000, 11552] I could then use torch.cat() to combine the features, however that seems needlessly wasteful (Though I may be missing something important).

Is there a way to connect that layer to each sample from the larger batch without having to duplicate the data? So that during training I would effectively be training a batch of size [4000, 11862]?

Hey @pumplerod
Could you share a better flow chart of the shapes of things you are dealing with?

I am a little confused with the way the model is processing the images and your question.

Sorry for the confusion. Basically it boils down to this… I have a dense layer which I would like to remain common to all samples from a batch…

Ideally I would not want to duplicate that common layer n_batch times since it’s always the same. I am currently creating a new tensor of ones the size of my batch of samples, and the width of the common layer. Then multiplying by the common layer and using torch.cat() to join them, but this just feels wrong.

I think @ptrblck has already solved this.

Hope this helps! :grinning_face_with_smiling_eyes:

1 Like

I spent some time going over the variations in the thread and I don’t quite see the situation I’m referring to. Perhaps I’ve missed it.

To simplify my issue I’m including code with error and my “work around”.

Here is the code which breaks because the torch.cat() bumps on not having the proper dimensions…

layers = [ torch.nn.Linear( 15, 5), 
           torch.nn.ReLU( inplace=True), 
           torch.nn.Linear( 5, 3),
           torch.nn.ReLU( inplace=True),
           torch.nn.Linear( 3, 1),
           torch.nn.Sigmoid() ]

model = torch.nn.Sequential( *layers)

single_layer = torch.randn( (1,10))
batch = torch.randn( (100,5))

input_layer = torch.cat( (single_layer, batch), dim=1)
output = model( input_layer)

And here is what I’ve done to get around the issue…

layers = [ torch.nn.Linear( 15, 5), 
           torch.nn.ReLU( inplace=True), 
           torch.nn.Linear( 5, 3),
           torch.nn.ReLU( inplace=True),
           torch.nn.Linear( 3, 1),
           torch.nn.Sigmoid() ]

model = torch.nn.Sequential( *layers)

single_layer = torch.randn( (1,10))
print( f'Constant Layer Size:: { single_layer.size()}')

batch = torch.randn( (100,5))
print( f'Batch Size:: { batch.size()}')

replicated_single_layer = torch.ones( (batch.size(0), single_layer.size(1)))
replicated_single_layer *= single_layer
print( f'Replicated Layer Size: { replicated_single_layer.size()}')

input_layer = torch.cat( (replicated_single_layer, batch), dim=1)
print( f'Input Layer Size: { input_layer.size()}')
output = model( input_layer)
print( output.size())

Perhaps the thread you’ve suggested has the solution staring me in the face and I’m just lost on the variation of its implementation.

Hey @pumplerod
I did some digging and I think this should be your best option:

single_layer = torch.randn((1,10))
batch = torch.randn((100,5))

single = single_layer.repeat(100,10)
input_layer = torch.cat([batch,single],dim=1)

Hope this helps :grinning_face_with_smiling_eyes:

1 Like

Oh, interesting. Thank you so much. Looks like it will certainly save a multiplication step. I’m not sure what the implications on autograd that will have but hopefully some speed improvement, even though it looks like it will still consume memory where it ought not.

I believe the slight modification to your suggestion is that I need to make sure to only repeat along the batch axis. Not repeat the single 10 nodes 10 times also. So my final code looks like:

single_layer = torch.randn( ( 1,10))
batch = torch.randn( ( 100,5))

replicated_single = single_layer.repeat( batch.size(0),1)
input_layer = torch.cat( [ batch,replicated_single], dim=1)