Convert mask of shape [224,224,1] to mask [224,224,3]

Hello, I am having an issue while converting a mask of shape [224, 224, 1] to mask [224,224,3]. First of all, the image shape is [224,224,3]. I am trying to reshape both into the same shape. I have applied augmentation only to images. This is my preprocessing code: Where I made a mistake. Many thanks.

Code:

# Load the image and mask data
def process_path(image_path, mask_path):
  img = tf.io.read_file(image_path)
  img = tf.image.decode_png(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  
  mask = tf.io.read_file(mask_path)
  mask = tf.image.decode_png(mask, channels=3)
 
  mask = tf.image.convert_image_dtype(mask, tf.float32)
  mask = tf.math.reduce_max(mask, axis=-1, keepdims=True)

  print("mask1", mask.shape)

  return img, mask

def preprocess(image, mask):
  input_image = tf.image.resize(image, (224, 224), method='nearest')
  input_mask = tf.image.resize(mask, (224, 224), method='nearest')
  # input_mask = tf.expand_dims(tf.image.resize(mask, (224, 224), method='nearest'), axis=-1)
  print("input_image",input_image.shape)
  print("input_mask",input_mask.shape)
  input_image = input_image / 255.

  return input_image, input_mask

def augment(image, mask):
  image, mask = preprocess(image, mask)

  image = tf.image.random_flip_left_right(image)
  image = tf.image.random_brightness(image, max_delta=0.5)
  image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  image = tf.image.random_hue(image, max_delta=0.2)
  print("image.shape",image.shape)
  return image, mask

# Load and augment the data
def load_and_augment(image_path, mask_path):
  image, mask = process_path(image_path, mask_path)
  image, mask = augment(image, mask)
  return image, mask

# Create the dataset
def create_dataset(image_paths, mask_paths, batch_size=32, shuffle_buffer_size=None, repeat=True):
  images = tf.data.Dataset.from_tensor_slices(image_paths)
  masks = tf.data.Dataset.from_tensor_slices(mask_paths)
  dataset = tf.data.Dataset.zip((images, masks))
  dataset = dataset.map(load_and_augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  dataset = dataset.cache()
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

  if shuffle_buffer_size:
    dataset = dataset.shuffle(shuffle_buffer_size)
  dataset = dataset.batch(batch_size)
  
  if repeat:
    dataset = dataset.repeat(5)
  return dataset

# Create the training and validation datasets
batch_size = 32
train_dataset = create_dataset(train_image_filenames, train_masks_filenames, batch_size=batch_size, shuffle_buffer_size=len(train_image_filenames), repeat=True)
validation_dataset = create_dataset(test_image_filenames, test_masks_filenames, batch_size=batch_size, repeat=False)

print("train_dataset",train_dataset)
print("validation_dataset",validation_dataset )

This is output :

train_dataset <RepeatDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None))>
validation_dataset <BatchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None))>

Based on your code snippet it seems you are using TensorFlow so their discussion board might be a better place to post this question.

1 Like