Hi,
I have a tf.data.Dataset format data which I get it through a map function as below:
dataset = source_dataset.map(encode_tf,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
def encode_tf(inputs):
inputs_plaintext = inputs['inputs_plaintext']
targets_plaintext = inputs['targets_plaintext']
encoded = tf.py_function(self.encode, [inputs_plaintext, targets_plaintext], [tf.int32, tf.int32, tf.int32, tf.int32])
input_ids, input_attention, target_ids,target_attention = encoded
input_ids.set_shape([None])
target_ids.set_shape([None])
input_attention.set_shape([None])
target_attention.set_shape([None])
data = {'input_ids': input_ids,
'labels':target_ids,
'attention_mask': input_attention,
'decoder_attention_mask': target_attention}
return data
I need to convert each element of data which is a dictionary to pytorch tensor values, is there a way I could convert the format ? Is there a function like map in tensorflow to do this reformatting fast?
thanks.