End-to-End Multi-Task Learning with Attention
TL;DR Summary
The paper introduces a novel Multi-Task Attention Network (MTAN) for task-specific feature-level attention learning. This architecture employs a shared network and dedicated soft-attention modules, enabling efficient feature sharing and exceptional performance in multi-task learn
Abstract
We propose a novel multi-task learning architecture, which allows learning of task-specific feature-level attention. Our design, the Multi-Task Attention Network (MTAN), consists of a single shared network containing a global feature pool, together with a soft-attention module for each task. These modules allow for learning of task-specific features from the global features, whilst simultaneously allowing for features to be shared across different tasks. The architecture can be trained end-to-end and can be built upon any feed-forward neural network, is simple to implement, and is parameter efficient. We evaluate our approach on a variety of datasets, across both image-to-image predictions and image classification tasks. We show that our architecture is state-of-the-art in multi-task learning compared to existing methods, and is also less sensitive to various weighting schemes in the multi-task loss function. Code is available at https://github.com/lorenmt/mtan.
Mind Map
In-depth Reading
English Analysis
1. Bibliographic Information
1.1. Title
The central topic of the paper is an end-to-end multi-task learning architecture that incorporates attention mechanisms.
1.2. Authors
The authors are Shikun Liu, Edward Johns, and Andrew J. Davison, all affiliated with the Department of Computing, Imperial College London. Their research backgrounds appear to be in computer vision and machine learning, particularly in areas like convolutional neural networks and multi-task learning, given the content of the paper.
1.3. Journal/Conference
The paper was published on arXiv, a preprint server, with the publication date 2018-03-28T16:15:45.000Z. While arXiv itself is a preprint repository and not a peer-reviewed journal or conference, the quality and content suggest it was likely submitted to or accepted by a prominent computer vision or machine learning conference (e.g., CVPR, ICCV, NeurIPS, ICML), which are highly reputable and influential venues in the field.
1.4. Publication Year
The paper was published in 2018.
1.5. Abstract
The paper proposes a novel multi-task learning (MTL) architecture called the Multi-Task Attention Network (MTAN). This design allows for the learning of task-specific feature-level attention. The MTAN consists of a single shared network that creates a global feature pool, and for each task, a dedicated soft-attention module. These modules enable the network to learn features tailored to individual tasks from the global pool while simultaneously allowing feature sharing across different tasks. The architecture is designed to be end-to-end trainable, can be built upon any feed-forward neural network, is simple to implement, and is parameter efficient. The authors evaluate MTAN on various datasets for both image-to-image prediction (like semantic segmentation and depth estimation) and image classification tasks. They demonstrate that their architecture achieves state-of-the-art performance in multi-task learning compared to existing methods and is also less sensitive to the weighting schemes used in the multi-task loss function.
1.6. Original Source Link
The official source link is https://arxiv.org/abs/1803.10704. It is a preprint published on arXiv.
The PDF link is https://arxiv.org/pdf/1803.10704v2.pdf.
2. Executive Summary
2.1. Background & Motivation
The core problem the paper addresses is the inefficiency and complexity of building separate Convolutional Neural Networks (CNNs) for each task in real-world computer vision applications. While CNNs have achieved great success in single tasks like image classification or semantic segmentation, developing complete vision systems often requires performing multiple tasks simultaneously. Building independent networks for each task is inefficient in terms of memory consumption, inference speed, and data utilization, as related tasks often share common, informative visual features.
This leads to the importance of Multi-Task Learning (MTL), which aims to train a single network to perform several tasks concurrently. However, existing MTL approaches face two key challenges:
-
Network Architecture (how to share): An effective
MTLarchitecture needs to learn bothtask-shared features(for generalization and avoiding overfitting) andtask-specific features(for tailored performance and avoiding underfitting). Many prior methods either explicitly separate features too much, leading to a large number of parameters, or share features too rigidly, limiting flexibility. -
Loss Function (how to balance tasks): When multiple tasks are trained simultaneously, their individual loss contributions need to be balanced. Manually tuning these
loss weightsis tedious and often suboptimal, as easier tasks can dominate the training process. Automatically learning these weights or designing an architecture robust to weighting choices is highly desirable.The paper's innovative idea is to introduce
feature-level attention masksinto amulti-task learningarchitecture. This design allows for greater flexibility in sharing complementary features and automatically determines which features are important for each task, directly addressing the architectural challenge. Consequently, by dynamically adapting feature importance, the network inherently becomes more robust to the choice ofloss weighting schemes, indirectly addressing the loss balancing challenge.
2.2. Main Contributions / Findings
The paper makes several primary contributions:
-
Novel Multi-Task Attention Network (MTAN) Architecture: The authors propose
MTAN, a new architecture composed of a singleshared networkthat forms aglobal feature poolandK task-specific attention networks. Each attention network comprisessoft attention modulesapplied at various layers of the shared network. This design allowsMTANto automatically learn bothtask-sharedandtask-specific featuresin anend-to-endmanner. -
Parameter Efficiency:
MTANis highlyparameter efficientbecause it avoids the large number of parameters typically required by architectures that explicitly separate tasks (e.g.,Cross-Stitch NetworksorProgressive Networks). It scales gracefully with the number of tasks, requiring only a roughly 10% increase in parameters per additional task. -
Robustness to Loss Weighting: The
MTANarchitecture demonstrates an inherent robustness to the choice ofloss function weighting schemes. This significantly reduces the need for manual tuning of task weights, a common bottleneck inMTL. -
Dynamic Weight Average (DWA) Scheme: As part of their evaluation of robustness, the authors propose a simple yet effective adaptive weighting method called
Dynamic Weight Average (DWA).DWAadjusts task weights over time by considering the rate of change of loss for each task, requiring only numerical task losses (unlikeGradNorm, which needs gradient access). -
State-of-the-Art Performance:
MTANachievesstate-of-the-artor competitive performance on a variety of benchmarks, includingimage-to-image predictiontasks (semantic segmentation, depth estimation, surface normal prediction onCityScapesandNYUv2datasets) andimage classificationtasks (Visual Decathlon Challenge). -
Generalizability and Simplicity: The
MTANcan be built upon anyfeed-forward neural networkand is simple to implement, making it widely applicable.These findings collectively solve the problem of inefficient single-task model deployment by providing a flexible, efficient, and robust
multi-task learningframework that performs well across diverse tasks and reduces the complexity of hyper-parameter tuning.
3. Prerequisite Knowledge & Related Work
3.1. Foundational Concepts
To fully understand the Multi-Task Attention Network (MTAN) paper, a reader should be familiar with several fundamental concepts in deep learning and computer vision:
-
Convolutional Neural Networks (CNNs):
CNNsare a class of deep neural networks specifically designed for processing structured grid-like data, such as images. They achieve this usingconvolutional layers, which apply learnable filters to input data, extracting features. These networks typically consist of multiple layers, includingconvolutional layers,pooling layers(for down-sampling), andfully connected layers(for classification).CNNshave been highly successful in tasks like image classification, object detection, and semantic segmentation. -
Multi-Task Learning (MTL):
MTLis a subfield of machine learning where a single model is trained to perform multiple tasks simultaneously. The core idea is that by learning related tasks in parallel, the model can leverage shared representations and common features, leading to improved generalization and efficiency compared to training separate models for each task. It can help prevent overfitting, improve data efficiency, and learn more robust representations. -
Encoder-Decoder Networks (e.g., SegNet): These are a common architecture type in
CNNs, particularly forimage-to-image predictiontasks likesemantic segmentation.- An
encoderprogressively reduces the spatial dimension of the input image while increasing the feature dimension, capturing hierarchical features. This typically involvesconvolutionalandpooling layers. - A
decoderthen up-samples the encoded features back to the original input resolution, using the learned features to produce a pixel-wise prediction for the output task. This typically involvesup-samplingandconvolutional layers. SegNet[1] is a specificencoder-decoderarchitecture built upon theVGG-16[27] encoder, known for its efficient memory usage during inference due to storing only pooling indices for up-sampling.
- An
-
Attention Mechanism (General Concept): In neural networks, an
attention mechanismallows the model to focus on the most relevant parts of the input when making a prediction. Instead of treating all input parts equally,attentionassigns varyingweightsorscoresto different elements, highlighting those that are more important for the current task. In the context offeature-level attention, it means learning to emphasize or de-emphasize certain feature channels or spatial locations within feature maps.Soft attentionimplies that these weights are continuous values (e.g., between 0 and 1), allowing for a graded focus rather than a hard selection. -
Image-to-Image Prediction Tasks:
- Semantic Segmentation: This task involves classifying each pixel in an image into a predefined category (e.g., "road," "car," "sky"). The output is typically a pixel-wise label map.
- Depth Estimation: This task involves predicting the distance of each pixel from the camera. The output is a dense depth map.
- Surface Normal Prediction: This task involves estimating the orientation of the surface at each pixel. The output is typically a 3-channel image where each channel represents a component of the normal vector (x, y, z).
-
Loss Functions: In supervised learning, a
loss functionquantifies the discrepancy between the model's predictions and the true labels. The goal during training is to minimize this loss.- Cross-Entropy Loss: Commonly used for classification tasks. It measures the difference between two probability distributions. For
pixel-wise cross-entropyin semantic segmentation, it's applied independently to each pixel's class probabilities. - Norm Loss (Mean Absolute Error): Used for regression tasks. It measures the absolute difference between predicted and true values. It is less sensitive to outliers than
L2(Mean Squared Error). - Dot Product Loss: For
surface normal prediction, the dot product between two normalized vectors measures their angular similarity. A dot product close to 1 indicates parallel vectors (good prediction), while -1 indicates anti-parallel. Minimizing negative dot product maximizes similarity.
- Cross-Entropy Loss: Commonly used for classification tasks. It measures the difference between two probability distributions. For
3.2. Previous Works
The paper contextualizes MTAN by comparing it to several prior Multi-Task Learning (MTL) approaches, highlighting their strengths and limitations:
-
Cross-Stitch Networks [20]:
- Concept: This architecture consists of one standard
feed-forward networkper task. To enable feature sharing, it introducescross-stitch unitsbetween corresponding layers of different task-specific networks. These units learn to linearly combine the feature maps from one task's network with another's. - Limitation (highlighted by MTAN paper):
Cross-Stitch Networksrequire a large number of network parameters because they maintain largely separate networks for each task, scaling linearly with the number of tasks. This makes them inefficient for a large number of tasks. - Relevance:
MTANaims to achieve flexible feature sharing more efficiently using attention, instead of explicit linear combinations.
- Concept: This architecture consists of one standard
-
Self-supervised approach of [6] (Multi-task Self-supervised Visual Learning):
- Concept: Based on the
ResNet101architecture, this method learns a regularized combination of features from different layers of a single shared network. It primarily focuses on learning good representations through self-supervision for various low-level tasks. - Relevance: It explores feature sharing across layers of a single network, similar to
MTAN's approach of having a shared backbone, butMTANintroduces explicit attention for task-specific feature selection.
- Concept: Based on the
-
UberNet [16]:
- Concept: Proposes an image pyramid approach, where images are processed across multiple resolutions. On top of a shared
VGG-Net[27] backbone, additional task-specific layers are formed for each resolution. It aims to train a single network for a very wide range of low-, mid-, and high-level vision tasks. - Relevance: Like
MTAN, it uses a shared backbone but adds task-specific components. However,UberNet's task-specific layers are more explicit and potentially less flexible thanMTAN's attention modules.
- Concept: Proposes an image pyramid approach, where images are processed across multiple resolutions. On top of a shared
-
Progressive Networks [26]:
- Concept: Uses a sequence of incrementally-trained networks. Each new network is trained for a new task while leveraging knowledge transferred from previously trained networks via lateral connections. This allows for knowledge transfer without catastrophic forgetting.
- Limitation (highlighted by MTAN paper): Similar to
Cross-Stitch Networks,Progressive Networksalso require a large number of network parameters, as each new task essentially adds a new network. - Relevance:
MTANaims for concurrent learning and parameter efficiency, contrasting with the sequential learning and increased parameter count ofProgressive Networks.
-
Weight Uncertainty [14]:
- Concept: This method addresses the loss balancing challenge in
MTLby modifying the loss functions based on task uncertainty. It learns the relative weights of different tasks by considering the noise or uncertainty associated with each task's ground truth. Tasks with higher uncertainty are given lower weights. - Relevance:
MTANdirectly compares its robustness to this method and uses it as a baseline weighting scheme.MTAN's proposedDWAis also an adaptive weighting scheme, but simpler.
- Concept: This method addresses the loss balancing challenge in
-
GradNorm [3]:
- Concept:
GradNormis an adaptive loss balancing method that manipulates thegradient normsover time to control the training dynamics of multiple tasks. It dynamically adjusts task weights to ensure that all tasks learn at similar rates, preventing one task from dominating. - Limitation (highlighted by MTAN paper):
GradNormrequires access to the network's internal gradients, making its implementation more complex and dependent on the specific architecture. - Relevance:
MTAN'sDWAis inspired byGradNormbut simplifies the approach by only requiring numerical task losses.
- Concept:
-
Dynamic Task Prioritisation [10]:
- Concept: This method encourages prioritization of difficult tasks directly using performance metrics (e.g., accuracy, precision) instead of relying solely on task losses.
- Relevance: Another approach to tackle the loss balancing problem, which
MTANalso addresses with itsDWAand architectural robustness.
3.3. Technological Evolution
The evolution of Multi-Task Learning in deep learning for computer vision can be broadly traced as follows:
- Early Single-Task CNNs: Initially,
CNNswere predominantly designed and trained for single, specialized tasks (e.g., AlexNet for image classification). While highly successful, this led to a proliferation of models for complex vision systems. - Implicit Feature Sharing (Hard Sharing): The simplest form of
MTLinvolvedhard parameter sharing, where anencoderor initial layers of aCNNare shared across tasks, with separate task-specificheads(decoders or classifiers) at the end. This is efficient but offers limited flexibility in how features are shared. Many earlyMTLapproaches implicitly fall into this category. - Explicit Feature Sharing and Adaptation:
- Cross-Stitch Networks [20]: Introduced explicit units (
cross-stitch units) to learn how to linearly combine features between parallel, task-specific networks, allowing for more adaptive sharing than simple hard sharing. - Progressive Networks [26]: Focused on sequential learning and knowledge transfer, with new task networks built incrementally, connected to previous ones.
- Parameter-Efficient Adapters (e.g., Res. Adapt. [23], DAN [25], Piggyback [19], Parallel SVD [24]): Recent works, particularly around the
Visual Decathlon Challenge, explored methods to adapt a pre-trained network or efficiently add task-specific parameters (like adapters or masks) to avoid training a full network per task. These aimed for parameter efficiency and better generalization across many domains.
- Cross-Stitch Networks [20]: Introduced explicit units (
- Adaptive Loss Weighting: Simultaneously, research progressed on how to balance the losses of multiple tasks during training, moving from manual weighting to adaptive schemes like
Weight Uncertainty [14]andGradNorm [3]. - Attention-based Feature Selection (MTAN): This paper's work,
MTAN, represents an evolution towards more flexible and automaticfeature sharingthroughattention mechanisms. Instead of explicit cross-connections or separate networks,MTANusessoft attention masksto dynamically select and emphasize relevant features from a shared pool for each task. This combines the benefits of hard sharing (efficiency) with the flexibility of adaptive sharing, while also offering robustness to loss weighting.
3.4. Differentiation Analysis
Compared to the main methods in related work, MTAN presents several core differences and innovations:
-
Implicit vs. Explicit Feature Sharing:
Cross-Stitch Networks[20] andProgressive Networks[26]: These methods rely on explicit mechanisms (e.g.,cross-stitch units, lateral connections) between largely separate task-specific networks or incrementally added networks to share knowledge. This leads to a significant increase inparameter count, scaling linearly with the number of tasks.MTAN:MTANuses a single, compactshared networkthat forms aglobal feature pool.Feature-level attention masksare then learned for each task to implicitlyselectandweighfeatures from this shared pool. This is a more dynamic and adaptive form of sharing that avoids explicit redundant network structures.
-
Parameter Efficiency:
- Prior Methods (e.g.,
Cross-Stitch,Progressive Networks): Often suffer from highparameter countsas they effectively create or duplicate network components for each task. MTAN: By sharing a single backbone and adding only smallattention modulesper task,MTANis significantly moreparameter efficient. The paper claims only a "rough 10% increase in parameters per learning task," which is crucial for scaling to many tasks.
- Prior Methods (e.g.,
-
Nature of Task-Specific Components:
Multi-Task, SplitorMulti-Task, Densebaselines: These often split the network into task-specificheadsonly at the very end or provide all shared features to task-specific networks without explicit selection.MTAN: Introducessoft attention modulesat each convolution block of the shared network. These modules learn to "mask" or "select" features, allowing for fine-grained task-specific feature extraction throughout the network, not just at the final layers.
-
Robustness to Loss Weighting:
- Most
MTLmethods: Are sensitive to the choice ofloss weighting schemes, requiring extensivehyper-parameter tuning(as noted by [20, 14]). Adaptive methods likeWeight Uncertainty [14]andGradNorm [3]attempt to mitigate this. MTAN: The paper empirically demonstrates thatMTANexhibits an inherent robustness to differentweighting schemes(e.g., equal weights, uncertainty weights, DWA). This is a significant advantage, reducing a major pain point inMTLdeployment. The architecture's ability to automatically learn feature importance likely contributes to this robustness.
- Most
-
Simplicity of Adaptive Weighting (
DWA):-
GradNorm[3]: Requires access to internal networkgradientsfor its adaptive weighting scheme, which can be complex to implement across various architectures. -
DWA(proposed byMTAN): Offers a simpler alternative that only requires the numerical task losses, making it easier to integrate and use.In summary,
MTANdifferentiates itself by proposing a more elegant, parameter-efficient, and inherently robust way to handle feature sharing inMulti-Task Learningthrough the strategic application offeature-level attention.
-
4. Methodology
4.1. Principles
The core idea behind the Multi-Task Attention Network (MTAN) is to leverage attention mechanisms to enable highly flexible and efficient feature sharing in Multi-Task Learning (MTL). The intuition is that while many tasks can benefit from shared low-level visual features, each task also requires specific, discriminative features for optimal performance. Instead of explicitly separating network branches or using rigid sharing mechanisms, MTAN proposes a self-supervised approach where soft attention masks are learned for each task. These masks dynamically "select" or "emphasize" the most relevant features from a global feature pool maintained by a single shared network. This allows the network to automatically determine which features are task-shared and which are task-specific in an end-to-end manner, maximizing both generalization across tasks and individual task performance. Furthermore, by dynamically adjusting feature importance, the network becomes more resilient to the choices made in the multi-task loss function weighting.
4.2. Core Methodology In-depth (Layer by Layer)
The MTAN architecture is composed of two main parts: a single shared network and K task-specific attention networks. While the shared network can be any feed-forward neural network (e.g., SegNet for dense prediction, Wide Residual Network for classification), each task-specific network consists of attention modules that interact with the shared network.
Figure 2 (from the original paper) illustrates the MTAN architecture based on VGG-16 as the encoder half of SegNet. The decoder half is symmetric but with individually learned weights.
该图像是论文中的示意图,展示了多任务注意力网络(MTAN)架构,包括共享的卷积层和任务专用的注意力模块,以及编码器和解码器中的注意力模块结构。
rVisalsain MTAN basen GG-6, showg the encoer hal SeNe (with he decoer hal bei the same design, although their weights are individually learned.
4.2.1. Shared Network and Global Feature Pool
The shared network takes the input data (e.g., an image) and processes it through a series of convolutional blocks. Each block produces a set of feature maps that collectively form a global feature pool. These shared features are intended to capture general visual patterns useful across all tasks. For instance, in a SegNet-based MTAN, the VGG-16 encoder blocks form the shared network, learning hierarchical features.
4.2.2. Task-Specific Attention Modules
For each task (where ), there is a corresponding attention network made up of multiple attention modules. Each attention module is designed to learn a soft attention mask that is applied to the shared features at a particular layer (or block) of the shared network. These masks act as feature selectors.
Let's break down the computation within an attention module:
-
Feature Selection: The
shared featuresin the block of the shared network are denoted as . For task , the learnedattention maskin this layer is . Thetask-specific featuresfor task at layer , denoted as , are computed by anelement-wise multiplicationof theattention maskwith theshared features: $ \hat { a } _ { i } ^ { ( j ) } = a _ { i } ^ { ( j ) } \odot p ^ { ( j ) } $ Here:- represents the
global shared featuresoutput by the block of the shared network. - is the
soft attention masklearned for task at the layer. This mask has the same spatial dimensions and number of channels as , and its values are in the range[0, 1](due to asigmoid activation). - represents the
task-specific featuresfor task at layer , obtained byattendingto the shared features. - denotes
element-wise multiplication(also known as theHadamard product). This operation scales the values in based on the corresponding values in , effectively highlighting or suppressing features.
- represents the
-
Learning the Attention Mask: The
attention maskitself is learned. Its computation depends on the shared features and, for subsequent layers, also on the previously attended task-specific features.- For the first attention module in the encoder, the input is solely the
shared featuresfrom the corresponding shared network block. - For subsequent attention modules in block (where ), the input is formed by a
concatenationof the currentshared featuresand theprocessed task-specific featuresfrom the previous layer. Theattention maskfor task at block is computed as: $ a _ { i } ^ { ( j ) } = h _ { i } ^ { ( j ) } \left( g _ { i } ^ { ( j ) } \left( \left[ u ^ { ( j ) } ; f ^ { ( j ) } \left( \hat { a } _ { i } ^ { ( j - 1 ) } \right) \right] \right) \right) , j \geq 2 $ Here:-
represents the
shared featuresfrom the block of the shared network. -
are the
task-specific featuresfor task from the previous block(j-1). -
is a
convolutional layerwith kernels. It acts as ashared feature extractorto process and prepare it for concatenation. It includes apoolingorsampling layerto match the resolution of . -
denotes
concatenationalong the channel dimension. This means the current shared features and the processed previous task-specific features are combined as input to the next stage of the attention module. -
and are
convolutional layers(specifically with kernels). These layers learn the actualattention maskfor task in block from the concatenated features. -
A
sigmoid activation functionis applied after (though not explicitly shown in the formula, it's mentioned in the text: "The attention mask, following a sigmoid activation to ensure "). This activation squashes the output values to the range[0, 1], ensuring that the mask acts as a soft gating mechanism.This design allows the
attention masksto be learned in aself-supervisedfashion throughback-propagation. If anattention maskapproaches 1 for all its elements, it essentially acts as anidentity map, meaning theattended featuresbecome identical to theglobal shared features. This implies that the task is utilizing all shared features. Conversely, if parts of the mask approach 0, those features are suppressed, indicating they are not relevant for the task. This flexibility ensuresMTANcan perform no worse than a simple shared multi-task network.
-
- For the first attention module in the encoder, the input is solely the
4.2.3. Model Objective (Loss Function)
For Multi-Task Learning with tasks, input and task-specific labels for , the total loss function is defined as a linear combination of task-specific losses:
$
\mathcal { L } _ { t o t } ( { \mathbf { X } } , { \mathbf { Y } } _ { 1 : K } ) = \sum _ { i = 1 } ^ { K } \lambda _ { i } \mathcal { L } _ { i } ( { \mathbf { X } } , { \mathbf { Y } } _ { i } ) .
$
Here:
-
is the
total multi-task lossthat the network aims to minimize during training. -
represents the input data (e.g., an image).
-
denotes the set of ground-truth labels for all tasks.
-
is the
task-specific lossfor the task, comparing the network's prediction for task with its ground-truth label. -
is the
weighting coefficientfor the task. These weights determine the relative importance of each task's loss contribution to the total loss. The paper studies the effect of different weighting schemes for these .The paper specifies the task-specific loss functions used for
image-to-image predictiontasks:
-
Semantic Segmentation Loss (): For
semantic segmentation, apixel-wise cross-entropy lossis applied for each predicted class label from adepth-softmax classifier. This measures how well the predicted probability distribution over classes at each pixel matches the true class label. $ \mathcal { L } _ { 1 } ( \mathbf { X } , \mathbf { Y } _ { 1 } ) = - \frac { 1 } { p q } \sum _ { p , q } \mathbf { Y } _ { 1 } ( p , q ) \log \hat { \mathbf { Y } } _ { 1 } ( p , q ) . $ Here:- is the input image.
- is the ground-truth semantic segmentation map.
- is the network's predicted semantic segmentation map (specifically, the softmax probabilities for each class at each pixel).
p, qare indices representing the spatial coordinates (pixel locations) in the image.- The term is typically a one-hot encoded vector representing the true class at pixel
(p,q), while is the predicted probability distribution over classes at that pixel. The negative sum of calculates the cross-entropy. - is a normalization term, averaging the loss over all pixels.
-
Depth Estimation Loss (): For
depth estimation, anL1 norm(Mean Absolute Error) is used to compare the predicted depth map with the ground-truth depth map. The paper notes using true depth forNYUv2andinverse depthforCityScapes(which is common for outdoor scenes to better represent distant objects). $ \mathcal { L } _ { 2 } ( \mathbf { X } , \mathbf { Y } _ { 2 } ) = \frac { 1 } { p q } \sum _ { p , q } | \mathbf { Y } _ { 2 } ( p , q ) - \hat { \mathbf { Y } } _ { 2 } ( p , q ) | . $ Here:- is the input image.
- is the ground-truth depth (or inverse depth) map.
- is the network's predicted depth (or inverse depth) map.
p, qare indices representing the spatial coordinates (pixel locations).- calculates the absolute difference between the true and predicted depth at each pixel.
- is a normalization term, averaging the loss over all pixels.
-
Surface Normal Prediction Loss (): For
surface normals(available inNYUv2), anelement-wise dot productis applied at each normalized pixel with the ground-truth normal map. The goal is to maximize the cosine similarity between predicted and true normal vectors, so the negative dot product is minimized. $ \mathcal { L } _ { 3 } ( \mathbf { X } , \mathbf { Y } _ { 3 } ) = - \frac { 1 } { p q } \sum _ { p , q } \mathbf { Y } _ { 3 } ( p , q ) \cdot \hat { \mathbf { Y } } _ { 3 } ( p , q ) . $ Here:-
is the input image.
-
is the ground-truth surface normal map, where each pixel contains a 3D vector representing the normal. These vectors are typically normalized.
-
is the network's predicted surface normal map, also with normalized 3D vectors per pixel.
-
p, qare indices representing the spatial coordinates (pixel locations). -
is the
dot productbetween the true and predicted normal vectors at pixel(p,q). For normalized vectors, this equals the cosine of the angle between them. A value of 1 means perfect alignment, -1 means opposite. -
The negative sign ensures that minimizing the loss means maximizing the dot product (i.e., minimizing the angle between predicted and true normals).
-
is a normalization term, averaging the loss over all pixels.
For
image classification tasks, the paper states thatstandard cross-entropy lossis applied, without providing a specific formula, assuming it is a well-known concept.
-
4.2.4. Dynamic Weight Average (DWA)
To address the challenge of loss weighting in MTL, the paper proposes Dynamic Weight Average (DWA). This method adaptively adjusts task weights over time by considering the rate of change of the loss for each task. It is inspired by GradNorm [3] but is simpler as it only requires numerical task losses.
With DWA, the weighting for task at iteration (or epoch ) is defined as:
$
\lambda _ { k } ( t ) : = \frac { K \exp ( w _ { k } ( t - 1 ) / T ) } { \sum _ { i } \exp ( w _ { i } ( t - 1 ) / T ) } , w _ { k } ( t - 1 ) = \frac { \mathscr { L } _ { k } ( t - 1 ) } { \mathscr { L } _ { k } ( t - 2 ) } .
$
Here:
-
is the
weighting coefficientfor task at the current training iteration (or epoch) . -
is the total number of tasks.
-
is a term that calculates the
relative descending rateof the loss for task between the previous two iterations/epochs. It is defined as the ratio of the loss for task att-1to the loss att-2.- is the average loss for task at iteration/epoch
t-1. - is the average loss for task at iteration/epoch
t-2. - A higher (i.e., less decrease or even increase in loss) indicates that task is struggling or its loss is decreasing slowly, suggesting it might need more attention.
- is the average loss for task at iteration/epoch
-
is a
temperature hyper-parameterthat controls the softness of task weighting.- A large results in a more even distribution of weights across different tasks, making for all tasks (equal weighting).
- A smaller makes the weighting more sensitive to differences in , assigning higher weights to tasks that are decreasing slower.
-
The exponential term amplifies the weight for tasks with higher .
-
The denominator normalizes these exponential terms.
-
The multiplication by ensures that , meaning the sum of weights remains constant across iterations.
In practice, is calculated as the average loss over an entire epoch to reduce noise from stochastic gradient descent. For the first two iterations (), is typically initialized to 1, implying equal initial importance for all tasks.
5. Experimental Setup
5.1. Datasets
The paper evaluates MTAN on a variety of datasets across both image-to-image predictions and image classification tasks to demonstrate its versatility and performance.
5.1.1. CityScapes
-
Source & Characteristics: The
CityScapes dataset[4] consists of high-resolution street-view images, primarily from urban driving scenarios. It provides rich annotations for various computer vision tasks. -
Tasks: Used for two
image-to-image predictiontasks:Semantic Segmentation: The dataset contains 19 classes for pixel-wise semantic segmentation. For experimental flexibility, the authors also created coarser versions with 7 and 2 classes (excluding thevoidgroup for 7 and 19 classes). The details of these segmentation classes are presented in Table 1 below.Depth Estimation: Ground-truthinverse depth labelsare provided. Inverse depth is often used in outdoor scenes because it can more easily represent points at infinite distances (like the sky) and provides better numerical stability for distant objects.
-
Scale & Preprocessing: All training and validation images were resized to to speed up training.
-
Choice Justification: A standard benchmark for urban scene understanding, suitable for evaluating dense prediction tasks in complex outdoor environments. Using different numbers of semantic classes allows for testing the method's performance under varying
task complexities.The following are the results from Table 1 of the original paper:
2-class 7-class 19-class background void void flat road, sidewalk construction building, wall, fence object pole, traffic light, trafficsgn nature vegetation, terrain sky sky foreground human person, rider vehicle carm truck, bus, caravan, trailer, train, motorcycle
5.1.2. NYUv2
- Source & Characteristics: The
NYUv2 dataset[21] consists ofRGB-D(color and depth) indoor scene images. It captures a more diverse range of viewpoints, lighting conditions, and object appearances compared toCityScapes. - Tasks: Used for three
image-to-image predictiontasks:13-class Semantic Segmentation: Defined in [5].Depth Estimation: Uses true depth data recorded byMicrosoft Kinect.Surface Normal Prediction: Ground-truth surface normals are provided in [7].
- Scale & Preprocessing: All training and validation images were resized to resolution.
- Choice Justification: Presents a more challenging indoor environment compared to
CityScapes, with greater variability in scenes and objects. This allows for a comprehensive understanding of how the proposed method behaves and scales under complex scenarios and with more tasks (three tasks vs. two forCityScapes).
5.1.3. Visual Decathlon Challenge
- Source & Characteristics: A recently proposed benchmark [23] consisting of 10 individual
image classification tasks. It challenges models to learn across diverse visual domains. The datasets involved are diverse, ranging fromImageNetto smaller, specialized classification tasks. - Tasks: 10
image classification tasks, treated as amany-to-many predictionscenario. - Choice Justification: A highly competitive benchmark designed to test the limits of
multi-domainandmulti-task learningin terms of generalization and efficiency across a large number of disparate tasks. This allows the evaluation ofMTAN's scalability and effectiveness beyond dense prediction tasks.
5.2. Evaluation Metrics
For each task, the paper uses standard evaluation metrics:
5.2.1. Semantic Segmentation
-
Mean Intersection over Union (mIoU):
- Conceptual Definition:
mIoUis a common metric for semantic segmentation. It quantifies the similarity between the predicted segmentation map and the ground-truth map. For each class, it calculates theIntersection over Union (IoU)(also known as theJaccard index), which is the area of overlap between the predicted segmentation and the ground-truth divided by the area of their union. ThemIoUis then the averageIoUacross all classes. A highermIoUindicates better segmentation quality. - Mathematical Formula: $ \mathrm{IoU}_k = \frac{\mathrm{TP}_k}{\mathrm{TP}_k + \mathrm{FP}_k + \mathrm{FN}k} $ $ \mathrm{mIoU} = \frac{1}{C} \sum{k=1}^{C} \mathrm{IoU}_k $
- Symbol Explanation:
- : Intersection over Union for class .
- : Number of
True Positivesfor class (pixels correctly predicted as class ). - : Number of
False Positivesfor class (pixels incorrectly predicted as class ). - : Number of
False Negativesfor class (pixels of class incorrectly predicted as another class). - : Total number of classes.
- Conceptual Definition:
-
Pixel Accuracy (Pix Acc):
- Conceptual Definition:
Pixel Accuracyis the simplest metric for segmentation, calculating the proportion of pixels that are correctly classified across all classes. It gives an overall sense of correct pixel predictions but can be misleading if classes are highly imbalanced (e.g., a large background class might dominate the score). - Mathematical Formula: $ \mathrm{Pix Acc} = \frac{\sum_{k=1}^{C} \mathrm{TP}k}{\sum{k=1}^{C} \mathrm{TP}k + \sum{k=1}^{C} \mathrm{FP}k + \sum{k=1}^{C} \mathrm{FN}_k} $
- Symbol Explanation:
- : Number of
True Positivesfor class . - : Number of
False Positivesfor class . - : Number of
False Negativesfor class . - : Total number of classes.
- The denominator essentially represents the total number of pixels in the image.
- : Number of
- Conceptual Definition:
5.2.2. Depth Estimation
-
Absolute Error (Abs Err):
- Conceptual Definition:
Absolute Error(often referred to as Mean Absolute Error) calculates the average of the absolute differences between each predicted depth value and its corresponding ground-truth depth value. It provides a measure of the average magnitude of error, regardless of direction. A lower value is better. - Mathematical Formula: $ \mathrm{Abs Err} = \frac{1}{N} \sum_{i=1}^{N} |y_i - \hat{y}_i| $
- Symbol Explanation:
- : Total number of pixels with valid depth values.
- : Ground-truth depth value for pixel .
- : Predicted depth value for pixel .
- : Absolute value.
- Conceptual Definition:
-
Relative Error (Rel Err):
- Conceptual Definition:
Relative Error(often Mean Relative Error) measures the average of the absolute differences between predicted and true depth values, scaled by the true depth value. This metric is useful because it normalizes errors by the magnitude of the depth, making it more informative across different depth ranges. A lower value is better. - Mathematical Formula: $ \mathrm{Rel Err} = \frac{1}{N} \sum_{i=1}^{N} \frac{|y_i - \hat{y}_i|}{y_i} $
- Symbol Explanation:
- : Total number of pixels with valid depth values.
- : Ground-truth depth value for pixel .
- : Predicted depth value for pixel .
- : Absolute value.
- Conceptual Definition:
5.2.3. Surface Normal Prediction
-
Angle Distance (Mean, Median):
- Conceptual Definition: For surface normals, the
angle distance(orangular error) is a direct measure of the angle in degrees between the predicted 3D normal vector and the ground-truth 3D normal vector at each pixel. A perfect prediction would have an angle distance of 0. TheMeanandMedianangle distanceare then computed over all valid pixels in the image. Lower values are better. - Mathematical Formula: $ \mathrm{Angle}( \mathbf{n}, \hat{\mathbf{n}} ) = \arccos \left( \frac{\mathbf{n} \cdot \hat{\mathbf{n}}}{|\mathbf{n}| \cdot |\hat{\mathbf{n}}|} \right) \times \frac{180}{\pi} $
- Symbol Explanation:
- : Ground-truth normalized surface normal vector at a pixel.
- : Predicted normalized surface normal vector at a pixel.
- : Dot product.
- : Magnitude of the vector.
- : Inverse cosine function, which returns the angle in radians.
- : Conversion factor from radians to degrees.
- Conceptual Definition: For surface normals, the
-
Within (Accuracy):
- Conceptual Definition: These metrics measure the percentage of pixels for which the
angle distancebetween the predicted and ground-truth normal vectors is within a certain threshold (, , and in this paper). Higher values indicate better performance. - Mathematical Formula: $ \mathrm{Within } , X^\circ = \frac{1}{N} \sum_{i=1}^{N} \mathbb{I}(\mathrm{Angle}(\mathbf{n}_i, \hat{\mathbf{n}}_i) \leq X^\circ) \times 100% $
- Symbol Explanation:
- : Total number of pixels with valid normal vectors.
- : Ground-truth normal vector for pixel .
- : Predicted normal vector for pixel .
- : The angle distance in degrees between the two vectors for pixel .
- : The angular threshold (, , or ).
- : Indicator function, which is 1 if the condition is true, and 0 otherwise.
- Conceptual Definition: These metrics measure the percentage of pixels for which the
5.2.4. Image Classification
-
Accuracy:
- Conceptual Definition:
Accuracyis the proportion of correctly classified samples (images in this case) out of the total number of samples. It's a fundamental metric for classification tasks. Higher values are better. - Mathematical Formula: $ \mathrm{Accuracy} = \frac{\text{Number of Correct Predictions}}{\text{Total Number of Predictions}} $
- Symbol Explanation:
- "Number of Correct Predictions": The count of samples where the model's predicted class matches the true class.
- "Total Number of Predictions": The total count of samples being evaluated.
- Conceptual Definition:
-
Cumulative Score (Visual Decathlon Challenge):
- Conceptual Definition: For the
Visual Decathlon Challenge, a cumulative score is assigned across all 10 classification tasks. The maximum possible score is 10,000 (1,000 points per task), based on a specific scoring function that typically normalizes task accuracies and sums them up. The exact scoring function is complex and defined by the challenge itself, but the goal is to maximize this combined score.
- Conceptual Definition: For the
5.3. Baselines
For image-to-image prediction tasks, all baselines were implemented using the SegNet [1] architecture for fair comparison. For the Visual Decathlon Challenge, several state-of-the-art multi-domain and multi-task learning methods specific to that challenge are used as baselines.
5.3.1. Baselines for Image-to-Image Predictions (SegNet-based)
The paper implemented 5 different network architectures (2 single-task, 3 multi-task) based on SegNet:
- Single-Task, One Task: This is the baseline
vanilla SegNetspecifically trained for a single task (e.g., only semantic segmentation or only depth estimation). It serves as a reference for single-task performance without any multi-task benefits. - Single-Task, STAN (Single-Task Attention Network): This baseline applies the proposed
MTANarchitecture but only performs a single task. This helps isolate the effect of theattention modulesthemselves, ensuring that any performance gain isn't just from having more parameters. - Multi-Task, Split (Wide, Deep): This represents the standard
multi-task learningapproach where a shared encoder is used, but the networksplitsinto separate task-specific decoders or prediction heads at the last shared layer. To ensure a fair comparison regarding model capacity, two versions were implemented:- Wide: The number of
convolutional filterswas adjusted to increase parameters. - Deep: The number of
convolutional layerswas adjusted to increase parameters. - Both versions were tuned until they had at least as many parameters as
MTAN, to validate thatMTAN's performance isn't simply due to an increase in network size.
- Wide: The number of
- Multi-Task, Dense: In this architecture, a
shared network(encoder) is used, and then for each task, a dedicatedtask-specific network(decoder) receivesall featuresfrom the shared network. Crucially, it does not employ anyattention modulesor selective feature access. This baseline helps to demonstrate the specific benefit ofMTAN's attention mechanism over simply passing all shared features. - Multi-Task, Cross-Stitch [20]: This is a previously proposed
adaptive multi-task learningapproach. The authors implementedCross-Stitch Networkson theSegNetarchitecture to provide a direct, fair comparison with a known state-of-the-artMTLmethod that focuses on explicit feature sharing.
5.3.2. Baselines for Visual Decathlon Challenge
For the Visual Decathlon Challenge, the paper compares MTAN against several competitive baselines that often participate in the challenge:
- Scratch [23]: Training a separate
Wide Residual Networkfrom scratch for each of the 10 tasks. - Finetune [23]: Finetuning a single
Wide Residual Network(pre-trained onImageNet) separately for each of the 10 tasks. - Feature [23]: Using a single
Wide Residual Networkpre-trained onImageNetas a fixed feature extractor, and training linear classifiers on top for each task. - Res. Adapt. [23] (Residual Adapters): A method that adds small, task-specific
residual adaptermodules to a shared backbone network to adapt it to multiple domains. - DAN [25] (Deep Adaptation Network): A method focusing on incremental learning and deep adaptation.
- Piggyback [19]: Adapting a single network to multiple tasks by learning
binary masksthat select which weights to use for each task. - Parallel SVD [24]: An efficient parametrization method for multi-domain deep neural networks, often using
Singular Value Decompositionideas.
5.4. Training Details
- Weighting Methods: For
image-to-image predictiontasks, experiments were run with three types ofloss weighting methods:- Equal Weighting: All are set to 1.
- Weight Uncertainty [14]: An adaptive method that weighs losses based on task uncertainty.
- Dynamic Weight Average (DWA): The authors' proposed adaptive weighting method (with
temperature). GradNorm[3] was excluded from comparison due to its architecture-specific implementation requirements, which would complicate a fair baseline evaluation.
- Optimizer:
ADAM optimizer[15] was used for all models. - Learning Rate: Initial learning rate of .
- Batch Size: 2 for
NYUv2dataset, 8 forCityScapesdataset. - Learning Rate Schedule: The learning rate was halved at
40k iterationsfor a total of80k iterationsforimage-to-image predictiontasks. - Visual Decathlon Specifics:
MTANwas built on aWide Residual Network[31] (depth 28, widening factor 4, stride 2 in first conv layer of each block). Training used a batch size of 100,SGDoptimizer with a learning rate of 0.1, and weight decay of for 300 epochs (learning rate halved every 50 epochs). A fine-tuning stage followed for 9 tasks (all exceptImageNet) with a learning rate of 0.01 until convergence.
6. Results & Analysis
6.1. Core Results Analysis
The experimental results demonstrate MTAN's effectiveness across various tasks and datasets, highlighting its state-of-the-art performance, parameter efficiency, and robustness to loss weighting schemes.
6.1.1. Image-to-Image Predictions on CityScapes and NYUv2
The image-to-image prediction experiments used SegNet as the backbone architecture and were evaluated on the 7-class CityScapes dataset and the 13-class NYUv2 dataset.
The following are the results from Table 2 of the original paper:
| #P. | Architecture | Weighting | Segmentation | Depth | ||
| (Higher Better) | (Lower Better) | |||||
| mIoU | Pix Acc | Abs Err | Rel Err | |||
| 2 | One Task | n.a. | 51.09 | 90.69 | 0.0158 | 34.17 |
| 3.04 | STAN | n.a. | 51.90 | 90.87 | 0.0145 | 27.46 |
| 1.75 | Split, Wide | Equal Weights | 50.17 | 90.63 | 0.0167 | 44.73 |
| Uncert. Weights [14] | 51.21 | 90.72 | 0.0158 | 44.01 | ||
| DWA, T = 2 | 50.39 | 90.45 | 0.0164 | 43.93 | ||
| 2 | Split, Deep | Equal Weights | 49.85 | 88.69 | 0.0180 | 43.86 |
| Uncert. Weights [14] | 48.12 | 88.68 | 0.0169 | 39.73 | ||
| DWA, T = 2 | 49.67 | 88.81 | 0.0182 | 46.63 | ||
| 3.63 | Dense | Equal Weights | 51.91 | 90.89 | 0.0138 | 27.21 |
| Uncert. Weights [14] | 51.89 | 91.22 | 0.0134 | 25.36 | ||
| DWA, T = 2 | 51.78 | 90.88 | 0.0137 | 26.67 | ||
| ≈2 | Cross-Stitch [20] | Equal Weights | 50.08 | 90.33 | 0.0154 | 34.49 |
| Uncert. Weights [14] | 50.31 | 90.43 | 0.0152 | 31.36 | ||
| DWA, T = 2 | 50.33 | 90.55 | 0.0153 | 33.37 | ||
| 1.65 | MTAN (Ours) | Equal Weights | 53.04 | 91.11 | 0.0144 | 33.63 |
| Uncert. Weights [14] | 53.86 | 91.10 | 0.0144 | 35.72 | ||
| DWA, T = 2 | 55.29 | 91.90 | 0.0144 | 34.14 | ||
The following are the results from Table 3 of the original paper:
| Type | #P. | Architecture | Weighting | Segmentation | Depth | Surface Normal | ||||||
| (Higher Better) mIoU | Pix Acc | (Lower Better) Abs Err | Rel Err | (Lower Better) Angle Distance Mean | Median | Within (Higher Better) 11.25 | 22.5 | 30 | ||||
| Single Task | 3 | One Task | n.a. | 15.10 | 51.54 | 0.7508 | 0.3266 | 31.76 | 25.51 | 22.12 | 45.33 | 57.13 |
| 4.56 | STAN | n.a. | 15.73 | 52.89 | 0.6935 | 0.2891 | 32.09 | 26.32 | 21.49 | 44.38 | 56.51 | |
| Multi Task | 1.75 | Split, Wide | Equal Weights | 15.89 | 51.19 | 0.6494 | 0.2804 | 33.69 | 28.91 | 18.54 | 39.91 | 52.02 |
| Uncert. Weights [14] | 15.86 | 51.12 | 0.6040 | 0.2570 | 32.33 | 26.62 | 21.68 | 43.19 | 55.36 | |||
| DWA, T = 2 | 16.92 | 53.72 | 0.6125 | 0.2546 | 32.34 | 27.10 | 20.69 | 42.73 | 54.74 | |||
| 2 | Split, Deep | Equal Weights | 13.03 | 41.47 | 0.7836 | 0.3326 | 38.28 | 36.55 | 9.50 | 27.11 | 39.63 | |
| Uncert. Weights [14] | 14.53 | 43.69 | 0.7705 | 0.3340 | 35.14 | 32.13 | 14.69 | 34.52 | 46.94 | |||
| DWA, T = 2 | 13.63 | 44.41 | 0.7581 | 0.3227 | 36.41 | 34.12 | 12.82 | 31.12 | 43.48 | |||
| 4.95 | Dense | Equal Weights | 16.06 | 52.73 | 0.6488 | 0.2871 | 33.58 | 28.01 | 20.07 | 41.50 | 53.35 | |
| Uncert. Weights [14] | 16.48 | 54.40 | 0.6282 | 0.2761 | 31.68 | 25.68 | 21.73 | 44.58 | 56.65 | |||
| DWA, T = 2 | 16.15 | 54.35 | 0.6059 | 0.2593 | 32.44 | 27.40 | 20.53 | 42.76 | 54.20 | |||
| ≈3 | Cross-Stitch [20] | Equal Weights | 14.71 | 50.23 | 0.6481 | 0.2871 | 33.56 | 28.58 | 20.08 | 40.54 | 51.97 | |
| Uncert. Weights [14] | 15.69 | 52.60 | 0.6277 | 0.2702 | 32.69 | 27.26 | 21.63 | 42.84 | 54.45 | |||
| DWA, T = 2 | 16.11 | 53.19 | 0.5922 | 0.2611 | 32.34 | 26.91 | 21.81 | 43.14 | 54.92 | |||
| 1.77 | MTAN (Ours) | Equal Weights | 17.72 | 55.32 | 0.5906 | 0.2577 | 31.44 | 25.37 | 23.17 | 45.65 | 57.48 | |
| Uncert. Weights [14] | 17.67 | 55.61 | 0.5927 | 0.2592 | 31.25 | 25.57 | 22.99 | 45.83 | 57.67 | |||
| DWA, T = 2 | 17.15 | 54.97 | 0.5956 | 0.2569 | 31.60 | 25.46 | 22.48 | 44.86 | 57.24 | |||
Analysis:
-
CityScapes (Table 2):
MTANgenerally performs very well. WithDWAweighting,MTANachieves the bestmIoU(55.29) andPix Acc(91.90) for segmentation. For depth, itsAbs Err(0.0144) is slightly higher thanDense(0.0134) but still very competitive, especially considering its parameter count. -
NYUv2 (Table 3): This dataset is noted as more challenging. Here,
MTANclearly outperforms all baselines across all three tasks (semantic segmentation,depth estimation, andsurface normal prediction) and across all weighting methods. For instance,MTANachieves the highestmIoU(17.72 with Equal Weights), lowestAbs Err(0.5906 with Equal Weights), and lowestMean Angle Distance(31.25 with Uncertainty Weights), while also having the highestWithin Xaccuracies for surface normals. -
Parameter Efficiency: A crucial advantage highlighted is
MTAN'sparameter efficiency. ForCityScapes,MTANhas 1.65 million parameters, which is less than half ofDense(3.63 million) and comparable toSplit, Wide(1.75 million) andSplit, Deep(2 million) (though the table's notation for Split parameters seems inconsistent, withSplit, Widebeing 1.75M andSplit, Deepbeing 2M for CityScapes, but 1.75M and 2M in NYUv2 forSplit, WideandSplit, Deeprespectively). Despite similar or fewer parameters,MTANconsistently achieves better performance, especially on the harderNYUv2dataset. This indicates that theattention moduleseffectively allow for more intelligent feature sharing and selection without significantly increasing model complexity. -
Robustness to Weighting Schemes: The tables demonstrate that
MTANmaintains strong performance acrossEqual Weights,Uncertainty Weights, andDWA. While the best performance might shift slightly between weighting schemes (e.g.,mIoUforNYUv2is highest withEqual WeightsforMTAN), the variations are smaller compared to baselines likeSplit, DeeporCross-Stitch Networks, which show more drastic performance fluctuations with different weighting schemes. This confirmsMTAN's claim of being less sensitive toloss weighting.The paper provides qualitative results that further illustrate the
MTAN's performance. The following figure (Figure 4 from the original paper) shows qualitative results on theCityScapesvalidation dataset forsemantic segmentationanddepth estimation.
该图像是多任务学习语义分割和深度估计的定性对比示意图。图中展示了输入图像、真实标签以及Vanilla单任务学习和多任务注意力网络(MTAN)的预测结果,突出表现MTAN在边界和细节处的改进效果。
effectiveness of the results provided from our method and single task method.
Analysis of Figure 4: The visual comparison shows MTAN produces sharper edges and more accurate object boundaries for semantic segmentation and generally more detailed depth estimation compared to vanilla Single-Task learning. This suggests that multi-task learning with MTAN helps in learning richer and more robust features that benefit both tasks, leading to higher quality predictions.
6.1.2. Robustness to Loss Function Weighting Schemes
The following figure (Figure 3 from the original paper) plots validation performance curves on the NYUv2 dataset for semantics, depth, and normals (from left to right) for Cross-Stitch Network (top) and MTAN (bottom) across different weighting schemes.
该图像是一个示意图,展示了不同重量策略下模型在多个任务上的学习曲线,包括语义精确度、深度绝对误差和法线误差。不同的曲线分别代表了均等权重(红色)、动态权重调整(蓝色)和权重不确定性(绿色)的策略。X轴为训练轮次,Y轴为相应的指标值。
Figure 3: Validation performance curves on the NYUv2 dataset, across all three tasks (semantics, depth, normals, from left to right), showing robustness to loss function weighting schemes on the Cross-Stitch Network [20] (top) and our Multi-task Attention Network (bottom).
Analysis of Figure 3:
- Cross-Stitch Network (Top Row): The learning curves for
Cross-Stitch Networkshow notable differences in behavior across the threeloss weighting schemes(Equal Weights,Uncert. Weights,DWA). For instance,Equal Weightsmight lead to faster convergence for one task but poorer performance or instability for another. - MTAN (Bottom Row): In contrast,
MTANexhibits much more consistent learning trends across all three weighting schemes for each task. The performance curves forsemantics,depth, andnormalslargely overlap, indicating thatMTAN's performance is stable regardless of whether equal weights, uncertainty-based weights, orDWAis used. This provides strong visual evidence forMTAN's inherent robustness to the choice ofloss function weighting scheme.
6.1.3. Effect of Task Complexity
The paper also analyzed the effect of task complexity by varying the number of semantic classes for CityScapes (2, 7, or 19 classes) while keeping the depth task constant. All networks were trained with equal weighting. The results are presented as performance improvement relative to vanilla Single-Task learning.
The following are the results from Table 4 (left part) of the original paper, which is presented as a graph in the paper. I'll summarize the key findings as the graph itself is not directly transcribable as a table.
- 2-class setup: The
Single-Task Attention Network (STAN)performs better than allmulti-task methods. This suggests that for very simple tasks, dedicating all network parameters to a single task (even with attention) is more effective than attempting to share features, as there might not be sufficient benefit from sharing. - Increased Task Complexity (7-class and 19-class): As the number of semantic classes increases (making the segmentation task more complex), the
multi-task methods(includingMTAN,Split,Dense,Cross-Stitch) show significantly greater performance improvements relative toSingle-Tasklearning. This indicates thatmulti-task learningbecomes more beneficial for complex tasks by encouragingfeature sharingand more efficient parameter utilization. - MTAN's Scaling: Crucially,
MTAN's relative performance gain increases at a greater rate than othermulti-task implementationsastask complexityrises. This suggests that itsattention mechanismbecomes even more effective at discerning and leveraging useful shared features when tasks are more intricate and potentially benefit more from selective sharing.
6.1.4. Attention Masks as Feature Selectors
The paper visualizes the first layer attention masks learned by MTAN on the CityScapes dataset (7-class semantic segmentation and depth estimation).
The following figure (Figure 5 from the original paper) shows visualizations of these attention masks.
该图像是示意图,展示了输入图像、语义掩码、语义特征、共享特征、深度掩码和深度特征的可视化。上半部分与下半部分分别展示了不同输入图像的处理结果,这些特征的学习和共享反映了多任务学习的有效性。
Figure 5: Visualisation of the first layer of 7-class semantic and depth attention features of our proposed network. The colours for each image are rescaled to fit the data.
Analysis of Figure 5: The visualizations show clear differences between the attention masks learned for semantic segmentation and depth estimation.
- The
depth masksoften exhibit higher contrast, with certain regions being strongly activated (brighter) while others are suppressed (darker). This implies that fordepth estimation, the network focuses more acutely on specific features (e.g., edges, textures related to depth cues) and masks out less informative parts of the shared features. - The
semantic masksappear to have a more uniform distribution of attention across the feature maps, though still showing some variation. This could suggest that forsemantic segmentation, a broader range of shared features is generally useful, and the task benefits from a more holistic view of the input. This visualization confirms that theattention modulesindeed act asfeature selectors, automatically learning to highlighttask-specificrelevant information from theglobal feature poolfor each task.
6.1.5. Visual Decathlon Challenge (Many-to-Many)
MTAN was also evaluated on the Visual Decathlon Challenge, which involves 10 image classification tasks.
The following are the results from Table 4 (right part) of the original paper:
| Method | #P. | ImNet. | Airc. | C100 | DPed | DTD | GTSR | Flwr | Oglt | SVHN | UCF | Mean Score | |
| Scratch [23] | 10 | 59.87 | 57.10 | 75.73 | 91.20 | 37.77 | 96.55 | 56.3 | 88.74 | 96.63 | 43.27 | 70.32 | 1625 |
| Finetune [23] | 10 | 59.87 | 60.34 | 82.12 | 92.82 | 55.53 | 97.53 | 81.41 | 87.69 | 96.55 | 51.20 | 76.51 | 2500 |
| Feature [23] | 1 | 59.67 | 23.31 | 63.11 | 80.33 | 45.37 | 68.16 | 73.69 | 58.79 | 43.54 | 26.8 | 54.28 | 544 |
| Res. Adapt.[23] | 2 | 59.67 | 56.68 | 81.20 | 93.88 | 50.85 | 97.05 | 66.24 | 89.62 | 96.13 | 47.45 | 73.88 | 2118 |
| DAN [25] | 2.17 | 57.74 | 64.12 | 80.07 | 91.30 | 56.54 | 98.46 | 86.05 | 89.67 | 96.77 | 49.38 | 77.01 | 2851 |
| Piggyback [19] | 1.28 | 57.69 | 65.29 | 79.87 | 96.99 | 57.45 | 97.27 | 79.09 | 87.63 | 97.24 | 47.48 | 76.60 | 2838 |
| Parallel SVD [24] | 1.5 | 60.32 | 66.04 | 81.86 | 94.23 | 57.82 | 99.24 | 85.74 | 89.25 | 96.62 | 52.50 | 78.36 | 3398 |
| MTAN (Ours) | 1.74 | 63.90 | 61.81 | 81.59 | 91.63 | 56.44 | 98.80 | 81.04 | 89.83 | 96.88 | 50.63 | 77.25 | 2941 |
Analysis:
- Competitive Performance:
MTAN(with 1.74 million parameters, or 1.74x the baseline parametersImNetwhich has 63.90) achieves a cumulative score of 2941, which surpasses most of the baselines (Scratch,Finetune,Feature,Res. Adapt.,Piggyback,DAN) and is competitive withParallel SVD(3398). - Task-Specific Performance:
MTANshows strong performance on individual tasks, achieving, for example, 63.90% onImageNet, 98.80% onGTSR, and 96.88% onSVHN. WhileParallel SVDachieves the highest overall score,MTANis very close to the top, doing well without resorting to complicated regularization strategies (likeDropout, dataset regrouping, or adaptive weight decay) that other methods might require. This underscores its simplicity and inherent robustness. - Parameter Efficiency:
MTANachieves this strong performance with a relatively low number of additional parameters (1.74x the baseline, which is comparable toRes. Adapt.at 2x andDANat 2.17x, and lower thanScratchandFinetuneat 10x). This again validates its parameter-efficient design for multi-domain learning.
6.2. Ablation Studies / Parameter Analysis
While the paper doesn't present a dedicated "ablation study" section with individual component removal, it implicitly performs several forms of analysis on the method's components and parameters:
-
Single-Task Attention Network (STAN) vs. One Task: The
STANbaseline effectively serves as an ablation to show the benefit of theattention modulesthemselves, even in a single-task context. From Table 2 (CityScapes),STAN(51.90 mIoU) performs better thanOne Task(51.09 mIoU) for semantic segmentation, and also for depth, suggesting the attention mechanism provides some benefit even without explicit multi-task interaction. -
Comparison to Multi-Task, Dense: This baseline directly evaluates the benefit of
attention modulesforfeature selectionin amulti-tasksetting.Denseprovides all shared features to task-specific networks without attention.MTANconsistently outperformsDenseonNYUv2(Table 3) and achieves comparable performance onCityScapes(Table 2) with significantly fewer parameters. This highlights the crucial role of theattention masksin intelligently selecting features rather than simply using all shared features. -
Effect of Task Complexity (Section 4.1.5): By evaluating
MTANonCityScapeswith 2, 7, and 19 semantic classes, the authors implicitly analyze how the method scales with task difficulty. The observation thatMTAN's performance gain increases at a greater rate with complexity (compared to otherMTLmethods) suggests that itsattention mechanismis particularly effective when tasks demand more sophisticated feature differentiation and sharing. -
Robustness to Weighting Schemes (Section 4.1.4 and Figure 3): This extensive comparison across
Equal Weights,Uncertainty Weights[14], andDWA(with ) for all baselines andMTANacts as a sensitivity analysis. The finding thatMTANis less sensitive to these choices compared to methods likeCross-Stitch Networksindicates a robust architectural design, where the learned attention implicitly helps balance task learning dynamics. -
Visualization of Attention Masks (Section 4.1.6, Figure 5): The visualization directly inspects the learned
attention modules. Showing distinctattention masksfor different tasks (segmentation vs. depth) demonstrates that the modules are indeed learningtask-specific feature selectors, providing an interpretable view of howMTANachieves its performance.Regarding the
temperatureparameter forDWA, the paper states it was "found empirically to be optimum across all architectures" at . This indicates some parameter tuning was performed forDWAitself.
7. Conclusion & Reflections
7.1. Conclusion Summary
The paper successfully introduces the Multi-Task Attention Network (MTAN), a novel multi-task learning architecture. MTAN's core innovation lies in its use of task-specific feature-level attention modules operating on a single global feature pool shared across tasks. This design enables the network to automatically learn both task-shared and task-specific features in an end-to-end manner. The experimental results consistently demonstrate that MTAN achieves state-of-the-art or competitive performance on a diverse set of image-to-image prediction tasks (semantic segmentation, depth estimation, surface normal prediction on CityScapes and NYUv2) and image classification tasks (Visual Decathlon Challenge). A significant finding is MTAN's robustness to different loss function weighting schemes, reducing the need for tedious hyper-parameter tuning. Furthermore, due to its efficient use of attention masks for weight sharing, MTAN is highly parameter efficient, making it scalable and practical for real-world applications. The proposed Dynamic Weight Average (DWA) also offers a simple yet effective adaptive weighting strategy.
7.2. Limitations & Future Work
The paper does not contain an explicit "Limitations" or "Future Work" section, but some points can be inferred:
- Complexity of Attention Modules: While
MTANis parameter-efficient, the addition ofattention modulesat eachconvolutional blockmight introduce some computational overhead during inference compared to simplerhard parameter sharingmodels, though the paper emphasizes efficiency. - Generalizability of DWA:
DWAis presented as a simpler alternative toGradNorm[3] that doesn't require gradient access. While effective,GradNormmight offer more fine-grained control over training dynamics due to its use of gradient norms. Future work could explore if more sophisticated adaptive weighting schemes, potentially combining insights fromGradNormwithMTAN's architecture, could yield further performance gains. - Scaling to Extremely Large Number of Tasks: The
Visual Decathlon Challengeinvolves 10 tasks. WhileMTANshows good scalability there, evaluating its performance and parameter efficiency for hundreds or thousands of tasks (which might be relevant for extremely broad AI systems) could be a direction for future research. - Interpretability of Attention: While the attention masks are visualized, a deeper quantitative analysis of why certain features are attended to, and how this correlates with task difficulty or relationships, could provide further insights.
7.3. Personal Insights & Critique
- Elegance and Simplicity: I find the
MTANarchitecture particularly elegant. The idea of usingsoft attention masksto dynamically select features from a shared backbone is intuitive and powerful. It strikes a good balance between rigidhard parameter sharing(which is efficient but inflexible) andexplicit task-specific networks(which are flexible but parameter-heavy). The fact that it can be built upon anyfeed-forward neural networkis a testament to its versatility. - Addressing Key MTL Challenges: The paper effectively addresses the two main challenges of
MTL:network architecture(how to share features) andloss function balancing. Theattention mechanismdirectly tackles feature sharing, and its inherent robustness simplifies theloss weightingproblem, which is a major pain point in practice.DWAis a practical contribution that further aids in loss balancing. - Transferability: The methods and conclusions of this paper are highly transferable. The
attention-based feature selectionparadigm could be applied to variousmulti-modalormulti-domainlearning scenarios beyond computer vision, wherever shared representations exist but task-specific nuances are crucial. For instance, innatural language processing, a shared encoder for multiple language understanding tasks could benefit fromMTAN-like attention to focus ontask-specificlinguistic features. - Potential Issues/Improvements:
-
Overhead of Attention: While parameter-efficient, the sequential computation of
attention masksat each layer for each task, and the subsequent element-wise multiplication, might introduce some latency during inference, especially if tasks have very deepattention networks. Further work could analyze the computational graph and explore parallelization strategies or more lightweight attention mechanisms. -
Learned Relationships: The paper demonstrates that attention masks are learned and how they look, but a deeper dive into whether these learned attention patterns align with human intuition of task relationships (e.g., if two tasks are highly related, do their attention masks look more similar?) could be insightful.
-
Dynamic Attention in Decoder: The paper focuses on the encoder's attention. Exploring how
attentionmight be dynamically learned within the decoder structure, especially for compleximage-to-imagetasks where reconstruction details are crucial, could be another avenue.Overall,
MTANrepresents a significant step forward in makingMulti-Task Learningmore practical, efficient, and robust, offering a compelling architecture that leverages the power of attention for adaptive feature sharing.
-
Similar papers
Recommended via semantic vector search.