Graph-based learning using Graph neural networks: This is a beginner-friendly exploration of Graph Neural Networks (GNNs), where we unravel the fundamentals of this powerful technique for analyzing interconnected data structures and pave the way for deeper understanding and practical applications. This will be a precursor to a subsequent hands-on workshop that'll be announced later.
This talk was delivered as part of the neo4j meetup that happened on 19th August, 2023 at Thoughtworks, Bangalore. Meetup link: https://www.meetup.com/graph-database-bangalore/events/294780261
4. Spotlight: A chemist’s tale
4
What’s this? Is a molecule a graph?
If so, what are its nodes,
edges, features?
5. Predict if a molecule is a potent drug
5
1. Train a GNN on a curated dataset where response is known
2. Once trained, use the model and apply any molecule
3. Select ~top-100 candidates
4. Get chemists to thoroughly investigate them
6. Halicin - a powerful drug discovered by GNN in
2020
“We wanted to develop a platform that would allow us to harness the power of artificial intelligence to usher in a new age of antibiotic
drug discovery. Our approach revealed this amazing molecule which is arguably one of the more powerful antibiotics that has been
discovered.” - James Collins, Professor of Medical Engineering and Science, MIT
6
6
8. Machine Learning problems using Graph data
8
?
Node-level predictions
Does this person smoke?
(Unlabelled node)
Link-level predictions
Next Prime video?
Graph-level predictions
Is this molecule a suitable
drug?
19. Two-step message passing: a recap
19
1
2
3
4
5
Target
node
1
2
3
4
AGGREGATE
1. AGGREGATE - pass information (the “message”) from the target node’s neighbours to
the target node
2. UPDATE - update each node’s features based on “message” to form an embedded
representation
20. Generic form of message passing
20
1
2
3
4
AGGREGATE
1
2
3
4
5
Target node
h = node features or embeddings
k = number of hops
21. Using neural networks for aggregate and update
21
Each node’s updated value becomes a weighting of its previous value + a weighting of its
neighbours’ values
22. Make message passing more efficient by
simplifying, generalizing and sharing parameters
22
Collapse the two weight vectors into W by adding self-loops to the adjacency matrix
23. Base GNNs on a “convolution” perspective
23
Normalize by # of nodes in the
neighbourhood
“Original” GNN (2009)
GCN (2016)
25. 25
pyTorch and pyG
1.
A wealth of libraries
You can mix-and-match some of these libraries to train and predict node/edge/graph
classification problems.
tensorflow
2.
GDS from neo4j
3.
25
Set expectation for the session
You should be able to understand why people use GNNs
You should have a 20000-foot idea of what they actually do - at least develop an intuition of what they do
Appreciate the math (if you have some exposure to Neural networks)
Overall we will be setting the stage for a subsequent workshop (perhaps in the next neo4j meetup) where we will do some hands-on.
Must-watch: https://www.youtube.com/watch?v=cWIeTMklzNg
Just like graph databases are found everywhere, predictions based on graphs also have their applications.
From a social graph of LinkedIn, you can suggest new connections.
Could be used in e-commerce recommender systems
Could be used in medicine/pharmacy - say is this drug potent?
Could be used in social networks - say providing connection recommendations
https://www.youtube.com/watch?v=fOctJB4kVlM&list=PLV8yxwGOxvvoNkzPfCx2i8an--Tkt7O8Z&index=1 2:14
Molecule
Nodes - atoms
Edges - bonds
Features of nodes - atom types, number of protons, charge
Features of edges - bond type
One you have a molecule represented as a graph, you can train a Graph Neural Network to perform - say a binary task. Will this molecule inhibit a given bacteria?
First, you will have a curated dataset of several molecules and you know the outcome - i.e. whether it will inhibit E-coli or not.
You train a Graph neural network with this training data.
Once you have trained the GNN, you can feed it a bunch of candidate molecules.
Take the most promising N number of molecules.
Halicin in 2020
GNNs in 1990s in the chemical industry - a bit of history
So what can you do with graph data?
Node-level predictions - If you have a graph with unlabelled nodes, you want to predict attributes about them or you want to classify them. GNNs as you’ll see soon will use information from other nodes to infer on these unlabelled nodes
Edge-prediction - Typically used by companies to predict which product will be purchased next by a customer. Here, it can used to predict if there will be a link between a person node and an item node
Graph-level prediction - If a molecule is a potent drug or not
https://www.youtube.com/watch?v=fOctJB4kVlM&list=PLV8yxwGOxvvoNkzPfCx2i8an--Tkt7O8Z&index=2
3:20
This is a commonly encountered concept in NLP and deep learning. How do you represent discrete things in a large set, say the words in a dictionary?
We have a huge sparse vector that has mostly 0s. It just has 1 for the word under consideration at the appropriate index. Problem is that we don’t know how they will correlate.
Most deep learning models tend to learn a distributed vector representation. In this case, you can see that “banana” and “mango” have some similarity - maybe because they are both fruits and are yellow.
So V could be 10000, and D maybe some 128 or 500 or something like that. In a way, it compresses as well.
Distributed Vector representations - https://www.youtube.com/watch?v=zCEYiCxrL_0 5:29
You start with a graph - modelling decision. It should encode the problem.
For each of the nodes, there could be some information - either an embedding or a set of features.
The output of a GNN is the same graph. For each node, there is a vector representation. Each node now should have the information not simply about itself - but rather how it belongs in the context of the graph.
Training and recommendation
Recommendation
The aggregate function should be permutation invariant.
Training
Training
Too much message passing - over-smoothing
In the first step, the nodes nly know about themselves
In the next step, all of them know about their neighbours.
Then their neighbours’ neighbours.
In each step, the perceptive field increases.
Finally, every node would have learnt how it “belongs” in the graph.
Once you have these vectors, you can do:
Node classification - classify each node independently by applying a shared layer
Link classification
Aggregate all the h vectors
12:59 https://www.youtube.com/watch?v=8owQBFAHw7E
Recommendation
Show n layers of convolution
Aggregate function needs to be permutation invariant. For eg; if there is a new neighbour, the output shape (whether it is a number or a matrix) should not change.
Summation is a choice - you could either do mean, max, min, etc.
Do not treat target and neighbour nodes differently.
Adjacency matrix defines
pyTorch - From Meta, now under Linux foundation. Optimized tensor library for deep learning using GPUs and CPUs.
pyG - Has extension for graph learning. Supports GNN architectures like GCN and GAT
Tensorflow
GDS - Random walks with restarts sampling (server-side and client-side components, uses Apache arrow for in-memory analytics)