BatmanNet: Bi-branch Masked Graph Transformer Autoencoder for Molecular Representation
TL;DR Summary
BatmanNet, a novel bi-branch masked graph transformer autoencoder, is proposed for effective molecular representation learning. It uses a simple self-supervised strategy to capture both local and global information, achieving state-of-the-art results in drug discovery tasks.
Abstract
Although substantial efforts have been made using graph neural networks (GNNs) for AI-driven drug discovery (AIDD), effective molecular representation learning remains an open challenge, especially in the case of insufficient labeled molecules. Recent studies suggest that big GNN models pre-trained by self-supervised learning on unlabeled datasets enable better transfer performance in downstream molecular property prediction tasks. However, the approaches in these studies require multiple complex self-supervised tasks and large-scale datasets, which are time-consuming, computationally expensive, and difficult to pre-train end-to-end. Here, we design a simple yet effective self-supervised strategy to simultaneously learn local and global information about molecules, and further propose a novel bi-branch masked graph transformer autoencoder (BatmanNet) to learn molecular representations. BatmanNet features two tailored complementary and asymmetric graph autoencoders to reconstruct the missing nodes and edges, respectively, from a masked molecular graph. With this design, BatmanNet can effectively capture the underlying structure and semantic information of molecules, thus improving the performance of molecular representation. BatmanNet achieves state-of-the-art results for multiple drug discovery tasks, including molecular properties prediction, drug-drug interaction, and drug-target interaction, on 13 benchmark datasets, demonstrating its great potential and superiority in molecular representation learning.
Mind Map
In-depth Reading
English Analysis
1. Bibliographic Information
1.1. Title
The central topic of this paper is a novel approach for molecular representation learning using a bi-branch masked graph transformer autoencoder. The title is BatmanNet: Bi-branch Masked Graph Transformer Autoencoder for Molecular Representation.
1.2. Authors
The authors of the paper are:
-
Zhen Wang (1,2)
-
Zheng Feng (3)
-
Yanjun Li (4)
-
Bowen Li (2)
-
Yongrui Wang (2)
-
Chulin Sha (2)
-
Min He (1,2)*
-
Xiaolin Li (2,5)*
The affiliations are not explicitly detailed for all authors, but two corresponding authors are identified: Min He (hemin607@163.com) and Xiaolin Li (xiaolinli@ieee.org).
1.3. Journal/Conference
This paper was published on arXiv, a preprint server.
1.4. Publication Year
The paper was published at (UTC) 2022-11-25T09:44:28.000Z, which corresponds to 2022.
1.5. Abstract
The paper addresses the challenge of effective molecular representation learning, particularly when labeled molecular data is scarce, a common issue in AI-driven drug discovery (AIDD). Existing self-supervised pre-training methods often involve multiple complex tasks, large datasets, and significant computational expense, making end-to-end pre-training difficult.
The authors propose BatmanNet, a novel bi-branch masked graph transformer autoencoder, coupled with a simple yet effective self-supervised strategy. This strategy simultaneously learns local and global molecular information by masking a high proportion of nodes and edges in a molecular graph and reconstructing the missing parts. BatmanNet features two tailored, complementary, and asymmetric graph autoencoders, one for node reconstruction and one for edge reconstruction. This design enables the model to efficiently capture the underlying structural and semantic information of molecules.
BatmanNet achieves state-of-the-art results across 13 benchmark datasets for various drug discovery tasks, including molecular property prediction, drug-drug interaction, and drug-target interaction. This demonstrates its potential and superiority in learning molecular representations.
1.6. Original Source Link
The official source link is: https://arxiv.org/abs/2211.13979v3 The PDF link is: https://arxiv.org/pdf/2211.13979v3.pdf This paper is a preprint published on arXiv.
2. Executive Summary
2.1. Background & Motivation
The core problem this paper aims to solve is the effective learning of molecular representations, especially in AI-driven drug discovery (AIDD) where labeled molecular data is often insufficient. Molecular representation learning is crucial for various downstream tasks like molecular property prediction, drug-drug interaction (DDI) prediction, and drug-target interaction (DTI) prediction.
The importance of this problem stems from the high cost and time involved in acquiring high-quality molecular property labels through wet-lab experiments, leading to a scarcity of task-specific labeled data. Supervised training of deep models on these restricted datasets frequently results in overfitting, where the model learns to perform well on the training data but fails to generalize to new, unseen data.
Prior research has explored self-supervised learning (SSL) to pre-train large Graph Neural Network (GNN) models on unlabeled datasets, which has shown promise for transfer performance. However, these existing approaches face specific challenges and gaps:
-
Complex Pre-training Tasks: Many methods require constructing multiple, complex self-supervised tasks to learn different aspects (local and global information) of molecules. This often involves introducing additional domain knowledge (e.g., motifs, subgraphs, atomic distance matrices, molecular descriptors), making the pre-training process complex, difficult to manage end-to-end, and less scalable.
-
High Computational Complexity and Large Model Size: Recent
Transformer-based models, while powerful, typically encode information directly over the entire molecular graph. This leads to a large number of parameters, high computational cost, and extensive memory consumption, limiting their accessibility for smaller research groups. -
Suboptimal Topological Representation: Some methods rely on
SMILES(Simplified Molecular-Input Line-Entry System) representations, which are linear strings. While useful,SMILESlacks explicit topological information and may not accurately capture molecular structural similarities, potentially misleading language models.The paper's entry point is to address these challenges by designing a simple yet effective self-supervised strategy to simultaneously learn both local and global molecular information and proposing a novel, computationally efficient bi-branch masked graph transformer autoencoder (
BatmanNet).
2.2. Main Contributions / Findings
The paper makes several primary contributions to molecular representation learning:
- Novel Self-Supervised Pre-training Strategy: It introduces a straightforward yet powerful self-supervised pre-training strategy that involves masking a high proportion (specifically, 60%) of both nodes (atoms) and edges (bonds) in a molecular graph. The model then reconstructs these missing parts using an autoencoder architecture. This single task is designed to implicitly learn both local contextual information (beyond -hops) and global graph-level information without requiring predefined domain knowledge or multiple complex tasks, making it highly scalable and effective.
- Bi-branch Asymmetric Graph-based Autoencoder Architecture (
BatmanNet): The paper developsBatmanNet, which features two complementary and asymmetric graph autoencoders, one specifically for reconstructing masked nodes and another for masked edges. Thisbi-branchdesign enhances the model's expressiveness by simultaneously learning node and edge representations. Theasymmetric encoder-decoderdesign, where theencoderoperates only on the visible (unmasked) parts of the graph and alightweight decoderreconstructs the full graph, significantly reduces computational complexity, memory consumption, and pre-training time. - State-of-the-Art Performance Across Diverse Drug Discovery Tasks:
BatmanNetachieves state-of-the-art results on 13 widely used benchmark datasets across multiple drug discovery tasks, including:- Molecular property prediction (e.g., physical, chemical, biophysical properties).
- Drug-drug interaction (DDI) prediction.
- Drug-target interaction (DTI) prediction.
These findings demonstrate the power capacity, effectiveness, and generalizability of
BatmanNetin learning high-quality molecular representations. The model achieves comparable or superior performance to previous SOTA models while utilizing fewer training data and model parameters.
3. Prerequisite Knowledge & Related Work
3.1. Foundational Concepts
3.1.1. Graph Neural Networks (GNNs)
Graph Neural Networks (GNNs) are a class of neural networks specifically designed to operate on graph-structured data. Unlike traditional neural networks that process Euclidean data (like images or sequences), GNNs can handle non-Euclidean data where relationships between data points are important. In the context of molecules, atoms are typically represented as nodes and chemical bonds as edges in a graph.
The core idea behind most GNNs is message passing (also known as neighborhood aggregation). In this mechanism, each node iteratively updates its hidden representation (or embedding) by aggregating information from its immediate neighbors and the edges connecting them. This process allows nodes to incorporate structural information from their local neighborhood into their representations. After multiple layers (or iterations) of message passing, a node's representation can capture structural information from its -hop neighborhood.
The -th layer of a GNN can be formulated as: $ \mathbf{m}_v^{(k)} = \mathrm{AGG}^{(k)} \left( \left{ \left( \mathbf{h}_v^{(k-1)}, \mathbf{h}u^{(k-1)}, \mathbf{e}{uv} \right) \mid u \in \mathcal{N}_v \right} \right) $ $ \mathbf{h}_v^{(k)} = \sigma \left( \mathbf{W}^{(k)} \mathbf{m}_v^{(k)} + \mathbf{b}^{(k)} \right) $ Where:
-
: The hidden representation (or embedding) of node at the -th layer.
-
: The aggregated message for node at the -th layer, collected from its neighbors.
-
: The aggregation function at layer , which combines information from the node's neighbors. This can be a sum, mean, max, or a more complex neural network.
-
: Represents all neighboring nodes of node .
-
: The representation (features) of the edge connecting node and node .
-
: An activation function (e.g., ReLU, sigmoid, tanh).
-
and : Learnable weight matrix and bias vector for the -th layer.
-
: The initial features of node .
After the final iterations, a
READOUTfunction is applied to combine all node representations into a singlegraph-level representation: $ \mathbf{h}_G = \operatorname{READOUT} \left( \left{ \mathbf{h}_v^{(K)} \mid v \in \mathcal{V} \right} \right) $ Where is the set of all nodes in the graph.
3.1.2. Multi-head Attention Mechanism
The multi-head attention mechanism is a core component of the Transformer architecture. It allows a model to jointly focus on information from different representation subspaces at different positions. Instead of performing a single attention function, it runs multiple attention functions in parallel, each with its own set of learned projection matrices. The outputs from these multiple "heads" are then concatenated and linearly transformed to produce the final output.
The basic unit is scaled dot-product attention, which takes three inputs: queries (), keys (), and values ().
$
\mathrm{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \mathrm{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right)\mathbf{V}
$
Where:
-
: Query matrix, representing the elements seeking information.
-
: Key matrix, representing the elements whose information is being sought.
-
: Value matrix, containing the actual information to be retrieved.
-
: The dimension of the keys (and queries). Dividing by scales the dot products to prevent the
softmaxfunction from having extremely small gradients. -
: Normalizes the attention scores so they sum to 1.
For
Multi-head attentionwith parallel attention layers (heads): $ \mathrm{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \mathrm{Concat}\left(\mathrm{head}_1, \dots, \mathrm{head}_h\right)\mathbf{W}^O $ $ \mathrm{head}_i = \mathrm{Attention}\left(\mathbf{Q}\mathbf{W}_i^{\mathbf{Q}}, \mathbf{K}\mathbf{W}_i^{\mathbf{K}}, \mathbf{V}\mathbf{W}_i^{\mathbf{V}}\right) $ Where: -
: Learnable projection weight matrices for the -th head, used to project the queries, keys, and values into different representation subspaces.
-
: Concatenates the outputs from all heads.
-
: A final learnable weight matrix to linearly project the concatenated output to the desired dimension.
3.1.3. Autoencoders
An autoencoder is a type of artificial neural network used for unsupervised learning of efficient data codings (representations). It consists of two main parts:
-
Encoder: This part compresses the input data into a lower-dimensional
latent space representation(orencoding). -
Decoder: This part attempts to reconstruct the original input data from the latent space representation.
The goal of an autoencoder is to learn a representation that effectively captures the most important features of the input data, such that the reconstructed output is as close as possible to the original input. This is typically achieved by minimizing a
reconstruction loss(e.g.,mean squared errororcross-entropy) between the input and the output.
3.1.4. Self-supervised Learning (SSL)
Self-supervised learning is a paradigm where a model learns representations from unlabeled data by solving a pretext task (or self-supervised task). The pretext task is designed such that solving it requires the model to learn useful features of the data that can then be transferred to downstream supervised tasks. Common SSL strategies include:
- Masked Language Modeling (MLM): Used in models like
BERT, where parts of the input sequence are masked, and the model learns to predict the masked tokens. - Contrastive Learning: Learning representations by pulling together augmented versions of the same data point (positive pairs) and pushing apart different data points (negative pairs).
- Autoencoding with Reconstruction: Similar to standard autoencoders, but often with more aggressive masking or data corruption to make the reconstruction task more challenging and force the model to learn robust features.
Masked Autoencoders (MAE)are a prominent example in computer vision.
3.1.5. Transformers
A Transformer is a deep learning model introduced in 2017, primarily known for its success in natural language processing (NLP). Its key innovation is the attention mechanism, particularly self-attention, which allows the model to weigh the importance of different parts of the input sequence when processing each element, without relying on sequential processing (unlike Recurrent Neural Networks - RNNs). Transformers are highly parallelizable and effective at capturing long-range dependencies. They typically consist of stacked encoder and decoder blocks, each containing multi-head attention and feed-forward neural networks.
3.1.6. Evaluation Metrics
-
ROC-AUC (Receiver Operating Characteristic - Area Under the Curve):
- Conceptual Definition:
ROC-AUCis a performance metric used for binary classification problems. TheROC curveplots theTrue Positive Rate (TPR)against theFalse Positive Rate (FPR)at various threshold settings. TheAUCrepresents the area under this curve. A higher AUC (closer to 1) indicates that the model is better at distinguishing between positive and negative classes. An AUC of 0.5 suggests a random classifier. - Mathematical Formula: $ \text{TPR} = \frac{\text{TP}}{\text{TP} + \text{FN}} $ $ \text{FPR} = \frac{\text{FP}}{\text{FP} + \text{TN}} $ The AUC is the integral of the ROC curve: $ \text{AUC} = \int_{0}^{1} \text{TPR}(\text{FPR}^{-1}(x)) dx $
- Symbol Explanation:
TP: True Positives (correctly predicted positive instances).FN: False Negatives (actual positive instances incorrectly predicted as negative).FP: False Positives (actual negative instances incorrectly predicted as positive).TN: True Negatives (correctly predicted negative instances).- : True Positive Rate, also known as
RecallorSensitivity. - : False Positive Rate.
- Conceptual Definition:
-
RMSE (Root Mean Squared Error):
- Conceptual Definition:
RMSEis a widely used metric for regression tasks. It measures the average magnitude of the errors between predicted values and actual values. It is the square root of the average of the squared differences between prediction and actual observation. A lower RMSE indicates better model performance. - Mathematical Formula: $ \text{RMSE} = \sqrt{\frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2} $
- Symbol Explanation:
- : The total number of observations or data points.
- : The actual (ground truth) value for the -th observation.
- : The predicted value for the -th observation.
- Conceptual Definition:
-
PR-AUC (Precision-Recall Area Under the Curve):
- Conceptual Definition:
PR-AUCis another metric for binary classification, particularly useful for imbalanced datasets where the positive class is rare. ThePrecision-Recall curveplotsPrecisionagainstRecallat various threshold settings. A higher PR-AUC indicates better performance, especially in cases where false positives are costly. - Mathematical Formula: $ \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} $ $ \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} $ The AUC is the integral of the Precision-Recall curve: $ \text{PR-AUC} = \int_{0}^{1} \text{Precision}(\text{Recall}^{-1}(x)) dx $
- Symbol Explanation:
TP,FN,FP: Same as forROC-AUC.Precision: The proportion of correctly predicted positive instances among all instances predicted as positive.Recall: The proportion of correctly predicted positive instances among all actual positive instances (same asTPR).
- Conceptual Definition:
-
F1 Score:
- Conceptual Definition: The
F1 scoreis the harmonic mean ofPrecisionandRecall. It is a single metric that balances both precision and recall, providing a more comprehensive measure of a model's accuracy than just precision or recall alone, especially when class distribution is uneven. A higher F1 score indicates better model performance. - Mathematical Formula: $ \text{F1} = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}} $
- Symbol Explanation:
Precision: As defined above.Recall: As defined above.
- Conceptual Definition: The
-
Cross-entropy Loss ():
- Conceptual Definition:
Cross-entropy lossis a common loss function used in classification tasks, particularly when the output is a probability distribution over multiple classes. It measures the difference between two probability distributions: the true distribution and the predicted distribution. Minimizing cross-entropy aims to make the predicted probabilities match the true probabilities as closely as possible. - Mathematical Formula (for multi-class classification): $ \mathcal{L}{ce} = -\sum{c=1}^{C} y_c \log(\hat{p}_c) $
- Symbol Explanation:
- : The total number of classes.
- : A binary indicator (0 or 1) if class is the correct class for the observation.
- : The predicted probability of the observation belonging to class .
- Conceptual Definition:
3.2. Previous Works
3.2.1. Early Feature-based Approaches
Early methods for molecular representation relied on handcrafted molecular descriptors and fingerprints [21, 22, 23]. These approaches represent molecules as fixed-length vectors based on predefined chemical properties or structural patterns.
- Example:
ECFP(Extended Connectivity Fingerprints) [50] is a circular topological fingerprint. - Limitations: These methods heavily depend on complex
feature engineering, which is labor-intensive and requires significant domain expertise. They also suffer fromvector sparsity issues(many zeros in the representation), which can hinder predictive performance.
3.2.2. Deep Learning on SMILES Representations
With the advent of deep learning, some works started using SMILES (Simplified Molecular-Input Line-Entry System) strings, a linear textual representation of molecules, as input for sequence-based models [24, 25, 26].
- RNN-based Models:
Recurrent Neural Networks (RNNs)were used to generate molecular representations from SMILES [26]. - BERT-style Models: Inspired by
Natural Language Processing (NLP),BERT-style models were pre-trained usingmasked language modelingtasks on SMILES strings [27, 28, 14, 15].SMILES-BERT[14] is a prominent example. - Autoencoder Frameworks: Some approaches used autoencoders to reconstruct SMILES representations for learning embeddings [29, 30, 26].
- Limitations:
SMILESrepresentations have inherent limitations. They are not designed to capture molecular similarity directly, meaning structurally similar molecules can have vastly different SMILES strings, which can confuse language models. They also struggle to explicitly represent essential chemical properties like molecular validity, often generating sequences corresponding to invalid molecules. Crucially, they lack explicittopology representation, failing to learn molecular structural information directly.
3.2.3. Graph Neural Networks (GNNs) for Molecular Graphs
GNNs have gained significant traction for molecular representation learning due to their ability to directly model graph-structured data (atoms as nodes, bonds as edges) [11].
- Early GCNs: Works like
GraphConv[52],Weave[31], andSchNet[32] encoded molecular graphs intoneural fingerprintsusinggraph convolutional networks. - Attention Mechanisms:
Graph Attention Network (GAT)[36] extensions likeAttentiveFP[35] learned aggregation weights using attention. - Message Passing Frameworks:
MPNN[37] introduced a message-passing framework, whichDMPNN[39] andMGCN[40] extended to model bond interactions and multi-level interactions, respectively. - Limitations: While GNNs excel at modeling graph structures, their traditional
message passing operatorsprimarily aggregatelocal information. This makes them less effective at capturinglong-range dependencieswithin molecules, which are crucial for understanding global molecular properties.
3.2.4. Self-supervised Learning (SSL) for Molecular Graphs
To overcome the data scarcity and limitations of traditional GNNs, SSL on molecular graphs has emerged as a key direction. These methods are typically categorized by the molecular information level used:
- 2D Topology-based Models: These models pre-train from the molecular 2D topological structure [13, 17, 18, 19, 43, 44, 45].
- Transformer-style Architectures: Many employ
Transformer-style architectures [13, 18, 19].GROVER[18] andMPG[13] predefine and extractmotifsorsubgraphsas prediction targets.KPGT[19] introduces additional domain knowledge (molecular descriptors/fingerprints) and masks a proportion of this knowledge for reconstruction.
- Masked Autoencoders for Graphs:
GMAE[44] focuses on masking nodes only.MGAE[45] focuses on masking edges only.GraphMAE[47] replaces masked nodes with descriptors.
- Transformer-style Architectures: Many employ
- 3D Geometry-based Models: These methods work on 3D geometry graphs using spatial positions of atoms with
geometric GNNmodels [46, 20].GEM[20] is an example, learning molecular geometry knowledge. - Limitations of existing SSL GNNs:
- Complex Pre-training Tasks: Many require multiple tasks and predefined domain knowledge (motifs, subgraphs, descriptors), making them complex, hard to implement end-to-end, and less scalable.
- High Computational Complexity: Transformer-based models often encode the entire graph, leading to large model sizes and high computational costs.
- Data Volume: Graph-based methods, especially complex ones, often require large volumes of data for pre-training, which can limit their generalization ability if data is sparse.
3.3. Technological Evolution
The field of molecular representation learning has evolved significantly:
- Handcrafted Features (Pre-2010s): Initial methods relied on expert-designed
molecular descriptorsandfingerprints. These were interpretable but limited by manual effort and potential information loss. - SMILES-based Deep Learning (Mid-2010s): The rise of deep learning led to the use of SMILES strings with
RNNsand laterTransformers(BERT-style models). This offered automatic feature learning but struggled with capturing explicit graph topology and molecular similarity. - Graph Neural Networks (Late 2010s):
GNNsemerged as a natural fit for molecular graphs, directly modeling atoms and bonds. Early GNNs focused on local message passing. - Self-supervised GNNs (Early 2020s): To overcome the need for large labeled datasets,
SSLstrategies were applied to GNNs. This involved pre-training on unlabeled data using various pretext tasks like predicting motifs, subgraphs, or reconstructing masked features. These often became complex, computationally expensive, and used large Transformer models. - BatmanNet's Position:
BatmanNetfits into the latest wave ofself-supervised GNNs, specifically addressing the complexity and computational cost issues of previousTransformer-based SSL methods. It proposes a simpler, more efficientmasked autoencoderapproach on graphs with an asymmetric,bi-branchdesign to learn both local and global information more effectively.
3.4. Differentiation Analysis
Compared to the main methods in related work, BatmanNet offers core differences and innovations primarily by tackling the twin challenges of complexity and computational cost in self-supervised graph learning:
-
Simplifying Pre-training Tasks (vs. GROVER, MPG, KPGT, GEM):
- Previous:
GROVERandMPGrequire predefining and extracting chemicalmotifsorsubgraphsfor their SSL tasks.KPGTintroduces additional domain knowledge like molecular descriptors/fingerprints as reconstruction targets.GEMuses several geometry-level SSL strategies. These methods involve constructing multiple, complex tasks and rely on external domain knowledge. - BatmanNet: Proposes a single, straightforward
bi-branch graph maskingtask. It randomly masks a high proportion (60%) of both nodes and edges and reconstructs their full features. This approach isdomain-agnostic, requires no predefined motifs, subgraphs, or additional domain-specific chemical knowledge, making it more scalable, intuitive, and easier to implement end-to-end. The high masking ratio makes the task challenging, forcing the model to learn robust local and global contextual information.
- Previous:
-
Reducing Computational Complexity and Model Size (vs. GROVER, MPG, KPGT, GEM):
- Previous: Most
Transformer-based molecular pre-training models (likeGROVER,MPG,KPGT) directly encode information over the entire molecular graph, leading to a large number of parameters and high computational/memory costs. - BatmanNet: Employs an
asymmetric encoder-decoderarchitecture, inspired byMAE[41]. Theencoderoperates only on the visible (unmasked) subset of the molecular graph. Thedecoder, which is much morelightweight(fewer layers, ), handles the reconstruction of the entire graph from the encoder's output andmask tokens. This design ensures that the full graph is only processed by the lightweight decoder, significantly reducing the model's overall parameters, computation, and memory consumption during pre-training.
- Previous: Most
-
Bi-branch Complementary Learning (vs. GMAE, MGAE, GraphMAE):
-
Previous:
GMAEmasks only nodes,MGAEmasks only edges, andGraphMAEreplaces masked nodes with descriptors rather than directly reconstructing them. -
BatmanNet: Introduces a
bi-branchcomplementary autoencoder. One branch specifically reconstructs masked nodes, and the other reconstructs masked edges. This dual approach simultaneously enhances the expressiveness of both node and edge representations, capturing a more complete picture of the molecular structure and semantics. Its direct removal and reconstruction of masked parts makes the task more challenging and the learned representations more capable.In essence,
BatmanNetdifferentiates itself by offering asimpler,more efficient, andmore holisticapproach to self-supervised molecular graph representation learning, achieving SOTA results with a significantly reduced computational footprint and model complexity.
-
4. Methodology
4.1. Principles
The core idea behind BatmanNet is to leverage a masked autoencoder framework, specifically tailored for molecular graphs, to learn expressive representations in a self-supervised manner. The theoretical basis and intuition are rooted in the success of Masked Autoencoders (MAE) in computer vision and Transformer-based models in NLP, adapted for graph structures.
The key principles are:
- Reconstruction as a Self-Supervised Task: By masking a significant portion of a molecular graph (both atoms/nodes and bonds/edges) and forcing the model to reconstruct the original, unmasked graph from its partial view, the model is compelled to learn rich structural and semantic features. This challenging task implicitly encourages the learning of both local contextual information (how masked parts relate to their immediate neighbors) and global structural information (how the entire graph is assembled).
- Bi-branch Learning for Comprehensive Representation: Molecules inherently have distinct atom and bond features, and their interactions are crucial.
BatmanNetintroduces two parallel branches—anode branchand anedge branch—each focused on its specific type of entity. This ensures that both atom-level and bond-level information are explicitly and complementarily learned. - Asymmetric Encoder-Decoder for Efficiency: Inspired by MAE, the model adopts an asymmetric architecture. A powerful
encoderprocesses only the visible (unmasked) parts of the input graph, which is a sparse subset. A muchlighter decoderthen takes thelatent representationsfrom the encoder, along withmask tokens(placeholders for the masked parts), to reconstruct the complete original graph. This design dramatically reduces computational overhead during pre-training, as the full graph is only handled by the lightweight decoder. - GNN-Attention Blocks for Local and Global Information: Each layer within the encoder and decoder combines a
Graph Neural Network (GNN)for local message passing (capturing neighborhood information) and aMulti-head Attentionmechanism for capturing long-range dependencies and global relationships across the graph.
4.2. Core Methodology In-depth (Layer by Layer)
4.2.1. Overview of BatmanNet Framework
BatmanNet is a bi-branch model, meaning it has two parallel processing pathways: a node branch and an edge branch. Each branch is dedicated to learning representations for nodes (atoms) or edges (bonds) from the input molecular graph. The overall architecture for each branch follows a transformer-style asymmetric encoder-decoder design, similar to MAE [41].
-
Input: For a molecule, we define a
node graphand anedge graph.- In , atom is a node, and bond is an edge connecting atoms and . Initial features are for nodes and for edges.
- The
edge graphis theline graph(orprimary dual) of the node graph, where each edge of becomes a node in , and two nodes in are connected if their corresponding edges in share a common node. This allows message passing over edges.
-
Self-supervised Pre-training Strategy: A
bi-branch graph maskingstrategy is applied. A high proportion (e.g., 60%) of nodes are randomly masked in thenode branchand a high proportion of edges are randomly masked in theedge branch. -
Encoder: The
encoder(one for each branch) operates only on the partially observable signals – i.e., the unmasked nodes and edges of the molecular graph. It embeds these visible components intolatent representations. The encoder is amulti-layer transformer-styled network. -
Decoder: The
decoder(one for each branch) takes thelatent representationsof the unmasked nodes/edges (from the encoder) andmask tokens(placeholders for the removed nodes/edges) as input. It then reconstructs the original, complete molecule. The decoder uses a similartransformer-styled architecturebut is designed to be much morelightweight( layers compared to the encoder). This asymmetry is key for efficiency, as the full graph is only processed by the lighter decoder. -
Output: During pre-training, the output is the reconstructed features of the masked nodes and edges. For downstream tasks, only the
encoderis used to produce molecular representations.The following figure (Figure 2 from the original paper) shows the BatmanNet's decoder:
该图像是比支分masked图变压器自编码器BatmanNet的示意图。该图展示了节点和边嵌入的过程,包括特征重排序、聚合以及多头注意力机制的使用,突出BatmanNet编码器和解码器的结构设计与功能。
4.2.2. Details of Encoder and Decoder
Both the encoder and decoder are composed of a stack of GNN-Attention blocks. The encoder has such blocks, and the decoder has blocks (). Each block implements a double-layer information extraction framework: a GNN for local information, followed by a Multi-head Attention layer for global information.
4.2.2.1. GNN-Attention Block
A GNN-Attention block consists of:
-
GNN Layer: Performs
message passingto extractlocal informationfrom the input graph. This produces learned embeddings. -
Multi-head Attention Layer: Takes the embeddings from the GNN and applies
Multi-head Attentionto captureglobal informationandlong-range dependenciesacross the graph.Specifically, within a
GNN-Attention block, three GNNs are used to learn the embeddings forqueries(),keys(), andvalues(). Let be the hidden representation matrix of nodes, with an embedding size of . $ \mathbf{Q} = \mathbf{G}{\mathbf{Q}}(\mathbf{H}) $ $ \mathbf{K} = \mathbf{G}{\mathbf{K}}(\mathbf{H}) $ $ \mathbf{V} = \mathbf{G}_{\mathbf{V}}(\mathbf{H}) $ Where:
-
, , and : These are three distinct
GNNinstances within theGNN-Attention block. Each takes the current hidden representations as input and transforms them to generate the query, key, and value matrices, respectively.After obtaining , they are fed into the
Multi-head Attentionmechanism to compute the block's output. As a reminder, theMulti-head Attentionmechanism calculates attention scores and combines information across multiple "heads": $ \mathrm{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \mathrm{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right)\mathbf{V} $ $ \mathrm{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \mathrm{Concat}\left(\mathrm{head}_1, \dots, \mathrm{head}_h\right)\mathbf{W}^O $ $ \mathrm{head}_i = \mathrm{Attention}\left(\mathbf{Q}\mathbf{W}_i^{\mathbf{Q}}, \mathbf{K}\mathbf{W}_i^{\mathbf{K}}, \mathbf{V}\mathbf{W}_i^{\mathbf{V}}\right) $ Where: -
: Dimension of keys.
-
: Projection weights for the -th attention head.
-
: Output projection weight.
4.2.2.2. Encoder Specifics
- Initial Processing: At the beginning of the encoder, a
linear projectionis applied to the initial features of the unmasked nodes and edges.Positional embeddingsare then added to preserve their structural positions (indexed byRDkitbefore masking). The paper adopts theabsolute sinusoidal positional encodingproposed by [48]. - Message Aggregation: The original graph () and its dual graph () are fed into their respective branches of the encoder. The aggregated node embedding and edge embedding are obtained as follows:
$
\mathbf{m}v = \sum{u \in \mathcal{N}v} \mathbf{h}u
$
$
\mathbf{m}{vw} = \sum{u \in \mathcal{N}v \setminus w} \mathbf{h}{uv}
$
Where:
- : Aggregated message for node .
- : Neighbors of node .
- : Hidden representation of node .
- : Aggregated message for edge
(v,w). - : Neighbors of node excluding the one connected by edge
(v,w). This formulation for is slightly unusual as it sums overh_uv(which might imply edge representations) but the sum is over (neighbors of ), andh_uvis not defined in the message passing section, suggesting a typo and that it should be related to neighbor nodes or edges of the edge. However, strictly following the paper's text and formula, we present it as is.
- Residual Connections:
Long-range residual connectionsare added from the initial features of nodes and edges to and . These connections help mitigatevanishing gradientsandover-smoothingduring message passing, improving training stability and information flow. - Final Output: After processing through multiple
GNN-Attention blocksand residual connections, aFeed Forwardnetwork andLayerNormare applied to obtain theunmasked node embeddingandedge embeddingas the final output of the encoder.
4.2.2.3. Decoder Specifics
- Feature Reordering: At the start of the decoder, a
Feature Reorderinglayer (as shown in Figure 2) concatenates:- The
embeddings of unmasked nodes and edgesreceived from the encoder. Mask tokens(learnable placeholders) for the removed nodes and edges.Positional embeddingscorresponding to their original positions in the graph are added to restore the correct order of all features.
- The
- Architecture: The decoder uses the same transformer-styled architecture as the encoder (i.e., stacked
GNN-Attention blocks) but is designed to bemore lightweight(fewer layers, ). - Output: The decoder produces node and edge embeddings for the entire graph, which are then used for the reconstruction task.
4.2.3. Reconstruction Target
BatmanNet's node and edge branches reconstruct molecules by predicting all features of the masked nodes and edges, respectively. The specific atom and bond features used are detailed in Supplementary Table S1.
- High-Dimension Multi-Label Predictions: The reconstruction tasks for both nodes (atoms) and edges (bonds) involve predicting high-dimensional, multi-label features. This design choice aims to alleviate the
ambiguity problemnoted in previous works [18], where using a limited number of atom or edge types as prediction targets might oversimplify the task. - Reconstruction Loss: A linear layer is appended to the output of each decoder. Its output dimension matches the total feature size of atoms (for the node branch) or bonds (for the edge branch). The
pre-training loss() is computed only on themasked tokens, similar toMAE[41]. The final pre-training loss is defined as: $ \mathcal{L}{\mathrm{pre-train}} = \mathcal{L}{\mathrm{node}} + \mathcal{L}{\mathrm{edge}} $ $ \mathcal{L}{\mathrm{node}} = \sum_{v \in \mathcal{V}{\mathrm{mask}}} \mathcal{L}{\mathrm{ce}} \left( \pmb{p}v, \pmb{y}v \right) $ $ \mathcal{L}{\mathrm{edge}} = \sum{(u,v) \in \mathcal{E}{\mathrm{mask}}} \mathcal{L}{\mathrm{ce}} \left( \pmb{p}{(u,v)}, \pmb{y}{(u,v)} \right) $ Where:- : The loss function for the node branch.
- : The loss function for the edge branch.
- : The set of masked nodes.
- : The set of masked edges.
- : The
cross-entropy loss(explained in3.1.6 Evaluation Metrics). - : Predicted features for masked node .
- : Ground-truth features for masked node .
- : Predicted features for masked edge
(u,v). - : Ground-truth features for masked edge
(u,v).
4.2.4. Pre-training Strategy: Bi-branch Graph Masking
The self-supervised pre-training strategy is a bi-branch graph masking and reconstruction approach.
- Random Masking: Given a molecular graph, the approach randomly masks a high proportion of its nodes (for the node branch) and edges (for the edge branch). The paper specifically mentions a
60%masking ratio. - Directed Masking for Edges: When masking edges, a
directed masking scheme[45] is adopted. This means that if an edge(u,v)is removed, its reverse(v,u)is not necessarily removed. To distinguish directed edges, the feature of the starting node (head node) is added to the initial feature of the edge. - Effectiveness of the Strategy:
-
Node-level Pre-training for Local Context: The high masking ratio ensures that each node/edge is likely to miss many of its neighbors. To reconstruct these missing parts, each node and edge embedding must learn comprehensive
local contextual information, going beyond limitedk-hopranges or specific subgraph shapes. This contrasts with previous methods that rely on predefining motifs or subgraphs. -
Graph-level Pre-training for Global Information: The task involves predicting the entire graph from the highly incomplete remaining nodes and edges. This is a more challenging task compared to other SSL tasks that might target smaller subgraphs or motifs. This increased challenge forces the model to develop a more powerful capacity for learning high-quality node and edge embeddings that capture both
localandglobalmolecular information effectively.The following figure (Figure 1 from the original paper) shows the BatmanNet pre-training for reconstruction:
该图像是示意图,展示了BatmanNet模型中节点和边的掩蔽过程。上半部分描述了节点掩蔽策略,通过编码器将初始节点嵌入转化为节点嵌入。下半部分展示了边掩蔽策略,初始边嵌入通过编码器处理得到边嵌入。
-
This strategy combines efficiency (minimal domain knowledge, single task) with effectiveness (learning comprehensive local and global molecular information).
5. Experimental Setup
5.1. Datasets
5.1.1. Pre-training Datasets
- ZINC250K: This dataset was used for pre-training
BatmanNet. It consists of 250,000 molecules sampled from the ZINC database [56]. The dataset was randomly split into training and validation sets in a 9:1 ratio. - Why Chosen:
ZINCis a widely used and publicly available database for chemical compounds, providing a large-scale source of unlabeled molecules suitable for self-supervised pre-training.
5.1.2. Downstream Task Datasets
The paper evaluates BatmanNet on 13 benchmark datasets across three categories of drug discovery tasks. Scaffold splitting [61, 62] was applied to all downstream datasets to divide them into training, validation, and test sets at an 8:1:1 ratio. This method ensures that structurally similar molecules are grouped, providing a more challenging and realistic evaluation of the model's ability to generalize to out-of-distribution molecules.
5.1.2.1. Molecular Property Prediction Datasets
These datasets from MoleculeNet [60] cover various physical, chemical, and biophysical properties, including both classification and regression tasks.
- BBBP [69]: (Classification) Records the blood-brain barrier permeability property of compounds.
- SIDER [70]: (Classification) A database of marketed drugs and their adverse drug reactions, grouped into 27 system organ classes (multi-task classification).
- ClinTox [71]: (Classification) Compares drugs approved by the FDA with those eliminated due to toxicity during clinical trials (multi-task classification).
- BACE [72]: (Classification) Provides quantitative binding results for inhibitors of human -secretase 1 (BACE-1).
- Tox21 [73]: (Classification) A public database measuring the toxicity of compounds on 12 different targets, including nuclear receptors and stress responses (multi-task classification).
- ToxCast [74]: (Classification) Provides toxicology data for 8615 compounds based on
in vitrohigh-throughput screening for 617 endpoints (multi-task classification). - FreeSolv [75]: (Regression) Provides experimental and calculated hydration free energy of small molecules in water.
- ESOL [76]: (Regression) A small dataset containing water solubility data for 1128 compounds.
- Lipo [77]: (Regression) Curated from the
ChEMBLdatabase, provides experimental results ofoctanol/water distribution coefficient (log D at pH 7.4)for 4200 compounds. - QM7 [78]: (Regression) A subset of the
GDB-13database, containingquantum mechanical propertiesfor molecules with up to seven "heavy" atoms (C, N, O, S). - QM8 [79]: (Regression) Applied to a collection of molecules with up to eight heavy atoms (subset of
GDB-17), containing computer-generatedquantum mechanical properties.
5.1.2.2. Drug-Drug Interaction (DDI) Prediction Datasets
This task is formulated as a binary classification problem: predicting whether two drugs are likely to interact.
- BIOSNAP [67]: Consists of 1322 approved drugs with 41520 labeled DDIs, sourced from drug labels and scientific publications.
- TWOSIDES [68]: Contains 548 drugs and 48584 pairwise drug-drug interactions, with associated side effects.
5.1.2.3. Drug-Target Interaction (DTI) Prediction Datasets
This task identifies interactions between a compound and a target protein.
- Human and C. elegans: Created by Liu et al., these datasets include highly credible negative samples of compound-protein pairs alongside positive samples from
DrugBank 4.1andMatador. A balanced dataset with a 1:1 ratio of positive and negative samples was used.
5.1.3. Atom and Bond Features
The RDKit library is used to extract atom and bond features, which serve as the input for the GNNs and the reconstruction targets for BatmanNet.
The following are the results from Table S1 of the original paper:
| Features | Size | Description | |
|---|---|---|---|
| Atom | Atom type | 23 | The atom type (e.g., C, N, O), by atomic number |
| Number of H | 6 | The number of bonded hydrogen atoms | |
| Charge | 5 | The formal charge of the atom | |
| Chirality | 4 | The chiral-tag of the atom | |
| Is-aromatic | 1 | Whether the atom is part of an aromatic system or not | |
| Bond | Bond type | 5 | The bond type (e.g., single, double, triple et al.) |
| Stereo | 6 | The stereo-configuration of the bond |
5.2. Evaluation Metrics
The choice of evaluation metrics depends on the task type (classification or regression):
-
Classification Tasks (Molecular Property Prediction, DDI, DTI):
- ROC-AUC (Receiver Operating Characteristic - Area Under the Curve): (Defined in
3.1.6) Measures the model's ability to distinguish between classes across various thresholds. - PR-AUC (Precision-Recall Area Under the Curve): (Defined in
3.1.6) Useful for imbalanced datasets, it quantifies the trade-off between precision and recall. - F1 Score: (Defined in
3.1.6) Harmonic mean of precision and recall, balancing both. - Precision: (Defined in
3.1.6) The proportion of correctly predicted positive instances among all instances predicted as positive. - Recall: (Defined in
3.1.6) The proportion of correctly predicted positive instances among all actual positive instances.
- ROC-AUC (Receiver Operating Characteristic - Area Under the Curve): (Defined in
-
Regression Tasks (Molecular Property Prediction like FreeSolv, ESOL, Lipo, QM7, QM8):
- RMSE (Root Mean Squared Error): (Defined in
3.1.6) Measures the average magnitude of prediction errors.
- RMSE (Root Mean Squared Error): (Defined in
5.3. Baselines
The performance of BatmanNet was compared against 20 competitive baselines: 10 supervised learning models (without pre-training) and 10 self-supervised learning models (with pre-training).
5.3.1. Supervised Learning Models (without pre-training)
- ECFP [50]: A circular topological fingerprint, a traditional feature-based method.
- TF_Robust [51]: A DNN-based multi-task framework that uses molecular fingerprints as input.
- GraphConv [52]: A basic graph convolutional network model.
- Weave [31]: A graph convolution model that jointly encodes atom and bond features.
- SchNet [32]: A quantum chemistry-inspired GNN that uses continuous-filter convolutional layers.
- MPNN [37]: A general message-passing neural network framework.
- DMPNN [39]: A variant of MPNN that explicitly considers edge features during message passing.
- MGCN [40]: A hierarchical GNN model for molecular property prediction.
- AttentiveFP [35]: An extension of the graph attention network that learns aggregation weights.
- TrimNet [53]: A graph-based approach employing a triplet message mechanism for efficient molecular representation.
5.3.2. Self-supervised Learning Models (with pre-training)
- Mol2Vec [54]: An NLP-inspired approach for learning molecular representations from substructures.
- N-GRAM [43]: Another NLP-inspired method using N-gram graphs for molecular representations.
- SMILES-BERT [14]: A BERT-style model pre-trained on sequential SMILES representations using masked language modeling.
- pre-trainGNN [17]: A general framework for pre-training GNNs.
- GraphMAE [47]: A self-supervised masked graph autoencoder that replaces masked nodes with descriptors.
- GROVERbase, GROVERlarge [18]:
Transformer-based models pre-trained on large-scale molecular data using self-supervised graph tasks (including motif prediction).GROVERlargeis a larger variant ofGROVERbase. - KPGT [19]: A knowledge-guided pre-training method based on the graph transformer, incorporating additional domain knowledge.
- MPG [13]: An effective self-supervised framework for learning expressive molecular global representations.
- GEM [20]: A geometry-enhanced molecular representation learning model with dedicated geometry-level self-supervised strategies.
5.4. Experimental Configurations
5.4.1. Pre-training Details
-
Optimizer:
Adam[57] optimizer. -
Learning Rate Scheduler:
Noam learning rate scheduler[42]. It warms up the learning rate over the first epoch and then decreases it exponentially. -
Batch Size: 32.
-
Masking Ratio: 0.6 (60%) for both node and edge branches.
-
Encoder Layers (
num_enc_mt_block): 6. -
Decoder Layers (
num_dec_mt_block): 2. This demonstrates the asymmetric lightweight decoder design. -
Hidden Size (
hidden_size): 100. -
GNN Layers within GNN-Attention block (
depth): 3. -
Self-attention Heads (
num_attn_head): 2. -
Parameters: Approximately 2.6 million parameters.
-
Training Time: Two days on a single
Nvidia RTX3090GPU.The following are the results from Table S2 of the original paper:
Hyper-parameter Value Description batch_size 32 The input batch_size hidden_size 100 The hidden _size of encoder and decoder depth 3 The number of GNN layers in GNN-Attention block num_enc_mt_block 6 The number of the GNN-Attention block in encoder num_dec_mt_block 2 The number of the GNN-Attention block in decoder num_attn_head 2 The number of attention heads in the GNN-Attention block mask_ratio 0.6 The mask ratio init_lr 0.0002 The initial learning rate of Noam learning rate schedular max_lr 0.0004 The maximum learning rate of Noam learning rate schedular final_lr 0.0001 The final learning rate of Noam learning rate schedular
5.4.2. Fine-tuning Details
For downstream tasks, only BatmanNet's encoder is used, and the input molecules are complete (no masking).
5.4.2.1. Node and Graph Level Embeddings
- Node Aggregation: After
GNN-Attention blocksin the encoder, both branches perform node aggregation to produce two node representations, and : $ \mathbf{m}v^{\mathrm{node-branch}} = \sum{u \in \mathcal{N}_v} \overline{\mathbf{h}}_u $ $ \mathbf{m}v^{\mathrm{edge-branch}} = \sum{u \in \mathcal{N}v \setminus w} \overline{\mathbf{h}}{uv} $ Where:- : The hidden state of node from the
node-branch'sGNN-Attention blocks. - : The hidden state of edge
(u,v)from theedge-branch'sGNN-Attention blocks. - Similar to the encoder explanation, is used to represent edges, but the sum is over neighbors , making the
edge-branchformulation slightly ambiguous in the notation. However, we adhere to the formula as provided in the supplementary material.
- : The hidden state of node from the
- Residual Connections: A single long-range residual connection concatenates and with their respective initial node and edge features.
- Final Embeddings: These are then passed through
Feed Forwardlayers andLayerNormto generate the final node and edge embeddings. - Self-attentive READOUT: Following
GROVER[18], a sharedself-attentive READOUTfunction is applied to these node embeddings to generate twograph-level embeddings, and . $ \mathbf{S} = \mathrm{softmax} \left( \mathbf{W}_2 \tanh \left( \mathbf{W}_1 \mathbf{H}^\top \right) \right) $ $ \pmb{g} = \mathrm{Flatten}(\mathbf{S}\mathbf{H}) $ Where:- : The matrix of node hidden representations (from either node or edge branch).
- and : Weight matrices for the attention mechanism.
- : The attention scores.
- : The resulting graph-level embedding.
5.4.2.2. Downstream Task Specifics
-
Molecular Properties Prediction:
Feed Forwardlayers are applied to both and to obtain predictions and : $ \pmb{p}_i^{\mathrm{node-branch}} = f \left( W \pmb{g}^{\mathrm{node-branch}} + b \right) $ $ \pmb{p}_i^{\mathrm{edge-branch}} = f \left( W \pmb{g}^{\mathrm{edge-branch}} + b \right) $ Where is an activation function andW, bare learnable parameters. -
DDI Prediction: For a pair of molecules, their respective graph-level embeddings and are obtained. These are concatenated to form pair-wise embeddings for each branch: $ \pmb{g}_{\mathrm{pair}}^{\mathrm{node-branch}} = \mathrm{Concat} \left( \pmb{g}_1^{\mathrm{node-branch}}, \pmb{g}2^{\mathrm{node-branch}} \right) $ $ \pmb{g}{\mathrm{pair}}^{\mathrm{edge-branch}} = \mathrm{Concat} \left( \pmb{g}_1^{\mathrm{edge-branch}}, \pmb{g}_2^{\mathrm{edge-branch}} \right) $ Then, predictions are made for each branch: $ \pmb{p}i^{\mathrm{node-branch}} = f \left( W \pmb{g}{\mathrm{pair}}^{\mathrm{node-branch}} + b \right) $ $ \pmb{p}i^{\mathrm{edge-branch}} = f \left( W \pmb{g}{\mathrm{pair}}^{\mathrm{edge-branch}} + b \right) $
-
DTI Prediction: The framework by Tsubaki et al. [65] is adapted.
BatmanNet's encoder replaces their compound encoder. ACNNmodel with anattention mechanismencodes protein sequences to derive protein representation . Given hidden vectors of protein sub-sequences : $ \pmb{y}_p = \sum_i^n \left( \alpha_i \pmb{h}_i \right) $ $ \pmb{\alpha}_i = \sigma \left( \pmb{h}_m^T \pmb{h}_i \right) $ $ \pmb{h}_m = f \left( \pmb{W} \pmb{g}_m + b \right) $ $ \pmb{h}_i = f \left( \pmb{W} s_i + b \right) $ Where:-
: Protein sequence representation.
-
: Attention weight for sub-sequence .
-
: A molecular vector (context vector for attention).
-
: Transformed sub-sequence vector.
-
: Learnable weight matrix and bias vector.
-
: Activation function.
The molecular embeddings and (from BatmanNet's encoder) are concatenated with the protein embeddings to form pair-wise representations: $ \pmb{y}_{\mathrm{pair}}^{\mathrm{node-branch}} = \mathrm{Concat} \left( \pmb{g}^{\mathrm{node-branch}}, \pmb{y}p^{\mathrm{node-branch}} \right) $ $ \pmb{p}i^{\mathrm{node-branch}} = f \left( W \pmb{y}{\mathrm{pair}}^{\mathrm{node-branch}} + b \right) $ $ \pmb{y}{\mathrm{pair}}^{\mathrm{edge-branch}} = \mathrm{Concat} \left( \pmb{g}^{\mathrm{edge-branch}}, \pmb{y}_p^{\mathrm{edge-branch}} \right) $ $ \pmb{p}i^{\mathrm{edge-branch}} = f \left( W \pmb{y}{\mathrm{pair}}^{\mathrm{edge-branch}} + b \right) $
-
5.4.2.3. Fine-tuning Loss Function
The final loss for downstream tasks combines a supervised loss () and a disagreement loss () [80]. The disagreement loss encourages consistency between the predictions of the two branches.
$
\mathcal{L}{\mathrm{fine-tune}} = \mathcal{L}{\mathrm{sup}} + \mathcal{L}{\mathrm{diss}}
$
$
\mathcal{L}{\mathrm{sup}} = \mathcal{L} \left( \pmb{p}_i^{\mathrm{node-branch}}, \pmb{y}_i \right) + \mathcal{L} \left( \pmb{p}_i^{\mathrm{edge-branch}}, \pmb{y}i \right)
$
$
\mathcal{L}{\mathrm{diss}} = \left. \pmb{p}_i^{\mathrm{node-branch}} - \pmb{p}_i^{\mathrm{edge-branch}} \right|_2
$
Where:
- : The supervised loss function (e.g., cross-entropy for classification, RMSE for regression) for prediction and ground truth .
- : The L2 norm (Euclidean distance) between the predictions of the two branches.
5.4.2.4. Fine-tuning Hyperparameters
Hyperparameters are chosen via random search for each task.
The following are the results from Table S3 of the original paper:
| Hyper-parameter | Value | Description |
|---|---|---|
| batch_size | 32 | The input batch size |
| ffn_hidden_size | 200 | The hidden_size of MLP layers |
| ffn_num_layer | 2 | The number of MLP layers |
| attn_hidden | 200 | The hidden_size for the self-attentive readout |
| attn_out | 2 | The number of output heads for the self-attentive readout |
| dist_coff | 0.1 | The coefficient of the disagreement loss |
| init_lr | max_lr / 10 | The initial learning rate of Noam learning rate schedular |
| max_lr | 0.0001 ~ 0.001 | The maximum learning rate of Noam learning rate schedular |
| final_lr | max_lr / (5-10) | The final learning rate of Noam learning rate schedular |
6. Results & Analysis
6.1. Core Results Analysis
6.1.1. Molecular Property Prediction
The paper first presents results on 9 molecular property prediction datasets, including both classification (BACE, BBBP, ClinTox, SIDER, Tox21, ToxCast) and regression (FreeSolv, ESOL, Lipo) tasks. The performance of BatmanNet is compared against 10 supervised models and 10 self-supervised models.
The following are the results from Table 1 of the original paper:
| Dataset | BACE | BBBP | ClinTox | SIDER | Tox21 | ToxCast | FreeSolv | ESOL | Lipo |
| #Molecules | 1513 | 2039 | 1478 | 1427 | 7831 | 8575 | 642 | 1128 | 4200 |
| #tasks | 1 | 1 | 2 | 27 | 12 | 617 | 1 | 1 | 1 |
| ECFP [50] | 0.861(0.024) | 0.783(0.050) | 0.673(0.031) | 0.630(0.019) | 0.760(0.009) | 0.615(0.017) | 5.275(0.751) | 2.359(0.454) | 1.188(0.061) |
| TF_Robust [51] | 0.824(0.022) | 0.860(0.087) | 0.765(0.085) | 0.607(0.033) | 0.698(0.012) | 0.585(0.031) | 4.122(0.085) | 1.722(0.038) | 0.909(0.060) |
| GraphConv [52] | 0.854(0.011) | 0.877(0.036) | 0.845(0.051) | 0.593(0.035) | 0.772(0.041) | 0.650(0.025) | 2.900(0.135) | 1.068(0.050) | 0.712(0.049) |
| Weave [31] | 0.791(0.008) | 0.837(0.065) | 0.823(0.023) | 0.543(0.034) | 0.741(0.044) | 0.678(0.024) | 2.398(0.250) | 1.158(0.055) | 0.813(0.042) |
| SchNet [32] | 0.750(0.033) | 0.847(0.024) | 0.717(0.042) | 0.545(0.038) | 0.767(0.025) | 0.679(0.021) | 3.215(0.755) | 1.045(0.064) | 0.909(0.098) |
| MPNN [37] | 0.815(0.044) | 0.913(0.041) | 0.879(0.054) | 0.595(0.030) | 0.808(0.024) | 0.691(0.013) | 2.185(0.952) | 1.167(0.430) | 0.672(0.051) |
| DMPNN [39] | 0.852(0.053) | 0.919(0.030) | 0.897(0.040) | 0.632(0.023) | 0.826(0.023) | 0.718(0.011) | 2.177(0.914) | 0.980(0.258) | 0.653(0.046) |
| MGCN [40] | 0.734(0.030) | 0.850(0.064) | 0.634(0.042) | 0.552(0.018) | 0.707(0.016) | 0.663(0.009) | 3.349(0.097) | 1.266(0.147) | 1.113(0.041) |
| AttentiveFP [35] | 0.863(0.015) | 0.908(0.050) | 0.933(0.020) | 0.605(0.060) | 0.807(0.020) | 0.579(0.001) | 2.030(0.420) | 0.853(0.060) | 0.650(0.030) |
| TrimNet [53] | 0.843(0.025) | 0.892(0.025) | 0.906(0.017) | 0.606(0.006) | 0.812(0.019) | 0.652(0.032) | 2.529(0.111) | 1.282(0.029) | 0.702(0.008) |
| Mol2Vec [54] | 0.841(0.052) | 0.876(0.030) | 0.828(0.023) | 0.601(0.023) | 0.805(0.015) | 0.690(0.052) | 5.752(1.245) | 2.358(0.452) | 1.178(0.054) |
| N-GRAM [43] | 0.876(0.035) | 0.912(0.013) | 0.855(0.037) | 0.632(0.005) | 0.769(0.027) | 2.512(0.190) | 1.100(0.160) | 0.876(0.033) | |
| SMILES-BERT [14] | 0.849(0.021) | 0.959(0.009) | 0.985(0.014) | 0.568(0.031) | 0.803(0.010) | 0.665(0.010) | 2.974(0.510) | 0.841(0.096) | 0.666(0.029) |
| pre-trainGNN [17] | 0.851(0.027) | 0.915(0.040) | 0.762(0.058) | 0.614(0.006) | 0.811(0.015) | 0.714(0.019) | |||
| GraphMAE1 [47] | 0.863(0.002) | 0.896(0.007) | 0.850(0.007) | 0.652(0.001) | 0.794(0.003) | 0.679(0.005) | - | - | - |
| GROVERbase [18] | 0.878(0.016) | 0.936(0.008) | 0.925(0.013) | 0.656(0.023) | 0.819(0.020) | 0.723(0.010) | 1.592(0.072) | 0.888(0.116) | 0.563(0.030) |
| GROVERlarge [18] | 0.894(0.028) | 0.940(0.019) | 0.944(0.021) | 0.658(0.023) | 0.831(0.025) | 0.737(0.010) | 1.544(0.397) | 0.831(0.120) | 0.560(0.035) |
| KPGT [19] | 0.855(0.011) | 0.908(0.010) | 0.946(0.022) | 0.649(0.009) | 0.848(0.013) | 0.746(0.002) | 2.121(0.837) | 0.803(0.008) | 0.600(0.010) |
| MPG [13] | 0.920(0.013) | 0.922(0.012) | 0.963(0.028) | 0.661(0.007) | 0.837(0.019) | 0.748(0.005) | 1.269(0.192) | 0.741(0.017) | 0.556(0.017) |
| GEM [20] | 0.925(0.010) | 0.953(0.007) | 0.977(0.019) | 0.663(0.014) | 0.849(0.003) | 0.742(0.004) | - | - | - |
| BatmanNet | 0.928(0.008) | 0.946(0.003) | 0.926(0.002) | 0.676(0.007) | 0.855(0.005) | 0.756(0.007) | 1.174(0.054) | 0.736(0.014) | 0.578(0.034) |
Analysis:
- Overall Superiority:
BatmanNetachieves state-of-the-art (SOTA) performance on 6 out of 9 datasets. - Vs. Supervised Models: It significantly outperforms all supervised models without pre-training across all datasets, highlighting the benefit of the self-supervised pre-training strategy.
- Vs. Previous SOTA:
- Classification Tasks (BACE, BBBP, ClinTox, SIDER, Tox21, ToxCast):
BatmanNetexceeds the previous SOTA model,GEM, on 4 out of 6 datasets (BACE, SIDER, Tox21, ToxCast). It shows very competitive performance on BBBP and ClinTox. - Regression Tasks (FreeSolv, ESOL, Lipo):
BatmanNetoutperformsMPG(previous SOTA for regression) on 2 out of 3 datasets (FreeSolv, ESOL).
- Classification Tasks (BACE, BBBP, ClinTox, SIDER, Tox21, ToxCast):
- Small Standard Deviations: The small standard deviations indicate high stability and robustness of
BatmanNet's performance.
6.1.2. Efficacy and Effectiveness Analysis (Pre-training Data and Model Size)
This analysis investigates the efficiency of BatmanNet by comparing its average AUC on classification tasks against the pre-training dataset size and model size of other baselines.
The following are the results from Table S4 of the original paper:
| Model | Pre-training Data Size (M) | Model Size (M) | AVG-AUC /% |
| GraphMAE | 2 | - | 78.90 |
| GROVERbase | 11 | 40 | 82.28 |
| GROVERlarge | 11 | 100 | 83.40 |
| KPGT | 2 | - | 82.53 |
| MPG | 11 | 55 | 84.18 |
| GEM | 20 | 85.15 | |
| BatmanNet | 0.25 | 2.6 | 84.78 |
The following figure (Figure 4 from the original paper) illustrates the pre-training dataset size and model size for BatmanNet and a series of advanced baselines, along with their average AUC across all classification datasets about molecular property prediction.

Analysis:
BatmanNetachieves an average AUC of 84.78%, which is very close to the top-performingGEM(85.15%) andMPG(84.18%).- Crucially,
BatmanNetachieves this performance with a significantly smaller pre-training dataset size (0.25 million molecules from ZINC250K) compared toGROVER,MPG(11 million), orGEM(20 million). - Furthermore,
BatmanNethas a remarkably smaller model size (2.6 million parameters) thanGROVERbase(40 million) orGROVERlarge(100 million), andMPG(55 million). The model sizes forGraphMAEandKPGTare not explicitly provided in this table but are generally large transformer-style models. - This demonstrates
BatmanNet's superior efficacy and efficiency, achieving comparable SOTA performance with substantially fewer training data and model parameters, addressing a key challenge identified in the introduction.
6.1.3. Drug-Drug Interaction (DDI) Prediction
BatmanNet's effectiveness is evaluated on DDI prediction as a binary classification task using BIOSNAP and TWOSIDES datasets.
The following are the results from Table 2 of the original paper:
| Model | AUC-ROC | PR-AUC | F1 |
| LR | 0.802(0.001) | 0.779(0.001) | 0.741(0.002) |
| Nat.Prot [63] | 0.853(0.001) | 0.848(0.001) | 0.714(0.001) |
| Mol2Vec [54] | 0.879(0.006) | 0.861(0.005) | 0.798(0.007) |
| MolVAE [30] | 0.892(0.009) | 0.877(0.009) | 0.788(0.033) |
| DeepDDI [2] | 0.886(0.007) | 0.871(0.007) | 0.817(0.007) |
| CASTER [64] | 0.910(0.005) | 0.887(0.008) | 0.843(0.005) |
| GEM [20] | 0.960(0.003) | 0.956(0.002) | 0.903(0.003) |
| MPG [13] | 0.966(0.004) | 0.960(0.004) | 0.905(0.008) |
| BatmanNet | 0.972(0.001) | 0.966(0.001) | 0.916(0.002) |
Analysis (BIOSNAP):
-
BatmanNetachieves the highestAUC-ROC(0.972),PR-AUC(0.966), andF1 score(0.916) on the BIOSNAP dataset. -
It surpasses the previous SOTA (
MPG) across all three metrics, demonstrating its superior performance in DDI prediction. The low standard deviations indicate consistent results.The following are the results from Table 4 of the original paper:
Model Precision Recall F1 DDI_PULearn [66] 0.904 0.824 0.862 GEM [20] 0.928 0.929 0.928 MPG [13] 0.936 0.936 0.936 BatmanNet 0.939 0.939 0.939
Analysis (TWOSIDES):
BatmanNetagain shows the best performance with aPrecision,Recall, andF1 scoreof 0.939 on the TWOSIDES dataset.- It slightly but consistently outperforms
MPGandGEM, further solidifying its SOTA status for DDI prediction.
6.1.4. Drug-Target Interaction (DTI) Prediction
BatmanNet is evaluated for DTI prediction by replacing the compound encoder in an existing framework [65] with its pre-trained encoder.
The following are the results from Table 3 of the original paper:
| Datasets | Model | Precision | Precision | AUC |
| Human | Tsubaki et al. [65] | 0.923 | 0.918 | 0.970 |
| GEM [20] | 0.930 | 0.930 | 0.972 | |
| MPG [13] | 0.952 | 0.940 | 0.985 | |
| BatmanNet | 0.983 | 0.982 | 0.998 | |
| (Relative improvement) | (3.26%) | (4.47%) | (1.32%) | |
| C.elegans | Tsubaki et al. [65] | 0.938 | 0.929 | 0.978 |
| GEM [20] | 0.955 | 0.954 | 0.988 | |
| MPG [13] | 0.954 | 0.959 | 0.986 | |
| BatmanNet | 0.988 | 0.987 | 0.999 | |
| (Relative improvement) | (3.46%) | (2.92%) | (1.11%) |
Analysis:
BatmanNetachieves significantly improved performance on both Human and C. elegans datasets.- On the Human dataset, it shows relative improvements of 3.26% in Precision, 4.47% in a second Precision column (likely Recall, given the context of DTI evaluation, but ambiguously labeled as "Precision" in the table), and 1.32% in AUC compared to
MPG. - On the C. elegans dataset, improvements are 3.46% (Precision), 2.92% (second Precision), and 1.11% (AUC) over
MPG. - The results, particularly
AUCvalues approaching 0.998 and 0.999, indicate strong transfer learning capabilities ofBatmanNetfor DTI prediction, suggesting its high-quality molecular representations are highly effective in this task.
6.1.5. Pre-trained Representations Visualization
The paper visualizes the learned representations of the self-supervised tasks (without downstream fine-tuning) to observe their discriminative power. Using UMAP [59], molecular embeddings are projected into a 2D space.
-
Methodology: 1,500 valid molecules from ZINC were mixed with invalid molecules generated by structural perturbations (shuffling atom features, altering atom/bond order).
-
Observation: As illustrated in Figure 3, the pre-trained
BatmanNetshows a clear separation between valid and invalid molecules, forming distinct clusters. In contrast, the non-pre-trained model exhibits much less clear separation.The following figure (Figure 3 from the original paper) shows the clustering of molecular representations between non-pre-trained (a) and pre-trained (b) states.
该图像是示意图,展示了非预训练(b)与预训练(a)的分子表示的聚类情况。左侧的图表中,蓝色点代表无效分子,红色点代表有效分子;右侧的图表则显示通过预训练后的分子聚类效果,有效分子的分布更为集中。Analysis: This visualization provides qualitative evidence that the
self-supervised pre-training strategyenablesBatmanNetto effectively learn and distinguish fundamental structural validity of molecules, confirming that it captures meaningful underlying molecular information even before fine-tuning on specific tasks.
6.2. Ablation Studies / Parameter Analysis
Ablation studies were conducted on 6 classification benchmarks (molecular property prediction) to understand the influence of different architectural components and hyperparameters.
The following figure (Figure 5 from the original paper) illustrates the AUC results of the BatmanNet model across various drug discovery tasks.
该图像是图表,展示了 BatmanNet 模型在不同药物发现任务上的 AUC 结果。图 (a) 对比了节点分支、边分支和 BatmanNet (双分支) 的性能,显示出 BatmanNet 在多个基准数据集上实现了最佳表现。图 (b) 展示了预训练和非预训练模型在各数据集上的 AUC 变化趋势,图 (c) 则分析了不同掩膜比例对 AVG-AUC 的影响,最高达到 0.848。
6.2.1. Effectiveness of the Bi-branch Information Extraction Network
This study evaluates the impact of the dual-branch design.
- Setup:
BatmanNet(bi-branch) was compared against single-branch versions (eithernode branch onlyoredge branch only), all pre-trained under the same conditions and having a similar number of parameters (2.6M). - Results (Figure 5a): The bi-branch
BatmanNetimproves the average AUC by 2.9% compared to the node-only branch and 3.0% compared to the edge-only branch. - Analysis: This clearly demonstrates that the complementary nature of simultaneously learning from both nodes and edges significantly enhances the model's ability to capture molecular information and improves performance.
6.2.2. Impact of the Self-Supervised Pre-training Strategy
This study assesses the contribution of the self-supervised pre-training.
- Setup:
Pre-trained BatmanNetwas compared withBatmanNet without pre-trainingon molecular property prediction tasks, using identical hyperparameters for fine-tuning. - Results (Figure 5b): The pre-trained model consistently outperforms the non-pre-trained model, showing an average AUC improvement of 4.0%.
- Analysis: This confirms that the proposed
self-supervised pre-training strategyis highly effective. It successfully captures rich structural and semantic information from unlabeled molecules, which then transfers to improve performance on downstream tasks, validating the core hypothesis of the paper.
6.2.3. Effect of Different Masking Ratio
This study investigates how varying the masking ratio (proportion of masked nodes and edges) impacts performance.
-
Setup:
BatmanNetwas pre-trained with masking ratios ranging from 10% to 90%, and the average AUC on all downstream classification tasks was measured.The following are the results from Table S6 of the original paper:
Ratio BBBP SIDER ClinTox BACE Tox21 ToxCast Avg 0.1 0.923(0.032) 0.662(0.015) 0.905(0.028) 0.913(0.007) 0.843(0.014) 0.739(0.011) 0.831 0.2 0.929(0.027) 0.667(0.003) 0.912(0.012) 0.915(0.006) 0.845(0.009) 0.745(0.009) 0.836 0.3 0.933(0.018) 0.668(0.003) 0.918(0.025) 0.919(0.013) 0.848(0.017) 0.749(0.007) 0.839 0.4 0.940(0.011) 0.671(0.006) 0.920(0.028) 0.920(0.014) 0.850(0.014) 0.750(0.009) 0.842 0.5 0.943(0.019) 0.675(0.004) 0.925(0.025) 0.925(0.014) 0.851(0.013) 0.753(0.008) 0.845 0.6 0.946(0.007) 0.676(0.004) 0.926(0.015) 0.928(0.015) 0.855(0.013) 0.756(0.009) 0.848 0.7 0.942(0.008) 0.674(0.004) 0.926(0.011) 0.924(0.016) 0.848(0.012) 0.751(0.007) 0.844 0.8 0.938(0.012) 0.673(0.004) 0.921(0.288) 0.918(0.016) 0.844(0.014) 0.750(0.008) 0.841 0.9 0.932(0.020) 0.668(0.005) 0.910(0.022) 0.915(0.015) 0.840(0.015) 0.746(0.011) 0.835
The following figure (Figure S1 from the original paper) shows the influence of the masking ratio on each benchmark dataset.

Analysis (Figure 5c, Table S6, Figure S1):
- The optimal masking ratio is found to be
60%, which yields the best average AUC (0.848). - Lower Ratios (10-50%): As the masking ratio increases from 10% to 60%, the performance generally improves. This is because a higher masking ratio creates a more challenging pre-training task, forcing the remaining unmasked nodes and edges to learn more comprehensively about their context to recover the missing parts. This leads to richer and more powerful embeddings.
- Higher Ratios (70-90%): Beyond 60%, the performance starts to decline. This indicates that when too much information is masked, there is insufficient contextual information left in the visible parts of the graph for the model to reliably reconstruct the complete original graph. This impairs the quality of the learned embeddings.
- The consistency of this trend across various datasets (Figure S1) further validates the choice of 60% as an effective masking ratio.
6.2.4. Additional Experiments on Molecular Property Prediction
The following are the results from Table S5 of the original paper:
| Dataset | BACE | BBBP | Clin Tox | SIDER | Tox21 | ToxCast | HIV | MUV | Avg | |
| #molecules | 1513 | 2039 | 1478 | 1427 | 7831 | 8575 | 41127 | 93087 | - | |
| #tasks | 1 | 1 | 2 | 27 | 12 | 617 | 1 | 17 | - | |
| D-MPNN | 0.809(0.006) | 0.710(0.003) | 0.906(0.006) | 0.570(0.007) | 0.759(0.007) | 0.655(0.003) | 0.771(0.005) | 0.786(0.014) | 0.746 | |
| AttentiveFP | 0.784(0.022) | 0.643(0.018) | 0.847(0.003) | 0.606(0.032) | 0.761(0.005) | 0.637(0.002) | 0.757(0.014) | 0.766(0.015) | 0.735 | |
| N-GramRF | 0.779(0.015) | 0.697(0.006) | 0.775(0.040) | 0.668(0.007) | 0.743(0.004) | 0.772(0.001) | 0.769(0.007) | - | ||
| N-GramXGB | 0.791(0.013) | 0.691(0.008) | 0.875(0.027) | 0.655(0.007) | 0.758(0.009) | - | 0.787(0.004) | 0.748(0.002) | - | |
| PretrainGNN | 0.845(0.007) | 0.687(0.013) | 0.726(0.015) | 0.627(0.008) | 0.781(0.006) | 0.657(0.006) | 0.799(0.007) | 0.813(0.021) | 0.742 | |
| GROVERbase | 0.826(0.007) | 0.700(0.001) | 0.812(0.030) | 0.648(0.006) | 0.743(0.001) | 0.654(0.004) | 0.625(0.009) | 0.673(0.018) | 0.710 | |
| GROVERlarge | 0.810(0.014) | 0.695(0.001) | 0.762(0.037) | 0.654(0.001) | 0.735(0.001) | 0.653(0.005) | 0.682(0.011) | 0.673(0.018) | 0.708 | |
| GraphMAE | 0.831(0.009) | 0.720(0.006) | 0.823(0.012) | 0.603(0.011) | 0.755(0.006) | 0.641(0.003) | 0.772(0.010) | 0.763(0.024) | 0.739 | |
| GEM | 0.856(0.011) | 0.724(0.004) | 0.901(0.013) | 0.672(0.004) | 0.781(0.001) | 0.692(0.004) | 0.806(0.009) | 0.817(0.005) | 0.781 | |
| BatmanNet | 0.861(0.028) | 0.838(0.005) | 0.897(0.012) | 0.659(0.003) | 0.792(0.003) | 0.718(0.007) | 0.812(0.009) | 0.784(0.014) | 0.795 | |
| Methods | ||||||||||
| Model | ||||||||||
| #molecules | ESOL | Avg - | ||||||||
| #tasks | 1128 1 | - | ||||||||
| 1.272 | ||||||||||
| D-MPNN AttentiveFP | 1.050(0.008) | 1.224 | ||||||||
| N-Gram_RF | 0.877(0.029) | 1.525 | ||||||||
| 1.074(0.107) 1.083(0.082) | 2.739 | |||||||||
| N-Gram_XGB | 1.100(0.006) | 1.534 | ||||||||
| PretrainGNN | 0.983(0.090) | 1.325 | ||||||||
| GROVERbase GROVERlarge | 0.895(0.017) | 1.330 | ||||||||
| GEM | 0.798(0.029) | 1.112 | ||||||||
| BatmanNet | 0.792(0.013) | 1.108 | ||||||||
Analysis:
BatmanNetachieves SOTA performance on 7 out of 11 datasets when compared under the experimental settings ofGEM[20].- It shows an overall relative improvement of 1.1% compared to previous SOTA across all datasets (1.8% on classification tasks and 0.4% on regression tasks).
- These results further reinforce
BatmanNet's robustness and generalizability across a broader range of molecular property prediction tasks.
7. Conclusion & Reflections
7.1. Conclusion Summary
This paper introduces BatmanNet, a novel bi-branch masked graph transformer autoencoder for learning effective molecular representations. The core innovation lies in its simple yet powerful self-supervised pre-training strategy, which involves masking a high proportion of both nodes and edges in a molecular graph and then reconstructing their complete features. This strategy efficiently captures both local and global molecular information without relying on complex, domain-specific tasks or additional knowledge.
The BatmanNet architecture features an asymmetric encoder-decoder design, where a robust encoder processes only the visible parts of the masked graph, and a lightweight decoder reconstructs the full graph. This design significantly reduces computational complexity, memory consumption, and pre-training time.
Extensive experimental results across 13 benchmark datasets for molecular property prediction, drug-drug interaction, and drug-target interaction tasks demonstrate that BatmanNet consistently achieves state-of-the-art performance. Notably, it delivers competitive or superior results compared to previous advanced models while utilizing substantially fewer pre-training data and model parameters, affirming its superior efficacy and efficiency.
7.2. Limitations & Future Work
The authors acknowledge several limitations and propose future research directions:
- Neglecting 3D Structural Information: The current
BatmanNetprimarily focuses on the 2D planar topological structure of molecules, overlooking crucial 3D structural information (e.g., conformations, spatial relationships).- Future Work: Incorporating 3D structural details into node and edge features could lead to a more comprehensive understanding of molecular features, improving performance, especially in tasks like
drug-target interactionswhere 3D geometry is critical.
- Future Work: Incorporating 3D structural details into node and edge features could lead to a more comprehensive understanding of molecular features, improving performance, especially in tasks like
- Limited Pre-training Dataset Size: Due to current computational constraints,
BatmanNetwas pre-trained on a relatively small dataset (ZINC250K) with a small-scale model.- Future Work: Expanding the approach to larger pre-training datasets could further assess the model's potential and how much additional improvement can be achieved.
- Mitigating Inherent Biases in Data-driven Approaches: The current framework is primarily data-driven, which may introduce potential biases present in the training data.
- Future Work: Exploring effective strategies to integrate domain knowledge into the data-driven pipeline could help mitigate these biases and further enhance model performance.
7.3. Personal Insights & Critique
This paper presents a compelling solution to the challenges of molecular representation learning. My personal insights and critique are as follows:
- Elegance of Simplicity: The most striking aspect of
BatmanNetis its ability to achieve state-of-the-art results with a significantly simpler self-supervised task (single bi-branch masking and reconstruction) and a more efficient asymmetric architecture. This challenges the trend of increasing complexity in pre-training tasks and model sizes, demonstrating that focused design choices can lead to superior outcomes. It suggests that the inherent information in molecular graph topology, when effectively extracted through a challenging reconstruction task, can be sufficient without heavy reliance on external domain knowledge during pre-training. - Scalability and Accessibility: The reduced computational complexity and model size make
BatmanNeta more accessible tool for researchers with limited computational resources, broadening the potential forAIDDadvancements beyond large institutions. This is a critical practical advantage. - Generalizability of Masked Autoencoding: The success of
BatmanNetfurther validates the power ofmasked autoencodingas a general self-supervised learning paradigm, extending its proven effectiveness from NLP and computer vision to graph-structured data, particularly in a complex scientific domain like chemistry. - Potential for Transfer to Other Domains: The core idea of a bi-branch masked graph autoencoder is highly transferable. It could be adapted to other graph-structured data in various fields, such as social networks, knowledge graphs, or biological networks (e.g., protein-protein interaction networks), where learning expressive node and edge representations is crucial.
- Areas for Improvement / Unverified Assumptions:
-
Specificity of Edge Graph Formulation: The formula for aggregated edge embedding in the paper is slightly ambiguous. While I adhered to it rigorously, a more precise definition or common variant (e.g., where explicitly refers to edge features, and the sum is over neighboring edges or nodes connected to the edge itself rather than just a node's neighbors) might enhance clarity. This might be a minor notation issue or a unique design choice that could warrant further explanation.
-
Dynamic Masking Strategies: The paper fixed the masking ratio at 60%. While optimal for the tested datasets, dynamic masking strategies (e.g., curriculum learning for masking, or adaptive masking based on graph properties) could potentially yield even better or more robust representations, especially for diverse molecular sizes and complexities.
-
Interpretability of Bi-branch Embeddings: While effective, further work could explore the interpretability of the learned node and edge embeddings from the two branches. Do they capture distinct, complementary chemical insights? Visualizing or analyzing the feature spaces more deeply could provide valuable domain-specific understanding.
In conclusion,
BatmanNetoffers a significant step forward inself-supervised molecular representation learning, providing an efficient, effective, and simple framework that has broad implications forAIDDand beyond. Its focus on simplicity and efficiency while maintaining high performance is a valuable lesson for future model design.
-
Similar papers
Recommended via semantic vector search.