How to implement an equivalent of tf.gather in pytorch

Good day all,

I have written codes in both tensorflow and pytorch to create a modulated signal. The tensorflow code is working perfectly, but the equivalent pytorch isn’t. I understand that the problem arises from the way the indices are mapped to a tensor in pytorch. Could you please help me figure out how to correctly implement the equivalent indices tensor mapping in pytorch. The codes are shown below:

import os
import tensorflow as tf
from torch.autograd import Variable
import numpy as np

class Modulator(object):

    def __init__(self, mod_type, K): 
        # Set modulation type
        if (mod_type not in ['BPSK', '4PAM']):
            raise(Exception('Modulator: Unknown modulation format'))
        self.mod_type = mod_type
        self.K = K
        
        # Create constellation
        if (self.mod_type == 'BPSK'):
            self.constellation = np.array([-1.0, 1.0])
        
        elif (self.mod_type == '4PAM'):
            self.constellation = np.array([-3.0, -1.0, 1.0, 3.0])
        
        self.constellation_size = self.constellation.shape[0]
       
        # Normalize constellation to unit power and convert to tensor
        self.constellation /= np.sqrt(np.mean(np.abs(self.constellation)**2))
        self.constellation = tf.Variable(self.constellation, trainable=False, dtype=tf.float32)
        return
    
    def random_indices(self, batch_size=4):
        '''Generate random constellation symbol indices'''
        indices = tf.random_uniform(shape=[batch_size, self.K], minval=0, maxval=self.constellation_size,dtype=tf.int32)
  
        return indices
    
    def modulate(self, indices):
        '''Map indices to constellation symbols'''
        x = tf.gather(self.constellation, indices)
        return x 
mod = Modulator('4PAM', 6)
indices = mod.random_indices(4)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(mod.random_indices(4)))
    print(sess.run(mod.modulate(indices)))
    print(indices.shape)

Output from tensorflow:

[[0 3 1 3 0 1]
 [1 1 2 0 1 3]
 [0 1 1 2 1 2]
 [3 1 3 1 1 2]]
[[-1.3416408  0.4472136 -1.3416408 -0.4472136  1.3416408 -0.4472136]
 [-0.4472136 -1.3416408 -1.3416408 -0.4472136 -1.3416408 -0.4472136]
 [ 0.4472136 -1.3416408  1.3416408 -1.3416408 -0.4472136 -1.3416408]
 [ 1.3416408 -1.3416408 -1.3416408  0.4472136  1.3416408 -0.4472136]]

Pytorch code:

import os
import torch
from torch.autograd import Variable
import numpy as np

class Modulator(object):

    def __init__(self, mod_type, K): 
        # Set modulation type
        if (mod_type not in ['BPSK', '4PAM']):
            raise(Exception('Modulator: Unknown modulation format'))
        self.mod_type = mod_type
        self.K = K
        
        # Create constellation
        if (self.mod_type == 'BPSK'):
            self.constellation = np.array([-1.0, 1.0])
            self.constellation = (torch.from_numpy(self.constellation)).type(torch.float32)
        elif (self.mod_type == '4PAM'):
            self.constellation = np.array([-3.0, -1.0, 1.0, 3.0])
            self.constellation = (torch.from_numpy(self.constellation)).type(torch.float32)
        
        self.constellation_size = self.constellation.shape[0]
       
        # Normalize constellation to unit power and convert to tensor
        self.constellation /= torch.sqrt(torch.mean(torch.abs(self.constellation)**2))
        return
    
    def random_indices(self, batch_size=4):
        '''Generate random constellation symbol indices'''
        indices = torch.FloatTensor(batch_size, self.K).uniform_(0, self.constellation_size).int()
  
        return indices
    
    def modulate(self, indices):
        '''Map indices to constellation symbols'''
        x = torch.gather(self.constellation, indices)
        return x 
mod = Modulator('4PAM', 6)
indices = mod.random_indices(4)
x = mod.modulate(indices)
print(indices)

Pytorch output:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-91-9cfe51fb9f8c> in <module>
     39 mod = Modulator('4PAM', 6)
     40 indices = mod.random_indices(4)
---> 41 x = mod.modulate(indices)
     42 print(indices)
     43 

<ipython-input-91-9cfe51fb9f8c> in modulate(self, indices)
     35     def modulate(self, indices):
     36         '''Map indices to constellation symbols'''
---> 37         x = torch.gather(self.constellation, indices)
     38         return x
     39 mod = Modulator('4PAM', 6)

TypeError: gather(): argument 'dim' (position 2) must be int, not Tensor

Thank you very much for your help

You should use Tensor.select(0, index) or simple slicing syntax tensor[index].

https://pytorch.org/docs/stable/tensors.html#torch.Tensor.select

Note that the index must be LongTensor.

def random_indices(self, batch_size=4):
        '''Generate random constellation symbol indices'''
        indices = torch.FloatTensor(batch_size, self.K).uniform_(0, self.constellation_size).long()
        return indices
   
    def modulate(self, indices):
        '''Map indices to constellation symbols'''
        x = self.constellation[indices]
        return x

Thank you very much Tony, it worked. It looked stupid of me spending the whole day trying to figure this out.