ValueError: Input images must have the same dimensions.

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Dropout, GlobalAveragePooling2D, MaxPooling2D, GlobalMaxPooling2D, BatchNormalization, Reshape, Multiply, Add, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import SGD
from skimage.metrics import structural_similarity as ssim
from skimage.transform import resize
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

Define the self attention

class SelfAttention(tf.keras.layers.Layer):
def init(self, num_channels):
super(SelfAttention, self).init()
self.query_conv = tf.keras.layers.Conv2D(num_channels // 8, kernel_size=1)
self.key_conv = tf.keras.layers.Conv2D(num_channels // 8, kernel_size=1)
self.value_conv = tf.keras.layers.Conv2D(num_channels, kernel_size=1)
self.gamma = self.add_weight(‘gamma’, shape=[1], initializer=‘zeros’, trainable=True)

def call(self, inputs):
    batch_size, height, width, num_channels = inputs.shape
    query = self.query_conv(inputs)
    key = self.key_conv(inputs)
    value = self.value_conv(inputs)
    energy = tf.matmul(query, tf.transpose(key, [0, 1, 3, 2]))
    attention = tf.nn.softmax(energy, axis=-1)
    context = tf.matmul(attention, value)
    out = self.gamma * context + inputs
    return out, attention

Define the channel attention

class ChannelAttention(tf.keras.layers.Layer):
def init(self, num_channels, reduction_ratio=8):
super(ChannelAttention, self).init()
self.avg_pool = tf.keras.layers.GlobalAveragePooling2D()
self.max_pool = tf.keras.layers.GlobalMaxPooling2D()
self.fc1 = tf.keras.layers.Dense(num_channels // reduction_ratio, activation=‘relu’)
self.fc2 = tf.keras.layers.Dense(num_channels, activation=‘sigmoid’)

def call(self, inputs):
    avg_pool = self.avg_pool(inputs)
    max_pool = self.max_pool(inputs)
    avg_fc = self.fc2(self.fc1(avg_pool))
    max_fc = self.fc2(self.fc1(max_pool))
    max_fc = tf.expand_dims(max_fc, axis=1)  # Add this line
    attention = avg_fc[:, tf.newaxis, :] * max_fc[:, :, tf.newaxis]
    out = attention * inputs
    return out, attention

Define the spatial attention

class SpatialAttention(tf.keras.layers.Layer):
def init(self, kernel_size=7):
super(SpatialAttention, self).init()
self.conv = tf.keras.layers.Conv2D(filters=1, kernel_size=kernel_size, padding=‘same’, activation=‘sigmoid’)

def call(self, inputs):
    attention = self.conv(inputs)
    out = attention * inputs
    return out, attention

Define the channel-spatial attention

class ChannelSpatialAttention(tf.keras.layers.Layer):
def init(self, num_channels, reduction_ratio=8, kernel_size=7):
super(ChannelSpatialAttention, self).init()
self.channel_attention = ChannelAttention(num_channels, reduction_ratio)
self.spatial_attention = SpatialAttention(kernel_size)

def call(self, inputs):
    out, attention_channel = self.channel_attention(inputs)
    out, attention_spatial = self.spatial_attention(out)
    attention = attention_channel * attention_spatial
    out = attention * inputs
    return out, attention

Define a function to get attention map for input image

def get_attention_map(model, image, size=(224, 224)):
layer_outputs = [layer.output for layer in model.layers if ‘attention’ in layer.name]
activation_model = tf.keras.models.Model(inputs=model.inputs, outputs=layer_outputs)
activations = activation_model.predict(image)
attention_maps = []
for activation in activations:
if isinstance(activation, tuple):
activation = activation[0]
if activation.ndim != 4:
continue
attention_map = tf.reduce_mean(activation, axis=-1)
attention_map = tf.expand_dims(attention_map, axis=-1)
attention_maps.append(attention_map)
return attention_maps

Load an image

img_path = ‘sample_img.jpg’
image = tf.keras.preprocessing.image.load_img(img_path, target_size=(224, 224))
image = tf.keras.preprocessing.image.img_to_array(image)
image = tf.expand_dims(image, axis=0)

Create the models

model_self_attention = tf.keras.applications.ResNet50(include_top=False, input_shape=(224, 224, 3))
model_channel_attention = tf.keras.applications.ResNet50(include_top=False, input_shape=(224, 224, 3))
model_spatial_attention = tf.keras.applications.ResNet50(include_top=False, input_shape=(224, 224, 3))
model_channel_spatial_attention = tf.keras.applications.ResNet50(include_top=False, input_shape=(224, 224, 3))

for layer in model_self_attention.layers:
if ‘conv5_block3_out’ in layer.name:
out, _ = SelfAttention(2048)(layer.output)
model_self_attention = tf.keras.models.Model(inputs=model_self_attention.inputs, outputs=out)

for layer in model_channel_attention.layers:
if ‘conv5_block3_out’ in layer.name:
out, _ = ChannelAttention(2048)(layer.output)
model_channel_attention = tf.keras.models.Model(inputs=model_channel_attention.inputs, outputs=out)

for layer in model_spatial_attention.layers:
if ‘conv5_block3_out’ in layer.name:
out, _ = SpatialAttention()(layer.output)
model_spatial_attention = tf.keras.models.Model(inputs=model_spatial_attention.inputs, outputs=out)

for layer in model_channel_spatial_attention.layers:
if ‘conv5_block3_out’ in layer.name:
out, _ = ChannelSpatialAttention(2048)(layer.output)
model_channel_spatial_attention = tf.keras.models.Model(inputs=model_channel_spatial_attention.inputs, outputs=out)

Get the attention maps

attention_maps_self_attention = get_attention_map(model_self_attention, image)
attention_maps_channel_attention = get_attention_map(model_channel_attention, image)
attention_maps_spatial_attention = get_attention_map(model_spatial_attention, image)
attention_maps_channel_spatial_attention = get_attention_map(model_channel_spatial_attention, image)

Plot the attention maps

plt.figure(figsize=(20, 20))
plt.subplot(331)
plt.imshow(tf.keras.preprocessing.image.array_to_img(image[0]))
plt.title(‘Original Image’)
plt.axis(‘off’)

attention_maps_self_attention = get_attention_map(model_self_attention, image)
plt.subplot(332)
plt.imshow(tf.squeeze(attention_maps_self_attention[0]))
plt.title(‘Self-Attention Map’)
plt.axis(‘off’)

attention_maps_channel_attention = get_attention_map(model_channel_attention, image)
plt.subplot(333)
plt.imshow(tf.squeeze(attention_maps_channel_attention[0]))
plt.title(‘Channel Attention Map’)
plt.axis(‘off’)

plt.subplot(334)
plt.imshow(tf.squeeze(attention_maps_spatial_attention[0]))
plt.title(‘Spatial Attention Map’)
plt.axis(‘off’)

attention_maps_channel_spatial_attention = get_attention_map(model_channel_spatial_attention, image)
plt.subplot(335)
plt.imshow(tf.squeeze(attention_maps_channel_spatial_attention[0]))
plt.title(‘Channel-Spatial Attention Map’)
plt.axis(‘off’)

plt.tight_layout()
plt.show()

print(image.shape)
print(attention_maps_self_attention)

Calculate the SSIM scores between the attention maps and the original image

ssim_score_self_attention = ssim(image[0], tf.expand_dims(attention_maps_self_attention[0], axis=-1))
ssim_score_channel_attention = ssim(image[0], tf.expand_dims(attention_maps_channel_attention[0], axis=-1))
ssim_score_spatial_attention = ssim(image[0], tf.expand_dims(attention_maps_spatial_attention[0], axis=-1))
ssim_score_channel_spatial_attention = ssim(image[0], tf.expand_dims(attention_maps_channel_spatial_attention[0], axis=-1))

Based on your code snippet it seems you are using TensorFlow with Keras, so I would recommend posting this question in their discussion board as you will find TF experts there :wink: