Adapting, Fast and Slow: A Causal Approach to Few-Shot Sequence Learning
Kasra Jalaldoust1, Elias Bareinboim1
1Causal Artificial Intelligence Lab, Columbia University
May, 2025
Abstract
Generalization to unseen target domains is a fundamental challenge in machine learning [1-3]. Our work introduces a **causal framework for supervised domain adaptation**, specifically addressing few-shot sequence learning [1, 2]. We investigate scenarios where ample source data is complemented by limited target data, a common setting in supervised domain adaptation (DA) [1, 2, 4, 5]. By combining both **causal structure-informed and structure-agnostic procedures**, we precisely characterize the conditions under which zero-shot or few-shot generalization becomes feasible [1, 2, 6, 7].
A key insight is that **generalization to an unseen target domain is inherently impossible without asserting a causal structure** that constrains the relationship between source and target domains [1, 2, 5, 8]. We extend our findings to sequential prediction tasks, demonstrating how knowledge of complex causal structure allows our structure-informed procedure to learn modular predictors from diverse source domains and systematically recompose them for faster adaptation in the target domain [1, 2, 7, 9]. Notably, we show that our structure-agnostic approach can achieve similarly fast rates in these scenarios [1, 2, 7, 9]. Our results provide a **causal theoretical basis for data-driven domain adaptation** and empirically corroborate these findings [1, 2].
The Challenge of Generalization
Traditional machine learning performance guarantees assume that the target domain, where a solution is evaluated, has an identical data distribution to the source domain used for training [3, 4]. However, even minor qualitative differences between source and target domains can severely impact performance, a problem broadly known as **distribution shift** or, in a scientific context, generalizability or external validity [3, 4].
In this context, **domain generalization** refers to situations where the learner only has access to large data from source domains and no target data [3, 4]. Our work focuses on the **domain adaptation (DA) problem**, a less extreme case where a small amount of target data is also available [4, 5]. The theoretical challenge in DA is not merely if learning is possible (one could always discard source data and rely solely on target data), but rather **how fast learning can occur and how best to leverage data from source domains** [10, 11].
Arbitrary differences between domains pose a significant barrier, making source data potentially useless without a defined relationship or "structure" between domains [5, 8]. Humans excel at transferring knowledge across domains, and causality is widely recognized as central to human understanding and decision-making, especially in changing circumstances [10-18]. Principles of generalization to the unseen from a causal perspective have been extensively studied under "transportability" and "statistical invariances" rooted in an implicit causal structure [1, 10, 11, 19-27].
Our Core Idea: Causal Structure for Adaptation
We seek to characterize when and how certain aspects of source data are generalizable, enabling **fast adaptation** (zero-shot/few-shot learning) versus when source data might hinder learning, leading to **slow adaptation** [6, 11]. Our approach hinges on the fundamental role of an underlying causal structure [6, 7].
Key Contributions & What We Offer:
-
Causal Structure for Faster Adaptation Rates: We introduce a fine-grained causal structure for classification tasks, enabling transportability via a **structure-informed procedure** [6, 7]. This procedure leverages qualitative knowledge, such as causal graphs and domain discrepancies, to transfer inferences from source to target data [1, 2]. We provide performance guarantees for this predictor [6, 7, 21, 28]. Crucially, we also design a **structure-agnostic procedure** that achieves performance guarantees nearly as good as the structure-informed baseline, facilitating few-shot adaptation even without explicit structural knowledge [1, 2, 6, 7, 29, 30].
For instance, in multi-cause domain adaptation, the structure-informed approach can identify relevant source domains where the causal mechanism for the label (Y) is compatible with the target [31, 32]. By pooling data from these compatible sources and the target data, a predictor can be trained on the identified causal parents [31, 32].
The **excess risk** for the structure-informed procedure ($\mu_{TR}$) is characterized as:
- If useful source data exists ($\mathcal{J} \neq \emptyset$), the error rate is $O(|X|^c \cdot |Y| / (\omega^2 \cdot N))$ [21, 28], indicating **fast adaptation** by leveraging large source data ($N$) [21, 28].
- If no source data can help ($\mathcal{J} = \emptyset$), the error rate is $O(|X|^c \cdot |Y| / (\omega \cdot n))$ [21, 28], meaning learning relies solely on target data ($n$) [21, 28].
For the structure-agnostic procedure ($\mu_{Ag}$), the target risk is $R_P*(\mu_{Ag}) = O(R_P*(\mu_{TR}) + c \cdot K \cdot M \cdot \log M / n)$ [29, 30]. This shows that the agnostic approach performs only marginally worse than the structure-informed one, with the added term representing the cost of learning unknown discrepancies and parent sets [29, 33].
-
Extension to Few-Shot Sequence Learning: We extend our framework to sequential prediction tasks, where the goal is to predict the last token of a sequence from a prefix [7, 9, 33, 34]. This is particularly relevant for tasks like fine-tuning language models for reasoning [7, 9, 35].
We introduce a **discrepancy oracle** ($\Delta(i, j; i', j')$) [34, 36]. This boolean function indicates potential mismatches in causal mechanisms or unobserved variable distributions between variables at different positions ($V_i, V_{i'}$) across different domains ($j, j'$) [34, 36]. This oracle allows for matching mechanisms across positions and domains in sequential settings [34, 36].
Our structure-informed algorithm for sequential DA uses this structural knowledge to learn **useful modular predictors** from combined source and target data, and then composes them for faster adaptation [7, 9, 37, 38]. For instance, a common causal module might govern how tokens are generated across different positions or domains [32, 35, 39].
The **excess risk** for sequential prediction ($\mu_{TR}$) is:
- $O(|V|^T / (\omega^2 \cdot N))$ if all components from $M+1$ to $T$ can be transported from sources [22, 40], enabling zero-shot generalization [22, 40].
- $O(|V|^{(M+1)} / (\omega \cdot n))$ otherwise [22, 40], if at least one component cannot be transported, leading to slower adaptation [22, 40].
The structure-agnostic sequential DA procedure ($\mu_{Ag}$) achieves a target risk of $R_P*(\mu_{Ag}) = O(R_P*(\mu_{TR}) + c \cdot K \cdot T^3 \cdot \log T / n)$ [41, 42]. This indicates that even without explicit knowledge of the discrepancy oracle or causal diagrams, the agnostic procedure can achieve near-optimal performance [41, 42].
-
A Practical Solution: Two-Stage Adaptation: Recognizing the computational intractability of the exhaustive structure-agnostic sequential algorithm (Algorithm 4) [43, 44], we introduce a theoretically equivalent, but more practical, two-stage procedure [43-46].
-
**Pretraining:** This stage leverages large source data to learn three key mappings [12, 43, 44, 47]:
- A **mechanism indicator** ($\varpi$) that maps position-domain pairs to categories, identifying identical causal mechanisms [43, 47].
- A **parent matrix** ($A_j$) for each domain $j$, encoding the causal diagram by indicating causal parents for each position [12, 47].
- A **universal predictor** ($\Psi$) that predicts a variable's value given its parents and the mechanism indicator [12, 47].
Maximizing a penalized likelihood on source populations ensures these properties are learned [12, 48].
-
**Fine-tuning:** Target data is partitioned for training, fine-tuning, and held-out validation [49, 50]. During fine-tuning, certain components of the pretrained model (e.g., positional encoding, universal operator indicator) are frozen, while new target-specific parent queries/keys and operator indicators are trained [49, 50]. If a conditional distribution can be transported from a source, the pretrained universal predictor can be reused [49, 50]. The fine-tuning process discovers the target parent matrix and target mechanism indicator, enabling efficient adaptation by reusing learned universal causal functions while adapting to domain-specific parent selection [49-51].
The fine-tuning procedure ($\mu_{ft}$) achieves a performance rate comparable to the structure-agnostic DA Algorithm 4 ($\mu_{Ag}$), i.e., $R_P*(\mu_{ft}) = O(R_P*(\mu_{Ag}))$ [45, 46, 52], making it a computationally feasible alternative [45, 46].
Validating Our Framework: Empirical Evaluation
We evaluated the two-stage adaptation method in both multi-cause and sequential settings, using synthetic data where sequences represent functional programs [45, 52, 53]. Experiments, typically with a single source domain and sequences of length 10, corroborated our theoretical results [45, 52].
-
Pretraining discovers causal structure: Investigation of pretrained parameters confirmed that the learned parent matrix matched the underlying causal structure, and operation indicators aligned for position-domain pairs with matching causal mechanisms [54, 55]. This demonstrates the model's ability to capture not just causal dependencies but also mechanism matches/mismatches across domains [54, 56].
-
Fine-tuning exhibits fast and slow adaptation: Our method achieved a smaller risk faster than baselines (ERM-pool, ERM-joint) when "process supervision" was available (i.e., intermediate tokens were not hidden) [13, 57]. This suggests that re-learning compositions is simpler than re-learning entire circuits when components are shared [13, 57]. However, when intermediate tokens were hidden ("no process supervision"), all methods struggled to converge, highlighting the importance of process supervision for structure-informed adaptation [13, 58].
Key Takeaways
Our paper introduces a **causal framework** for supervised domain adaptation that offers both structure-informed and structure-agnostic algorithms [59, 60]. We demonstrate that **causal structure is critical** for identifying model components that can be reliably transported across domains [59, 60]. Even without explicit structural knowledge, our agnostic procedures can achieve **near-optimal performance** [59, 60]. Finally, the developed **two-stage learning procedure** provides a computationally tractable alternative that is theoretically equivalent to an exhaustive agnostic procedure [59, 60]. This work lays a causal theoretical foundation for data-driven domain adaptation through a unifying structure-agnostic scheme [1, 2].
Understanding with an Analogy: The LEGO Builder
Imagine you are a master LEGO builder. You have a huge collection of instruction manuals (source data) for many different LEGO sets (source domains), but each manual builds something slightly different. Now, someone gives you a very small, incomplete instruction manual for a new, unique LEGO model (target data), and asks you to build it as fast as possible.
Our work is like understanding the underlying principles of LEGOs (the **causal structure**). Instead of treating each manual as entirely separate, you realize that certain smaller sections of instructions (modular predictors) are actually identical or function in the same way, even if they appear in different manuals or build different parts of a model.
-
The **structure-informed procedure** is like having an X-ray vision that tells you exactly which sections of your old manuals are identical to the sections needed for the new model. You can then instantly grab those pre-built modules or instructions from your vast collection and snap them into place, building the new model very quickly (**fast adaptation**) [21, 28, 31, 32].
-
The **structure-agnostic procedure** is like not having X-ray vision, but you're smart. You try out different sections from your old manuals, combining them with the small new instructions you have, and then quickly test which combination gets you closest to the new model [20, 29, 30, 61]. You'll still build it almost as quickly as someone with X-ray vision, because the underlying LEGO principles are consistent [29, 30, 41, 42].
-
The **two-stage adaptation** is even more practical. First, you spend a lot of time organizing your entire LEGO collection by fundamental building principles (pretraining) – classifying every brick and connection type, and how they function universally [12, 43, 44, 47]. Then, when you get a new model, you don't have to guess; you just adapt a few specific parts of your organized system (**fine-tuning**) to match the new model's unique requirements, reusing all your universal knowledge [49, 50]. It's like having a universal LEGO building kit that you can quickly customize for any new model [45, 46].
Ultimately, knowing the fundamental ways LEGOs connect (causal mechanisms) is key to building new and complex models quickly, even when you only have a few new instructions.
Acknowledgments
This research is supported in part by the NSF, ONR, AFOSR, DoE, Amazon, JP Morgan, and The Alfred P. Sloan Foundation [24, 62].
Additional Resources
For a more in-depth understanding of the theoretical underpinnings and empirical results, please refer to the complete PDF document of the paper, which includes comprehensive supplementary material, detailed analyses, and experimental setups [63, 64]. The appendices cover:
- Multi-source domain adaptation in uni-cause case [63, 64]
- Extensive proofs for all lemmas, propositions, and theorems [63, 65, 66]
- More refined details on structure-informed adaptation rates [63, 66, 67]
- Detailed model architecture and design principles [63, 66, 68]
- Comprehensive experimental setup and reproducibility guidelines, including data generation, model architecture specifications, training hyperparameters, baseline configurations, and evaluation protocols [53, 63, 66, 69].
This page was built using the Academic Project Page Template by Eliahu Horwitz, adopted from the Nerfies project page [70, 71].
← Back to main page