[Part V] End to End Guide for heart disease prediction : Explainability with Attention Mechanism
This is the last part of the 5-series blog, in the previous blog we discussed about deploying our model using Flask. In this final blog I will talk about how we can make our model explainable using a explainability technique called the Attention Mechanism.
Introduction
Explainability is vital for AI models as it fosters trust, transparency, and user acceptance. It enables stakeholders to understand how decisions are made, helps identify biases and errors, and facilitates model improvement. Additionally, explainable models ensure compliance with regulations and ethical standards, promoting responsible and fair AI deployment across various domains.
There are many explainability techniques but one of the techniques in AI known as attention mechanism is important for deep learning models. As it highlights the most relevant parts of input data considered during model decision-making. By assigning weights to different input elements, it allows users to understand which features or context influenced the model's predictions, enhancing interpretability and insight into its inner workings.
So, in this article we will explore adding the attention mechanism explainability technique to our heart disease prediction model.
Attention mechanism
First, we have to define a function for attention mechanism, the function should be able to perform the following operations:
1. **Reshaping Inputs**: The input data, denoted as `inputs`, is reshaped to have an additional dimension at the end. This new shape is `[batch_size, num_features, 1]`. The purpose of this reshaping is to prepare the inputs for the dot-product attention operation.
2. **Calculating Attention Weights**: The attention mechanism calculates the importance or relevance of each feature in the input data. This is achieved through the dot-product attention mechanism. It multiplies the reshaped inputs with its transpose. This operation computes the pairwise dot products between features of the input data.
3. **Applying Softmax Activation**: After obtaining the dot products, the softmax activation function is applied along the `num_features` dimension to obtain the attention weights. The softmax function normalises the dot products, turning them into values between 0 and 1, where higher values represent higher importance or attention.
4. **Calculating Context Vector**: The attention weights obtained in the previous step represent the importance of each feature in the input data. Now, these weights are used to calculate a weighted sum of the input features, resulting in the context vector. This is done by element-wise multiplication of the original `inputs` with the attention weights, and then using `tf.reduce_sum` to perform the summation along the `num_features` dimension. The final context vector is obtained.
def attention_mechanism(inputs):
inputs_with_dim = tf.expand_dims(inputs, axis=-1)
attention_weights = tf.nn.softmax(tf.matmul(inputs_with_dim, inputs_with_dim, transpose_a=True), axis=1)
context_vector = tf.reduce_sum(inputs * attention_weights, axis=1)
return context_vector
attention_output = layers.Attention(name="attention")([all_features, all_features])
Implementation of self-attention mechanism using Keras functional API should be done. It applies attention to two sets of input features, capturing important dependencies within the data. The output represents the attention-weighted representation of the input features.
Visualization
def get_attention_weights(model, input_dict):
attention_layer_output = model.get_layer("attention")
attention_model = keras.Model(inputs=model.input, outputs=attention_layer_output.output)
attention_output = attention_model(input_dict)
return attention_output
attention_output = get_attention_weights(model, sample_input)
normalized_attention = tf.nn.softmax(attention_output, axis=-1)
import numpy as np
import matplotlib.pyplot as plt
feature_names = list(sample_input.keys())
num_features = len(feature_names)
tick_placement = np.arange(0, num_features, 1)
plt.figure(figsize=(10, 6))
plt.imshow(normalized_attention.numpy(), cmap="viridis", aspect="auto")
plt.xlabel("Features")
plt.ylabel("Attention Weights")
plt.title("Attention Heatmap")
plt.colorbar()
plt.show()
Visualization of attention weights from the trained model with the attention layer should be done. So, define a function to obtain attention weights for a specific input sample and normalizes them using softmax. The sample input, represented as a dictionary, is used to compute the attention output. The code then creates a heatmap using Matplotlib, where each feature corresponds to the x-axis, and attention weights are represented on the y-axis.
The heatmap illustrates the relative importance of features during prediction, aiding in model interpretability. By displaying the attention distribution, users can understand which input elements the model focuses on, enhancing insight into the decision-making process.
Conclusion
In conclusion, attention mechanisms play a vital role in enhancing the explainability of deep learning classification models. By capturing the importance of different input elements, attention mechanisms provide valuable insights into the decision-making process of the model. This transparency fosters trust in the model's predictions, improves model debugging, and helps identify potential biases or errors. Powerful visualization tools like heatmap can also be used. Additionally, attention mechanisms enable users to better understand how specific features influence the model's output, making the model more interpretable and facilitating human-AI collaboration.
Also check out
Finding Harmony in Dialogue : Balancing ChatGPT and Human Interactions for effective Communication
References
[1] https://cloud.google.com/explainable-ai
[5] https://theaisummer.com/attention/
By Maddula Syam Pavan