Overview
Direct Answer
Multi-task learning is a machine learning paradigm in which a single model is trained simultaneously on multiple related prediction tasks, leveraging shared representations to improve generalisation and reduce overfitting compared to training separate task-specific models.
How It Works
The model architecture contains shared hidden layers that extract common features, with task-specific output heads branching from these layers. During training, loss functions from all tasks are combined—typically through weighted summation—and gradients propagate through both task-specific and shared parameters, forcing the model to learn representations beneficial across tasks.
Why It Matters
Multi-task learning reduces data requirements per task, decreases training time and computational cost, and improves robustness by allowing models to exploit task relationships. Organisations benefit from fewer model deployments and improved performance in data-constrained domains such as rare disease diagnosis or low-resource languages.
Common Applications
Applications include natural language processing (simultaneous part-of-speech tagging, named entity recognition, and dependency parsing), computer vision (joint detection, segmentation, and classification), and autonomous systems (predicting vehicle trajectory alongside traffic sign recognition).
Key Considerations
Effective multi-task learning requires careful selection of related tasks; incompatible or weakly related tasks can degrade performance through negative transfer. Task weighting and architectural design significantly influence outcomes and often require domain expertise to optimise.
Cross-References(1)
Referenced By1 term mentions Multi-Task Learning
Other entries in the wiki whose definition references Multi-Task Learning — useful for understanding how this concept connects across Machine Learning and adjacent domains.
More in Machine Learning
Transfer Learning
Advanced MethodsA technique where knowledge gained from training on one task is applied to a different but related task.
Bagging
Advanced MethodsBootstrap Aggregating — an ensemble method that trains multiple models on random subsets of data and averages their predictions.
Gradient Descent
Training TechniquesAn optimisation algorithm that iteratively adjusts parameters in the direction of steepest descent of the loss function.
Anomaly Detection
Anomaly & Pattern DetectionIdentifying data points, events, or observations that deviate significantly from the expected pattern in a dataset.
Regularisation
Training TechniquesTechniques that add constraints or penalties to a model to prevent overfitting and improve generalisation to new data.
Adam Optimiser
Training TechniquesAn adaptive learning rate optimisation algorithm combining momentum and RMSProp for efficient deep learning training.
Elastic Net
Training TechniquesA regularisation technique combining L1 and L2 penalties, balancing feature selection and coefficient shrinkage.
Ridge Regression
Training TechniquesA regularised regression technique that adds an L2 penalty term to prevent overfitting by constraining coefficient magnitudes.