Graph Neural Networks for Recommendation Systems: Link Prediction
Graph Neural Networks (GNNs)
Graph Neural Networks (GNNs) are a unique type of neural network designed to operate on data structured as a graph. They are an adaptation of convolutional neural networks (CNNs), which work well on grid-like data (like images), to non-grid data structures (like graphs).
GNNs work by iteratively updating node representations (embeddings) based on the representations of their neighboring nodes. Over several iterations, a node's representation begins to encode the structure of its local graph neighborhood, allowing for tasks such as node classification, link prediction, and graph classification.
If you want to deep dive into GNNs, refer to my detailed guide on Graph Neural Networks.
What is Link Prediction
Link prediction is a task in graph theory where the goal is to estimate the likelihood of a connection (link) between two nodes in a network. Link prediction can be used to uncover missing links, predict future connections, or suggest potential links in a network. For example, in a social network graph, link prediction could be used to suggest new friends or followers.
Relating link prediction to recommendation systems, imagine a situation where nodes represent users and items (like movies or books), and edges represent interactions (like purchases or ratings). A link prediction task in this context would involve predicting potential future interactions, i.e., recommendations. If we can predict a high likelihood of a link between a user and an item not yet interacted with, that item would be a good recommendation for the user.
Draw it with NetworkX and PyTorch Geometric
Creating a Simple Graph
NetworkX is a Python library that allows you to create, manipulate, and study the structure, dynamics, and functions of complex networks.
Here's how you can create a simple graph using NetworkX:
import networkx as nx
# Create an empty graph object
G = nx.Graph()
# Add nodes to the graph
G.add_node(1)
G.add_nodes_from([2, 3])
# Add edges to the graph
G.add_edge(1, 2)
G.add_edges_from([(1, 3), (2, 3)])
# Print nodes and edges
print(G.nodes())
print(G.edges())
PyTorch Geometric (PyG) is another powerful library for implementing graph neural networks in PyTorch. The fundamental data structure in PyG is the Data class, which encapsulates a graph as node features, edge indices, and edge features.
Here's how you can create the same graph as above in PyG:
from torch_geometric.data import Data
import torch
# Define the edges
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
# Define the graph
data = Data(edge_index=edge_index)
# Print edges
print(data.edge_index.t().tolist())
Importing Test Graphs
Both NetworkX and PyG offer functionalities to import popular benchmark datasets or randomly generated graphs.
In NetworkX, you can generate a random graph as follows:
# Generate a random graph with 10 nodes
G = nx.gnp_random_graph(10, 0.5)
In PyG, you can import benchmark graph datasets like this:
from torch_geometric.datasets import Planetoid
# Import the CORA citation network
dataset = Planetoid(root='/tmp/Cora', name='Cora')
Visualizing Graphs
Visualizing graphs helps us understand their structure and characteristics. NetworkX provides built-in functions for graph visualization:
import matplotlib.pyplot as plt
nx.draw_networkx(G, with_labels=True)
plt.show()
Challenges
While visualization is a powerful tool, it's not always the best way to understand large or high-dimensional graphs. Graphs can become very complex, making them difficult to draw in a way that highlights important features. Also, as the number of nodes and edges increases, the visualization can become cluttered and unreadable. It's often necessary to use additional methods like statistics, metrics, or subgraph sampling to understand complex graph structures.
For PyG, you can convert the graph to NetworkX for visualization:
from torch_geometric.utils import to_networkx
# Load Cora dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
# Convert to NetworkX graph
G = to_networkx(data, to_undirected=True)
plt.figure(figsize=(20,20))
# If you want to specify node colors according to classes
color_map = {0: 'blue', 1: 'red', 2: 'green', 3: 'yellow',
4: 'purple', 5: 'pink', 6: 'orange'}
# Draw the graph using networkx
nx.draw_networkx(
G,
node_color=[color_map[data.y[i].item()] for i in range(data.num_nodes)],
pos=nx.spring_layout(G, k=0.035),
with_labels=False,
node_size=50
)
plt.show()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
print(f"Diameter: {nx.diameter(G)}")
print(f"Average clustering coefficient: {nx.average_clustering(G)}")
Training a Graph Neural Network for Link Prediction
Next, we'll discuss the training process for a GNN for link prediction. We'll use PyTorch Geometric for this process due to its extensive support for graph neural networks.
Start Small
Starting small is always a good idea when dealing with GNNs. Take a batch or a small graph that you can optimize in a single GPU before scaling up. Graphs can be quite large, and GNNs can be memory-intensive, so starting small allows you to understand the process before dealing with the challenges of scale. Depending on your node features, I've successfully been able to load graphs of 2M nodes, each with a vector of 128b node embedding, and about a 5x ratio between number of nodes and number of edges. Your mileage may vary!
Simple Architecture: Graph Autoencoder
A simple architecture to start with is the Graph Autoencoder. It consists of an encoder that generates node embeddings and a decoder that predicts links from the embeddings. Here's an example "from scratch" with PyG:
from torch_geometric.nn import GCNConv
class Net(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
# Dot product of all pairs of nodes on each edge,
# including "negative" edges.
def decode(self, z, edge_label_index):
return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
@torch.no_grad()
def eval_link_predictor(model, data, predict_probs=True):
model.eval()
z = model.encode(data.x, data.edge_index)
out = model.decode(z, data.edge_label_index).view(-1)
if predict_probs:
return torch.nn.functional.sigmoid(out)
else:
return out
import torch_geometric.transforms as T
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score
# Split the data
transform = T.RandomLinkSplit(
num_val=0.1, num_test=0.2, is_undirected=False, neg_sampling_ratio=1.0,
add_negative_train_samples=False # we will do this during training
)
train_data, val_data, test_data = transform(data)
# Train
epochs_count = 50
model = Net(train_data.x.shape[-1], 256, 128)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, epochs_count, verbose=False
)
for epoch in range(1, epochs_count+1):
model.train()
optimizer.zero_grad()
z = model.encode(train_data.x, train_data.edge_index)
# Sample negative edges during training on every epoch.
neg_edge_index = negative_sampling(
edge_index=train_data.edge_index,
num_nodes=train_data.num_nodes,
num_neg_samples=train_data.edge_label_index.size(1),
method="sparse"
)
# Combine positive and negative edges sampled
edge_label_index = torch.cat(
[train_data.edge_label_index, neg_edge_index],
dim=-1
)
edge_label = torch.cat(
[
train_data.edge_label,
train_data.edge_label.new_zeros(neg_edge_index.size(1))
]
)
out = model.decode(z, edge_label_index).view(-1)
loss = criterion(out, edge_label)
loss.backward()
optimizer.step()
# Get loss and probs for metric purposes
val_logits = eval_link_predictor(model, val_data, predict_probs=False)
val_loss = criterion(val_logits, val_data.edge_label)
val_probs = eval_link_predictor(model, val_data)
test_probs = eval_link_predictor(model, test_data)
val_auc = roc_auc_score(val_data.edge_label.cpu().numpy(),
val_probs.cpu().numpy())
test_auc = roc_auc_score(test_data.edge_label.cpu().numpy(),
test_probs.cpu().numpy())
print(
f"Epoch: {epoch:03d}, Train Loss: {loss:.3f}, "
f"Val Loss: {val_loss:.3f}, Val AUC: {val_auc:.3f}, "
f"Test AUC: {test_auc:.3f}"
)
This is a very basic example. For more advanced architectures, consider Graph Isomorphism Network (GIN), Graph Attention Network (GAT), or GraphSAGE. Each provides a different approach to learning node representations and could potentially offer better performance.
Node and Edge Embeddings
The encoder part of the GNN learns node embeddings, which are dense vector representations that capture the node's position and role within the graph. As for edge embeddings, in our example, they can be obtained by getting the element-wise product (Hadamard Product) of two node embeddings.
The decoder part of the GNN then uses these embeddings to predict whether a link (or edge) exists between two nodes. In the case of link prediction, the goal is to learn a function that can accurately predict whether an edge exists between two nodes based on their embeddings.
In the case of our GAE (Graph Autoencoder) model, we implement a simple dot-product decoder. This means that the score for an edge between two nodes is computed as the dot product of their embeddings. Other more complex decoders could also be used, depending on the specific task and the nature of the graph.
The loss function for link prediction is typically Binary Cross Entropy (BCE), as the task is a binary classification problem (predicting the existence or non-existence of a link).
In PyG's GAE model, you can use the built-in recon_loss method which computes the BCE loss. In fact you can use the wrappers from PyG to make the modeling code above much simpler.
Metrics for link prediction are typically those used in binary classification, like Accuracy, Precision, Recall, F1-score, or Area Under the Receiver Operating Characteristic Curve (AUROC).
Using PyG Primitives
Let's use the GAE primitive in PyG to simplify our code above:
from torch_geometric.nn import GCNConv, GAE
class Encoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(Encoder, self).__init__()
self.conv1 = GCNConv(in_channels, 2 * out_channels)
self.conv2 = GCNConv(2 * out_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
encoder = Encoder(dataset.num_features, out_channels=16)
model = GAE(encoder)
def train():
model.train()
optimizer.zero_grad()
z = model.encode(x, train_pos_edge_index)
loss = model.recon_loss(z, train_pos_edge_index)
loss.backward()
optimizer.step()
return loss
Typical Features and Advantages of GNNs vs Other Systems for Link Prediction
GNNs offer several advantages over traditional methods for link prediction, such as Collaborative Filtering, Matrix Factorization, Content Filtering, or Heuristics.
Features
- Non-Euclidean Data Processing: GNNs are specifically designed to work with graph data, which is fundamentally non-Euclidean. This allows GNNs to process relational data in its native form.
- Node and Edge Features: GNNs can incorporate both node and edge features into their computations, enabling a richer representation of the graph.
- Local and Global Views: GNNs can capture both local and global structural information by leveraging node features and the topology of the graph.
Advantages
- Superior Performance: GNNs often outperform traditional methods on graph data, especially on large and complex graphs.
- Robustness to Sparse Data: GNNs are able to leverage the graph structure to make predictions even when individual node features are sparse or missing.
- Interpretability: The node and edge embeddings learned by GNNs can provide insights into the structure and properties of the graph.
- Transferability: The embeddings learned by GNNs can be used for a variety of downstream tasks, not just link prediction.
Drawbacks
- Computational Complexity: GNNs can be computationally intensive, particularly on large graphs.
- Scalability: While methods exist to scale GNNs, they typically require significant computational resources. Training GNNs on very large graphs can be challenging and often require you to figure out sampling strategies.
- Handling Dynamic Graphs: While there are variants of GNNs designed for dynamic graphs (where nodes and edges can be added or removed over time), handling such changes remains a challenging problem for GNNs. Adding a temporal dimension to GNNs is possible, but still very much an active area of research.
Literature
There's a growing body of research on using GNNs for link prediction and recommendation systems. Here are a few notable works:
- "Graph Convolutional Neural Networks for Web-Scale Recommender Systems." (Ying et al., 2018) — Introduced PinSage, a GCN model used by Pinterest for web-scale recommendation.
- "Graph Neural Networks for Social Recommendation." (Wang et al., 2019) — Proposed GraphRec for social recommendation leveraging user-item interaction history and social network information.
- "Session-based Recommendation with Graph Neural Networks." (Wu et al., 2019) — Introduced a method for using GNNs for session-based recommendations.
- "LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation." (He et al., 2020) — Proposed a simplified GCN model for recommendation that removes unnecessary complexity.
- "Bipartite Graph Neural Network for Recommender Systems." (Liu et al., 2020) — Proposed a GNN model specifically designed for bipartite graphs common in recommendation systems.
- "Graph Neural Networks for Recommender Systems: Challenges, Methods, and Directions." (Zhang et al., 2021) — A comprehensive survey of GNNs applied to recommendation systems.
- "BERT4Rec: Sequential Recommendation with Bidirectional Encoder Representations from Transformer." (Sun et al., 2019) — Applies the Transformer model to sequential recommendation tasks with competitive performance.