|
Pretrain–Test Task Alignment Governs Generalization in In-Context Learning |
Soundness: 3: good
Presentation: 4: excellent
Contribution: 3: good
Rating: 8: accept, good paper
Confidence: 4: You are confident in your assessment, but not absolutely certain. It is unlikely, but not impossible, that you did not understand some parts of the submission or that you are unfamiliar with some pieces of related work. |
This paper studies how the structure of pretraining data affects in-context learning (ICL) using a simplified linear attention model for linear regression, where training and testing tasks come from Gaussian distributions with different covariances. The authors derive an exact analytical expression for ICL error and propose a new alignment measure between pretraining and test task distributions that accounts for finite-sample effects, task diversity, and noise. This measure correlates extremely well with actual ICL performance, even in trained nonlinear Transformers. Finally, they show that pretraining on the same task distribution is not always optimal—when task diversity is limited, models generalize better if pretrained on a lower-dimensional (power-law structured) set of tasks.
1. The paper is well written and clearly states its problem; overall, it is an engaging and enjoyable read.
2. The analysis in Section 4 is particularly insightful and well structured. The authors examine two complementary cases—fixing `C_train` to find the optimal `C_test`, and fixing `C_test` to find the optimal `C_train`—to understand when pretraining and test task structures align or conflict. The finding that pretraining on the true test distribution `C_test` can be suboptimal in low-data regimes, where a lower-dimensional or more anisotropic pretraining distribution yields better ICL performance, is especially thought-provoking.
3. The proposed alignment measure is well-motivated and technically novel. It quantifies how well the pretraining and test task structures align while explicitly accounting for finite-sample effects, task diversity, and noise—factors often ignored in simpler population-level alignment metrics.
4. The paper provides strong empirical validation: the proposed measure correlates almost perfectly with actual ICL generalization error in both the analytically solvable linear model, outperforming other common similarity metrics such as `C_test C_train^{-1}` and CKA.
I think the paper is beautiful from a theoretical perspective. However, from a more practical ICL standpoint:
1. The paper’s main results are construction-based. It would strengthen the work to show whether a linear attention model trained with stochastic gradient methods actually converges to the analytically derived (constructed) solution. This matters because, while constructions can represent any algorithm in principle, demonstrating convergence under realistic training dynamics would make the results more compelling for practical settings.
2. It remains unclear how additional architectural components—such as multiple layers or MLP blocks—would affect the proposed alignment and the observed “non-optimality of pretraining on the test distribution.” In more general Transformer settings, would the same phenomenon hold, or could deeper architectures mitigate or amplify this effect? Some discussion or intuition on this would be valuable.
See weaknesses above and the following:
I’d appreciate the authors’ view on how Section 4 relates to the “task-scaling vs. context-scaling” perspective (e.g., Abedsoltan et al., 2024: https://arxiv.org/pdf/2410.12783). That work finds that Transformers (unlike MLPs) benefit from both (i) increasing the number of pretraining tasks and (ii) longer context at test time.
In your Fig. 4, my understanding is that context length is held fixed, while task diversity (κ) and spectral shape (p_train) are varied. Could you clarify:
1) If we allow context length to grow (while task diversity remains small), do your results predict that matching the pretraining distribution to the test distribution becomes optimal in the large-context limit? Or does some degree of structured misalignment remain beneficial?
2) Since Fig. 4 already shows that increasing task diversity removes the non-optimality, how would you expect context scaling to interact with this effect? Would larger contexts further reduce alignment sensitivity, or play an orthogonal role?
3) Finally, It would be interesting to understand whether your framework can jointly explain task and context scaling effects. |
Lightly AI-edited |
|
Pretrain–Test Task Alignment Governs Generalization in In-Context Learning |
Soundness: 3: good
Presentation: 3: good
Contribution: 2: fair
Rating: 4: marginally below the acceptance threshold
Confidence: 4: You are confident in your assessment, but not absolutely certain. It is unlikely, but not impossible, that you did not understand some parts of the submission or that you are unfamiliar with some pieces of related work. |
This paper investigates how aligning or misaligning train and test over task diversity contributes to or minimizes ICL error. The authors use a simplified linear attention model for ICL of linear regression tasks that enables them to formally characterize ICL generalization error and that gives them a new way of measuring alignment between training and testing data sets with respect to tasks.
Understanding generalization error in ICL is clearly linked to pretraining and this paper provides a novel approach to thinking about generalization in terms of task diversity, whereas prior work has looked at other generalization problems. This work can in principle give us a better understanding of how pretraining can affect ICL. Intuitively task diversity is a dimension on which real ICL as opposed to memorization can emerge.
However if I understand correctly, task diversity here has apparently quite a limited definition. as in equations 2, 3, 18 for example task diversity in the linear equation set up discussed here is primarily the variance in the matrix W in the linear relationship y_i = Wx_i +b for the inputs in the task. You do not vary $d$ the dimensionality of the tokens in C train nor do you vary the distributions for x_i or b.
Thus, the authors propose an alignment metric for train and test that predicts ICL error on a limited testing suite when the main parameter of variation is the slope of the linear relationship between x and y in the input context. As the linear relationship here is in d dimensions, in fact each dimension m can convey a "different" affine equation, $y^m = a^mx^m$. Thus task diversity comes down to how many linear equations are encoded in the matrix W.
This is potentially very interesting. If we knew that language generation tasks are in general expressible in terms of m linear relationships using d dimensions, then this gives us a general approach to ICL. However, even for simple tasks and simple attention layer only models, the outputs of the attention layer formulated as a mathematical function become highly nonlinear
This is a partial step but only a partial step in learning how ICL really works and how it relies on pretraining. The reason for this partiality is that that the divergence of test and train is very limited. Only in the case in which we are varying C train with respect to task, in the very specific sense you have defined, do, for instance, the elements F(z) and M(z) measure the degree to which we can capture the information in C train.
But generalization should also cover generalizing from the test data inputs for $x$ and $b$, for the linear function learning task as in Garg et al 2022. Naim and Asher (2024), for instance, take a detailed look at what happens when we shift distributions for inputs, slopes and intercepts (x, a, b) in a one dimensional linear function learning task. The results show that transformer models with softmax and normalization don't really compute functions like regression and develop systematic errors as training and test distributions diverge. Another paper, Naim and Asher 2025, shows formally that transformer architectures have inherent limitations in computing linear functions. However, they also show that the same transformer behavior applies to polynomial functions and even some continuous functions. There should be some discussion of how generalization on task diversity interacts with these other forms of generalization, espeically since the observation that adding task samples does not always improve performance matches Naim and Asher's observation of degradation of predictions when extending in context examples after a certain point.
More generally, omitting a discussion about how other dimensions of the problem affect distribution vitiates at least to some degree the authors' claim that "these results identify train-test task alignment as a key
determinant of generalization in ICL,"
There are also some specific worries.
In formula 4, you do add but not Norm, as is standard in transformer architectures. With linear attention in addition there is no exponential. Thus it looks like formula 9 becomes the output of a 1 layer 1 attention head model, which Naim and Asher 2024 show doesn't really have much predictive power.
In equation 8 you drop certain terms that do contribute to the model output values for given tokens. So I have worries about how these analytical results transfer actually to working transformer models with multiple layers, multiple attention heads and non linear scoring functions.
You are working with linear attention and with restricted assumptions on how the calculation of values in the attention layers is made (see remarks above). So while it's really interesting that the measure works on a more standard transformer as in H2, we can't transfer the mathematical analysis very readily to the non linear transformer of H2.
The last section is interesting but given that only task diversity seems to be changing, then it's not clear why this is so surprising
In the conclusion you speak of "fully general task covariates". But the entire exercise is built around a linear relationship between inputs x_i and labels y_i. So how does this provide a fully general task variation? In many generation ICL tasks it is not at all obvious that there is any linear relationship between the input context and the generated string. The linear relationship is a very specific ICL restriction.
Omar Naim and Nicholas Asher. Re-examining learning linear functions in context. arXiv:2411.11465
[cs.LG], 2024
Omar Naim and Nicholas Asher. analyzing limits of icl
https://arxiv.org/pdf/2502.03503, 2025
Are you predicting only for $y_n$ given a fixed input sequence $x_1, y_1, ... x_n$ ?
governing the input?
Is generalization or testing as in equation 12 is just limited to shifting the slope, as the distribution for test x input and scalar epsilon are identical to train?
Here are a few typos and grammatical suggestions:
l, 391: increases/ aggravates/exacerbates but not extremizes
l 431 all the signal |
Fully human-written |
|
Pretrain–Test Task Alignment Governs Generalization in In-Context Learning |
Soundness: 3: good
Presentation: 4: excellent
Contribution: 3: good
Rating: 6: marginally above the acceptance threshold
Confidence: 4: You are confident in your assessment, but not absolutely certain. It is unlikely, but not impossible, that you did not understand some parts of the submission or that you are unfamiliar with some pieces of related work. |
This paper studies a linear self-attention model trained on an in-context linear regression task, where the task vectors in training and at test time follow different distributions. By deriving the test error in terms of the data statistics in a high-dimensional limit, the authors quantify how the information in training is useful at test time. The analysis also reveals that the optimal pretraining distribution in the sense of minimizing the test error is not the test distribution in general.
- The paper is well-written and mostly clear. Frequently providing intuitive interpretations for the equations was very helpful for me to follow the theoretical results and their implications.
- The paper studies an interesting setting in which the in-context task distribution is different in training and at test time, providing analytical expressions for the test error in a high-dimensional limit.
- The finding that it is not always optimal to pretrain on the test distribution is interesting and somewhat surprising.
- The alignment measure introduced in this paper predicts the test error of softmax transformers more accurately than the conventional alignment measures, justifying the applicability of the theory developed in the simplified setup.
- I have several questions regarding the motivation to drop two terms in Equation (7).
First, [1] showed that $v_{21}$ remains zero if initialized at zero with $w_\mu\sim \mathcal N(0,I)$ and $n\to \infty$. Their proof relies on the mean of $w_\mu$ over $n$ sample sequences being zero. However, in this paper's setup with a finite (and possibly small) $k$, the mean of $w_\mu$ isn't close to zero. Hence, using [1] to motivate setting $v_{21}=0$ in line 165 may not be applicable.
Second, I'm not sure why $x_ix_i^\top v_{21}$ is considered to contain no task information (line 161). One reason linear attention cannot achieve zero training/test loss, even when the noise $\epsilon=0$ and the sequence length is sufficient to recover $w$, is that the model relies on the in-context covariance estimated from the training set to approximate $\sum_{i=1}^l x_i x_i^\top$ for new test sequences, rather than using the true $x_i$ covariance of the current context. So $\sum_{i=1}^l x_i x_i^\top$ should contain essential information for robustly inferring the task vector. The limitation instead seems to come from expressivity -- single-layer linear attention lacks the expressivity to properly utilize the information in $\sum_{i=1}^l x_i x_i^\top$.
- In Equation (11), does ridge regularization in the reduced model match ridge regularization in the original model? Training the reduced model regularizing $\Gamma$ may yield different models from training the original linear attention regularizing the value, key and query weights. If they are not equivalent, it would be important to clarify this distinction.
[1] Zhang, R., Frei, S., & Bartlett, P. L. Trained transformers learn linear models in-context. 2024 JMLR.
- Section 4 makes it clear that, given a test distribution, it is not always optimal to pretrain on the test distribution. But it's not yet clear to me how one should design the optimal pretraining distribution that minimizes the test error. Could the authors clarify whether this question is addressed in the paper, or if it remains an open problem?
- In line 369, what are MLP connections? Is the architecture softmax attention and MLP with skip connections?
- A very minor suggestion: it might be helpful to avoid using the word "attention" in its psychological sense (lines 38, 407) in a paper that studies the self-attention model. |
Fully human-written |
|
Pretrain–Test Task Alignment Governs Generalization in In-Context Learning |
Soundness: 3: good
Presentation: 4: excellent
Contribution: 3: good
Rating: 6: marginally above the acceptance threshold
Confidence: 3: You are fairly confident in your assessment. It is possible that you did not understand some parts of the submission or that you are unfamiliar with some pieces of related work. Math/other details were not carefully checked. |
The authors present a theory of in-context learning of linear regression in a simplified one-layer transformer, and calculate the generalization error for in & out of distribution tasks.
The authors provide key steps towards a more complete understanding of ICL in transformers, through their analysis of a simplified transformer & problem setting.
- The subject matter considered is relevant to the state of the field and in particular to AI safety (Understanding distribution shifts in transformers is important for safety concerns)
- The work is, to my knowledge, original
- The paper is presented clearly, with all mathematical objects given appropriate intuition
- While I have not carefully checked the calculations, the results presented follow earlier work I am familiar with and the methods used seem plausible.
In my opinion, the main weakness of the paper as it stands is that not enough is done to explain the limits of the toy model considered. While the authors do show empirically that their model explains behaviors of non-toy, nonlinear, transformers trained on linear problems, it is not yet clear how far one can stretch this model.
In particular:
- The data is still quite simple: Does the model hold up for non-gaussian data (with a linear rule)?
- I imagine the model will break if the in-context regression problem is nonlinear. Is this correct?
- What can we learn about real data from the alignment measures proposed in the paper?
- The authors claim that mismatched task distributions are often optimal. Is there a real-world example of this phenomenon? (I.e. does the theory at least make a qualitative prediction about LLMs?)
- What further steps must be taken to bring theory & practice closer together?
In addition to the questions above, I have a few more:
- From Figure 1, it seems that $e_\text{misalign}$ captures much of the behavior of $e_\text{ICL}$. In what regime is this true?
- In many figures, specific values for the thermodynamic parameters $\alpha, \tau, \kappa$ are chosen. Are the values chosen part of a typical regime? I might imagine that some of the conclusions drawn will change as these parameters are varied (as an example, the authors note that previous work identifies a phase transition at $\tau=1$)
- What are the main obstacles to generalizing the theory beyond linear problems & Gaussian data? (I.e. what goes wrong in the calculation) |
Fully human-written |