import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Dropout, GlobalAveragePooling2D, MaxPooling2D, GlobalMaxPooling2D, BatchNormalization, Reshape, Multiply, Add, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import SGD
from skimage.metrics import structural_similarity as ssim
from skimage.transform import resize
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
Define the self attention
class SelfAttention(tf.keras.layers.Layer):
def init(self, num_channels):
super(SelfAttention, self).init()
self.query_conv = tf.keras.layers.Conv2D(num_channels // 8, kernel_size=1)
self.key_conv = tf.keras.layers.Conv2D(num_channels // 8, kernel_size=1)
self.value_conv = tf.keras.layers.Conv2D(num_channels, kernel_size=1)
self.gamma = self.add_weight(‘gamma’, shape=[1], initializer=‘zeros’, trainable=True)
def call(self, inputs):
batch_size, height, width, num_channels = inputs.shape
query = self.query_conv(inputs)
key = self.key_conv(inputs)
value = self.value_conv(inputs)
energy = tf.matmul(query, tf.transpose(key, [0, 1, 3, 2]))
attention = tf.nn.softmax(energy, axis=-1)
context = tf.matmul(attention, value)
out = self.gamma * context + inputs
return out, attention
Define the channel attention
class ChannelAttention(tf.keras.layers.Layer):
def init(self, num_channels, reduction_ratio=8):
super(ChannelAttention, self).init()
self.avg_pool = tf.keras.layers.GlobalAveragePooling2D()
self.max_pool = tf.keras.layers.GlobalMaxPooling2D()
self.fc1 = tf.keras.layers.Dense(num_channels // reduction_ratio, activation=‘relu’)
self.fc2 = tf.keras.layers.Dense(num_channels, activation=‘sigmoid’)
def call(self, inputs):
avg_pool = self.avg_pool(inputs)
max_pool = self.max_pool(inputs)
avg_fc = self.fc2(self.fc1(avg_pool))
max_fc = self.fc2(self.fc1(max_pool))
max_fc = tf.expand_dims(max_fc, axis=1) # Add this line
attention = avg_fc[:, tf.newaxis, :] * max_fc[:, :, tf.newaxis]
out = attention * inputs
return out, attention
Define the spatial attention
class SpatialAttention(tf.keras.layers.Layer):
def init(self, kernel_size=7):
super(SpatialAttention, self).init()
self.conv = tf.keras.layers.Conv2D(filters=1, kernel_size=kernel_size, padding=‘same’, activation=‘sigmoid’)
def call(self, inputs):
attention = self.conv(inputs)
out = attention * inputs
return out, attention
Define the channel-spatial attention
class ChannelSpatialAttention(tf.keras.layers.Layer):
def init(self, num_channels, reduction_ratio=8, kernel_size=7):
super(ChannelSpatialAttention, self).init()
self.channel_attention = ChannelAttention(num_channels, reduction_ratio)
self.spatial_attention = SpatialAttention(kernel_size)
def call(self, inputs):
out, attention_channel = self.channel_attention(inputs)
out, attention_spatial = self.spatial_attention(out)
attention = attention_channel * attention_spatial
out = attention * inputs
return out, attention
Define a function to get attention map for input image
def get_attention_map(model, image, size=(224, 224)):
layer_outputs = [layer.output for layer in model.layers if ‘attention’ in layer.name]
activation_model = tf.keras.models.Model(inputs=model.inputs, outputs=layer_outputs)
activations = activation_model.predict(image)
attention_maps = []
for activation in activations:
if isinstance(activation, tuple):
activation = activation[0]
if activation.ndim != 4:
continue
attention_map = tf.reduce_mean(activation, axis=-1)
attention_map = tf.expand_dims(attention_map, axis=-1)
attention_maps.append(attention_map)
return attention_maps
Load an image
img_path = ‘sample_img.jpg’
image = tf.keras.preprocessing.image.load_img(img_path, target_size=(224, 224))
image = tf.keras.preprocessing.image.img_to_array(image)
image = tf.expand_dims(image, axis=0)
Create the models
model_self_attention = tf.keras.applications.ResNet50(include_top=False, input_shape=(224, 224, 3))
model_channel_attention = tf.keras.applications.ResNet50(include_top=False, input_shape=(224, 224, 3))
model_spatial_attention = tf.keras.applications.ResNet50(include_top=False, input_shape=(224, 224, 3))
model_channel_spatial_attention = tf.keras.applications.ResNet50(include_top=False, input_shape=(224, 224, 3))
for layer in model_self_attention.layers:
if ‘conv5_block3_out’ in layer.name:
out, _ = SelfAttention(2048)(layer.output)
model_self_attention = tf.keras.models.Model(inputs=model_self_attention.inputs, outputs=out)
for layer in model_channel_attention.layers:
if ‘conv5_block3_out’ in layer.name:
out, _ = ChannelAttention(2048)(layer.output)
model_channel_attention = tf.keras.models.Model(inputs=model_channel_attention.inputs, outputs=out)
for layer in model_spatial_attention.layers:
if ‘conv5_block3_out’ in layer.name:
out, _ = SpatialAttention()(layer.output)
model_spatial_attention = tf.keras.models.Model(inputs=model_spatial_attention.inputs, outputs=out)
for layer in model_channel_spatial_attention.layers:
if ‘conv5_block3_out’ in layer.name:
out, _ = ChannelSpatialAttention(2048)(layer.output)
model_channel_spatial_attention = tf.keras.models.Model(inputs=model_channel_spatial_attention.inputs, outputs=out)
Get the attention maps
attention_maps_self_attention = get_attention_map(model_self_attention, image)
attention_maps_channel_attention = get_attention_map(model_channel_attention, image)
attention_maps_spatial_attention = get_attention_map(model_spatial_attention, image)
attention_maps_channel_spatial_attention = get_attention_map(model_channel_spatial_attention, image)
Plot the attention maps
plt.figure(figsize=(20, 20))
plt.subplot(331)
plt.imshow(tf.keras.preprocessing.image.array_to_img(image[0]))
plt.title(‘Original Image’)
plt.axis(‘off’)
attention_maps_self_attention = get_attention_map(model_self_attention, image)
plt.subplot(332)
plt.imshow(tf.squeeze(attention_maps_self_attention[0]))
plt.title(‘Self-Attention Map’)
plt.axis(‘off’)
attention_maps_channel_attention = get_attention_map(model_channel_attention, image)
plt.subplot(333)
plt.imshow(tf.squeeze(attention_maps_channel_attention[0]))
plt.title(‘Channel Attention Map’)
plt.axis(‘off’)
plt.subplot(334)
plt.imshow(tf.squeeze(attention_maps_spatial_attention[0]))
plt.title(‘Spatial Attention Map’)
plt.axis(‘off’)
attention_maps_channel_spatial_attention = get_attention_map(model_channel_spatial_attention, image)
plt.subplot(335)
plt.imshow(tf.squeeze(attention_maps_channel_spatial_attention[0]))
plt.title(‘Channel-Spatial Attention Map’)
plt.axis(‘off’)
plt.tight_layout()
plt.show()
print(image.shape)
print(attention_maps_self_attention)
Calculate the SSIM scores between the attention maps and the original image
ssim_score_self_attention = ssim(image[0], tf.expand_dims(attention_maps_self_attention[0], axis=-1))
ssim_score_channel_attention = ssim(image[0], tf.expand_dims(attention_maps_channel_attention[0], axis=-1))
ssim_score_spatial_attention = ssim(image[0], tf.expand_dims(attention_maps_spatial_attention[0], axis=-1))
ssim_score_channel_spatial_attention = ssim(image[0], tf.expand_dims(attention_maps_channel_spatial_attention[0], axis=-1))