Graph Neural Networks: Vanilla, GCN, GAT, GIN
Introduction
Graph Neural Networks (GNNs) have become a critical tool in the machine learning toolbox, providing a powerful way to analyze and interpret graph-structured data. As data in the real world often exhibit graph-like characteristics — from social networks and web pages to molecules and proteins — understanding how to apply GNNs is increasingly important.
In this article, we'll delve into the world of GNNs, starting with an introduction to the concept of message passing in these networks. We will then explore three specific types of GNNs: Graph Convolutional Networks (GCNs), Graph Attention Networks (GATs), and Graph Isomorphism Networks (GINs). Each of these models offers unique advantages and challenges when dealing with graph-structured data.
Graph Theory Basics
Before we dive into GNNs, it's crucial to understand some fundamental concepts from graph theory, which forms the backbone of these models.
A graph G can be defined as a pair <V,E> where V is a set of vertices (or nodes) and E is a set of edges. Each edge connects a pair of vertices and can be represented as an ordered pair (v_i, v_j). In the context of GNNs, these vertices and edges often represent entities and relationships, respectively.
A critical concept for understanding GNNs is the adjacency matrix A. For a graph with n vertices, the adjacency matrix is an $n \times n$ matrix where the entry $A_{ij}$ is 1 if there's an edge from vertex $v_i$ to vertex $v_j$ and 0 otherwise.
Along with the adjacency matrix, another important concept is the degree matrix D. This is a diagonal matrix where each entry $D_{ii}$ equals the degree of vertex $v_i$ (i.e., the number of edges connected to $v_i$). All non-diagonal elements are 0 in this matrix.
These matrices play a pivotal role in the functioning of GNNs, as we'll see in the upcoming sections. Please note that the specific formulae and the way matrices are used might change depending on the type of graph (undirected, directed, weighted, etc.) and the specific type of GNN. The general ideas, however, remain the same.
Message Passing Neural Networks
One of the core principles of GNNs is the idea of "message passing", which refers to the process of aggregating and passing information between connected nodes in a graph. This concept is inspired by the nature of graph-structured data, where the features of a node are often deeply influenced by its neighboring nodes.
The general process of message passing in a GNN can be broken down into the following steps:
- Message Function: This function, often denoted as
M, is used to compute a "message" that a node sends to its neighboring nodes. The message is typically a function of the features of the node and its neighbors. - Aggregate Function: After all nodes have sent their messages, the next step is to aggregate the received messages at each node. This is done using an aggregate function, often denoted as
A. This function takes all the messages that a node has received and combines them into a single vector.- Notice the order of the nodes in a graph is arbitrary, therefore you need order-invariant aggregation functions, like the
sumor theaverage. Notice subtraction, division, concatenation are all order-dependent.
- Notice the order of the nodes in a graph is arbitrary, therefore you need order-invariant aggregation functions, like the
- Update Function: Once the aggregated message has been computed, an update function, often denoted as
U, is applied to update the node's features based on the aggregated message.
Mathematically, these steps can be expressed as follows:
Step 1: Each node $v_i$ computes a message for each of its neighbors $v_j$, where $h_i$ and $h_j$ are the feature vectors of nodes i and j, and $e_{ij}$ is the feature vector of the edge connecting them:
Step 2: Each node $v_i$ aggregates the messages from its neighbors, where $N(i)$ is the set of neighbors of node $i$:
$$a_i = A(\{m_{ij}\}_{j \in N(i)})$$Step 3: Each node $v_i$ updates its feature vector, where $h'_i$ is the updated feature vector of node $i$:
$$h'_i = U(h_i, a_i)$$This process of message passing can be repeated for a number of rounds, allowing information to propagate through the graph. After the final round, the updated node features are used for downstream tasks such as node classification, link prediction, or graph classification.
The power of message passing neural networks lies in their ability to learn complex patterns in the structure and features of the graph, providing a flexible framework for working with graph-structured data.
More info: Distill.pub's "A gentle introduction to GNNs"
Vanilla Graph Neural Networks and Graph Convolutional Networks
Before diving into Graph Convolutional Networks (GCNs), let's take a brief look at the most basic type of Graph Neural Network (GNN), often referred to as a Vanilla GNN.
Vanilla Graph Neural Networks
The simplest form of GNNs, sometimes called Vanilla GNNs, follows the message passing framework described in the previous section. The main idea behind these networks is to update the representation of each node by aggregating information from its neighboring nodes.
One of the most common formulations of a Vanilla GNN is the following:
$$h_i^{(l+1)}=\sigma(W^{l} \cdot \text{AGGREGATE}^l\{ h_j^{l} \mid v_j \in N(i) \} )$$Here:
- $h^{(l+1)}_i$ is the feature representation of node $i$ at layer $l+1$ — typically a vector
- $W^l$ is a learnable weight matrix for layer $l$
- $\text{AGGREGATE}^l$ is an aggregation function for layer $l$ (such as mean, sum, or max)
- $\sigma$ is a non-linear activation function (such as ReLU)
- The set $\{ h^l_j \mid v_j \in N(i) \}$ contains the feature representations of the neighboring nodes of node $i$ at layer $l$
This equation essentially says that to update the feature representation of a node, we:
- Aggregate the feature representations of its neighbors
- Apply a linear transformation to the aggregated representation
- Finally apply a non-linear activation function
As you can see, it does not consider the current node $h_i$ representation. The "message" being sent here is just the $h_j$ neighbor vector: so the message function here is the identity function!
Graph Convolutional Networks (GCNs)
Building on the concept of Vanilla GNNs, Graph Convolutional Networks (GCNs) introduce the idea of convolution operations on graphs.
GCNs use a similar update rule as Vanilla GNNs, but with a couple of key differences. First, in GCNs, the aggregation function is typically the mean function. Second, and more importantly, the feature representation of a node in GCNs is updated based not only on its neighbors' representations, but also on its own representation.
The update rule for GCNs can be formulated as follows:
$$h_i^{(l+1)}=\sigma(W^{l} \cdot \text{MEAN}\{ h_j^{l} \mid v_j \in N(i) \cup h_i^l \} )$$The terms are the same as in the vanilla GNN: however notice the "$\cup h_i^l$" at the end to capture the current node vector.
One of the key strengths of GCNs is their ability to preserve both the local and global structure of the graph, enabling them to capture complex patterns in the data. However, they do have some limitations, such as the difficulty in handling graphs with different scales and the problem known as over-smoothing, which can degrade performance when many layers are used.
More info: Graph Convolutional Networks by Francesco Casalegno
Graph Attention Networks (GATs)
After understanding the basics of GNNs and the improvements made by GCNs, let's move on to another important variant of GNNs known as Graph Attention Networks (GATs). GATs introduce the concept of attention mechanisms to the realm of GNNs, which provides a weighted importance to the features of neighboring nodes rather than treating them equally.
Understanding Attention Mechanisms
Before we dive into GATs, it's important to understand what attention mechanisms are. In essence, attention mechanisms allow a model to focus on certain parts of the input more than others. This idea was first introduced in the field of Natural Language Processing (NLP) and has since been adapted to various other domains, including GNNs.
The key idea behind an attention mechanism is that it provides a set of weights, which sum to one and are used to create a weighted sum of the input features. The weights are typically computed using a function that considers the relative importance of different parts of the input.
Graph Attention Networks
In the context of GNNs, GATs use an attention mechanism to weigh the importance of a node's neighbors when aggregating their features. This allows GATs to capture different types of graph structures and to handle graphs where different neighbors have different levels of relevance to a node.
The update rule for GATs can be formulated as follows:
$$h^{(l+1)}_i = \sigma \left( \sum_{j \in N_{(i)}} \alpha^{(l)}_{ij}W^{(l)} h^{(l)}_j \right)$$Here, $\alpha^{(l)}_{ij}$ is the attention weight between nodes $i$ and $j$ at layer $l$, computed as:
$$\alpha^{(l)}_{ij} = \frac{\exp\left( \text{LeakyReLU}\left( \vec{a^{T^{(l)}}} \cdot [W^{(l)} h^{(l)}_i \| W^{(l)} h^{(l)}_j] \right) \right)}{\sum_{k \in N(i)} \exp\left( \text{LeakyReLU}\left( \vec{a^{T^{(l)}}} \cdot [W^{(l)} h^{(l)}_i \| W^{(l)} h^{(l)}_k] \right) \right)} = \text{softmax}_j(a(Wh_i, Wh_j))$$In this equation:
- $\vec{a^{(l)}}$ is a learnable vector
- $\|$ denotes concatenation, and the LeakyReLU function is applied element-wise
- $a^{(l)}$ can be a single-layer feedforward neural network, in this case parameterized by the weight vector and LeakyReLU
Multi-head attention
In addition to its basic implementation, multi-head attention in GATs works by applying the attention mechanism multiple times (once for each "head") in parallel, each with its own set of learnable parameters. The outputs of all the attention heads are then concatenated (for intermediary layers) or averaged (for the final layer) to form the final output.
$$h_i^{(l+1)} = \|_{k=1}^{K} \sigma \left( \sum_{j \in N(i)} \alpha^{(l,k)}_{ij} W^{(l,k)} h_j^{(l)} \right)$$ $$h_i^{(l+1)} = \sigma\left( \frac{1}{K} \sum_{k=1}^{K} \sum_{j \in N(i)} \alpha^{(l,k)}_{ij} W^{(l,k)} h_j^{(l)} \right) \quad \text{(last layer)}$$$k$ represents the number of independent attention mechanisms, or "heads".
- Notice that in this setting, the output will have a dimension of $KF'$ (where $F$ is the number of features, or the dimension of $h$) instead of $F$. That's why we need to normalize it in the last layer by averaging over $K$.
Compared to GCNs, GATs have the advantage of being able to assign different importance to different nodes in the neighborhood, allowing for a more flexible and potentially more expressive model of the graph structure. However, they can also be more computationally intensive due to the need to compute and store the attention weights.
More info: Graph Attention Networks original site
Graph Isomorphism Networks (GINs)
Now that we have covered the concepts of GNNs, GCNs, and GATs, let's delve into another important variant of GNNs: Graph Isomorphism Networks (GINs).
Overview of Graph Isomorphism Networks and the Weisfeiler-Lehman Dilemma
Graph Isomorphism Networks (GINs) provide a powerful solution to a fundamental challenge in graph representation learning, distinguishing between non-isomorphic graph structures.
This problem, termed the Weisfeiler-Lehman (WL) dilemma, represents a limitation observed in earlier graph neural network models like GCNs and GATs. The issue stems from the fact that these models were not expressive enough to differentiate between certain pairs of non-isomorphic graphs that can be distinguished by the WL test, a method used to determine whether two graphs are isomorphic or identical up to a relabeling of nodes.
GINs address this issue by incorporating a learnable and continuous version of the WL test into their architecture, thus elevating their expressiveness to the level of the WL test. This is achieved through the introduction of a learnable parameter, $\epsilon$, that adjusts the balance between a node's own features and the aggregated features of its neighbors.
It's worth noting that graph isomorphism itself is an NP problem, meaning there is no known efficient algorithm for solving it in the worst case.
The GIN Update Rule
The update rule for GINs can be formulated as follows:
$$h^{(l+1)}_i = \text{MLP} \left( (1 + \epsilon^{(l)}) h^{(l)}_i + \sum_{j \in N(i)} h^{(l)}_j \right)$$Here, $\epsilon^{(l)}$ is a learnable parameter for layer $l$ that determines the relative importance of a node's own features versus the features of its neighbors.
The term $(1 + \epsilon^{(l)}) h^{(l)}_i + \sum_{j \in N(i)} h^{(l)}_j$ can be viewed as a more powerful aggregation function that includes the node's own features and allows for different levels of importance for the node's own features versus the neighbors' features.
Compared to GCNs and GATs, GINs have the potential to capture a broader range of graph structures due to their more powerful aggregation function. However, like GATs, they can also be more computationally intensive due to the need to compute and store additional parameters.
More info:
- A great article by David Bieber on the WL Isomorphism test
- Another article by Professor Bronstein on the expressivity of GNNs
Conclusion
In conclusion, Graph Neural Networks (GNNs) have emerged as a remarkable approach to tackle the exciting challenges posed by graph-structured data. With their ability to capture intricate relationships and dependencies within graphs, GNN variants have opened up a world of possibilities across diverse domains.
Node classification, one of the key applications of GNNs, finds itself in good company. From predicting the interests of social media users to identifying protein functions in bioinformatics, GNNs excel at understanding the characteristics of individual nodes within a graph.
Link prediction, another fascinating application, allows us to unravel hidden connections and anticipate missing edges in a graph. GNNs prove their mettle by discerning potential videos we want to watch in social networks, foreseeing collaborations in citation networks, or even predicting potential drug-target interactions in the field of drug discovery.
As we traverse the realm of graph representation learning, the potential for GNNs seems boundless. With their ability to capture complex patterns, model intricate relations, and unleash the potential hidden within graphs, GNNs are an extremely versatile tool in our ML toolkit, enabling us to unravel the mysteries of interconnected systems one node and one link at a time.