Multiplying vector of embedding to image with shape (h * w * 1)

Hello, I’m trying to migrate my GAN model code from Keras (backend TF) into Pytorch.

Currently I’m stuck at this operation, where I combined result from an embedding layer and image input thru element-wise multiplication.

In Keras, the code line would be such like this

label_embedding = Flatten()(Embedding(n_class, embed_dim)(label))
model_input = Multiply()([input_img, label_embedding])

Where label_embedding give a vector of size embed_dim and input_img is in shape of (h, w, 1), resulting in model_input in shape of (h, w, embed_dim)

In Pytorch, however, I can’t simply multiply the embedding result because of difference in dimension

self.embedding = nn.Embedding(n_class, embed_dim)

model_input = torch.mul(input_img, self.embedding(y))

How do I reshape the embedding vector to match with image input and able to be operated using torch.mul?