MultiStage Influence Function
Abstract
Multistage training and knowledge transfer, from a largescale pretraining task to various finetuning tasks, have revolutionized natural language processing and computer vision resulting in stateoftheart performance improvements. In this paper, we develop a multistage influence function score to track predictions from a finetuned model all the way back to the pretraining data. With this score, we can identify the pretraining examples in the pretraining task that contribute most to a prediction in the finetuning task. The proposed multistage influence function generalizes the original influence function for a single model in (Koh & Liang, 2017), thereby enabling influence computation through both pretrained and finetuned models. We study two different scenarios with the pretrained embeddings fixed or updated in the finetuning tasks. We test our proposed method in various experiments to show its effectiveness and potential applications.
1 Introduction
Multistage training (pretrain and then finetune) has become increasingly important and has achieved stateoftheart results in many tasks. In natural language processing (NLP) applications, it is now a common practice to first learn word embeddings (e.g., word2vec [mikolov2013distributed], GloVe [pennington2014glove]) or contextual representations (e.g., ELMo [peters2018deep], BERT [devlin2018bert]) from a large unsupervised corpus, and then refine or finetune the model on supervised end tasks. Similar ideas in transfer learning have also been widely used in many different tasks. Intuitively, the successes of these multistage learning paradigms are due to knowledge transfer from pretraining tasks to the end task. However, current approaches using multistage learning are usually based on trialanderror and many fundamental questions remain unanswered. For example, which part of the pretraining data/task contributes most to the end task? How can one detect “false transfer” where some pretraining data/task could be harmful for the end task? If a testing point is wrongly predicted by the finetuned model, can we trace back to the problematic examples in the pretraining data? Answering these questions requires a quantitative measurement of how the data and loss function in the pretraining stage influence the end model, which has not been studied in the past and is the main focus of this paper.
To find the most influential training data responsible for a model’s prediction, the influence function was first introduced by [IF1980], from a robust statistics point of view. More recently, as largescale applications become more challenging for influence function computation, [koh2017understanding] proposed to use a firstorder approximation to measure the effect of removing one training point on the model’s prediction, to overcome computational challenges. These methods are widely used in model debugging and there are also some applications in machine learning fairness [brunet19a, WangUC19]. However, all of the existing influence function scores computation algorithms studied the case of singlestage training – where there is only one model with one set of training/prediction data in the training process. To the best of our knowledge, the influence of pretraining data on a subsequent finetuning task and model has not been studied, and it is nontrivial to apply the original influence function in [koh2017understanding] to this scenario. A naive approach to solve this problem is to remove each individual instance out of the pretraining data one at a time and retrain both pretrain and finetune models; this is prohibitively expensive, especially given that pretraining models are often largescale and may take days to train.
In this work, we study the influence function from pretraining data to the end task, and propose a novel approach to estimate the influence scores in multistage training that requires no additional retrain, does not require model convexity, and is computationally tractable. The proposed approach is based on the definition of influence function, and considers estimating influence score under two multistage training settings depending on whether the embedding from pretraining model is retrained in the finetuning task. The derived influence function well explains how pretraining data benefits the finetuning task. In summary, our contributions are threefold:

We propose a novel estimation of influence score for multistage training. In real datasets and experiments across various tasks, our predicted and actual influence score of the pretraining data to the finetuned model are well correlated. This shows the effectiveness of our proposed technique for estimating influence scores in multistage models.

We propose effective methods to determine how testing data from the finetuning task is impacted by changes in the pretraining data. We show that the influence of the pretraining data to the finetuned model consists of two parts: the influence of the pretraining data on the pretrained model, and influence of the pretraining data on the finetuned model. These two parts can be quantified using our proposed technique.

We propose methods to decide whether the pretraining data can benefit the finetuning task. We show that the influence of the pretraining data on the finetuning task is highly dependent on 1) the similarity of two tasks or stages, and 2) the number of training data in the finetuning task. Our proposed technique provides a novel way to measure how the pretraining data helps or benefits the finetuning task.
2 Related Work
Multistage model training that trains models in many stages on different tasks to improve the endtask has been used widely in many machine learning areas. For example, transfer learning has been widely used to transfer knowledge from source task to the target task [pan2009survey]. More recently, researchers have shown that training the computer vision or NLP encoder on a source task with huge amount of data can often benefit the performance of small endtasks, and these techniques including BERT [devlin2018bert], Elmo [Larochelle:2008] and large ResNet pretraining [mahajan2018exploring] have achieved stateofthearts on many tasks.
Although mutlistage models have been widely used, there are few works on understanding multistage models and exploiting the influence of the training data in the pretraining step to benefit the finetune task. In contrast, there are many works that focus on understanding single stage machine learning models and explaining model predictions. Algorithms developed along this line of research can be categorized into features based and data based approaches. Feature based approaches aim to explain predictions with respect to model variables, and trace back the contribution of variables to the prediction [oramas2018visual, pmlrv97guo19b, shrikumar2017learning, smilkov2017smoothgrad, simonyan2013deep, SundararajanTY16, Fong2017Interpretable, Dabkowski2017Real, ancona2017unified]. However, they are not aiming for attributing the prediction back to the training data.
On the other hand, data based approaches seek to connect model prediction and training data, and trace back the most influential training data that are most responsible for the model prediction. Among them, the influence function [IF1980, koh2017understanding], which aims to model the prediction changes when training data is added/removed, has been shown to be effective in many applications. There is a series of work on influence functions, including investigating the influence of a group of data on the prediction [groupInfluence], using influence functions to detect bias in word embeddings [brunet19a], and using it in preventing data poisoning attacks [Steinhardt]. There are also works on data importance estimation to explain the model from the data prospective [NIPS2018_8141, NIPS2019_8674, KhannaKGK19].
All of these previous works, however, only consider a single stage training procedure, and it is not straightforward to apply them to multistage models. In this paper, we propose to analyze the influence of pretraining data on predictions in the subsequent finetuned model and end task.
3 Algorithms
In this section, we detail the procedure of multistage training, show how to compute the influence score for the multistage training, and then discuss how to scale up the computation.
3.1 MultiStage Model Training
Multistage models, which train different models in consecutive stages, have been widely used in various ML tasks. Mathematically, let be the training set for pretraining task with data size , and be the training data for the finetuning task with data size . In pretraining stage, we assume the parameters of the pretrained network have two parts: the parameters that are shared with the end task, and the taskspecific parameters that will only be used in the pretraining stage. Note that could be a word embedding matrix (e.g., in word2vec) or a representation extraction network (e.g., Elmo, BERT, ResNet), while is usually the last few layers that corresponds to the pretraining task. After training on the pretraining task, we obtain the optimal parameters . The pretraining stage can be formulated as
(1) 
where is the loss function for the pretrain task and is summation of loss with respect to all the pretraining examples.
In the finetuning stage, the network parameters are , where is shared with the pretraining task and is the rest of the parameters specifically associated with the finetuning task. We will initialize the part by . Let denote the finetuning loss, and summarizes all the loss with respect to finetuning data, there are two cases when finetuning the endtask:

Finetuning Case 1: Fixing embedding parameters , and only finetune :
(2) 
Finetuning Case 2: finetune both the embedding parameters (initialized from ) and . Sometimes updating the embedding parameters in the finetuning stage is necessary, as the embedding parameters from the pretrained model may not be good enough for the finetuning task. This corresponds to the following formulation:
(3)
3.2 Influence function for multistage models
We derive the influence function for the multistage model to trace the influence of pretraining data on the finetuned model. In Figure 1 we show the task we are interested in solving in this paper. Note that we use the same definition of influence function as [koh2017understanding] and discuss how to compute it in the multistage training scenario. As discussed at the end of Section 3.1, depending on whether or not we are updating the shared parameters in the finetuning stage, we will derive the influence functions under two different scenarios.
3.2.1 Case 1: embedding parameters are fixed in finetuning
To compute the influence of pretraining data on the finetuning task, the main idea is to perturb one data example in the pretraining data, and study how that impacts the test data. Mathematically, if we perturb a pretraining data example with loss change by a small , the perturbed model can be defined as
(4) 
Note that choices of can result in different effects in the loss function from the original solution in (1). For instance, setting is equivalent to removing the sample in the pretraining dataset.
For the finetuning stage, since we consider Case 1 where the embedding parameters are fixed in the finetuning stage, the new model for the endtask or finetuning task will thus be
(5) 
The influence function that measures the impact of a small perturbation on to the finetuning loss on a test sample from finetuning task is defined as
(6)  
(7) 
where measures the influence of on the finetuning task parameters , and measures how influences the pretrained model . Therefore we can split the influence of on the test sample into two pieces: one is the impact of on the pretrained model , and the other is the impact of on the finetuned model . It is worth mentioning that, due to linearity, if we want to estimate a set of test example influence function scores with respect to a set of pretraining examples, we can simply sum up the pairwise influence functions, and so define
(8) 
where contains a set of pretraining data and contains a group of finetuning test data that we are targeting on. Next we will derive these two influence scores and (see the detailed derivations in the appendix) in Theorem 1 below.
Theorem 1.
(10) 
where means taking the part of the vector.
3.2.2 Case 2: embedding parameter is also updated in the finetuning stage
For the second finetuning stage case in (3), we will also further train the embedding parameter from the pretraining stage. When is also updated in the finetuning stage, it is challenging to characterize the influence since the pretrained embedding is only used as an initialization. In general, the final model may be totally unrelated to ; for instance, when the objective function is strongly convex, any initialization of in (3) will converge to the same solution.
However, in practice the initialization of will strongly influence the finetuning stage in deep learning, since the finetuning objective is usually highly nonconvex and initializing with will converge to a local minimum near . Therefore, we propose to approximate the whole training procedure as
(12)  
where are optimal for the pretraining stage, are optimal for the finetuning stage, and is a small value. This is to characterize that in the finetuning stage, we are targeting to find a solution that minimizes and is close to . In this way, the pretrained parameters are connected with finetuning task and thus influence of pretraining data to the finetuning task can be tractable. The results in our experiments show that with this approximation, the computed influence score can still reflect the real influence quite well.
Similarly we can have , , and to measure the difference between their original optimal solutions in (12) and the optimal solutions from perturbation over the pretraining data . Similar to (6), the influence function that measures the influence of perturbation to pretraining data on test sample ’s loss is
(13) 
The influence function of small perturbation of to can be computed following the same approach in Subsection 3.2.1 by replacing for and for in (9). This will lead to
(14) 
(15) 
After plugging (14) and (15) into (3.2.2), we will have the influence function .
Similarly, the algorithm for computing for Case 2 can follow Algorithm 1 for Case 1 by replacing gradient computation. Through the derivation we can see that our proposed multistage influence function does not require model convexity.
3.3 Computation Challenges
The influence function computation for multistage model is presented in the previous section. As we can see in Algorithm 1 that the influence score computation involves many Hessian matrix operations, which will be very expensive and sometimes unstable for largescale models. We used several strategies to speed up the computation and make the scores more stable.
Large Hessian Matrices
As we can see from Algorithm 1, our algorithm involves several Hessian inverse operations, which is known to be computation and memory demanding. For a Hessian matrix with a size of and is the number of parameters in the model, it requires memory to store and operations to invert it. Therefore, for large deep learning models with thousands or even millions of parameters, it is almost impossible to perform Hessian matrix inverse. Similar to [koh2017understanding], we avoid explicitly computing and storing the Hessian matrix and its inverse, and instead compute product of the inverse Hessian with a vector directly. More specifically, every time when we need an inverse Hessian vector product , we invoke conjugate gradients (CG), which transforms the linear system problem into an quadratic optimization problem directly, we will compute a Hessian vector product, which can be efficiently done by backprop through the model twice with time complexity [Hessians]. . In each iteration of CG, instead of computing
The aforementioned conjugate gradient method requires the Hessian matrix to be positive definite. However, in practice the Hessian may have negative eigenvalues, since we run a SGD and the final Hessian matrix may not at a local minimum exactly. To tackle this issue, we solve
(16) 
whose solution can be shown to be the same as is guaranteed to be positive definite as long as is invertible, even when has negative eigenvalues. If is not illconditioned, we can solve (16) directly. The rate of convergence of CG depends on , where is the condition number of , which can be very large if is illconditioned. When is illconditioned, to stabilize the solution and to encourage faster convergence, we add a small damping term on the diagonal and solve . since the Hessian matrix is symmetric.
Time Complexity
As mentioned above, we can get an inverse Hessian vector product in time if the Hessian is with size . To analyze the time complexity of Algorithm 1, assume there are parameters in our pretrained model and parameters in our finetuned model, it takes or to compute a Hessian vector product, where is the number of pretraining examples and is the number of finetuning examples. For the two inverse Hessian vector products as shown in Algorithm 1, the time complexity therefore is and , where is the number of iterations in CG. For other operations in Algorithm 1, vector product has a time complexity of or , and computing the gradients of all pretraining examples has a complexity of . So the total time complexity of computing a multistage influence score is . Therefore we can see that the computation is tractable as it is linear to the number of training samples and model parameters. All the computation related to inverse Hessian can use inverse Hessian vector produc (IHVP), which makes the memory usage and computation efficient.
4 Experiments
In this section, we will conduct experiment on real datasets in both vision and NLP tasks to show the effectiveness of our proposed method.
4.1 Evaluation of Influence Score Estimation
We first evaluate the effectiveness of our proposed approach for the estimation of influence function. For this purpose, we build two CNN models based on CIFAR10 and MNIST datasets. The model structures are shown in Table A in Appendix. For both MNIST and CIFAR10 models, CNN layers are used as embeddings and fully connected layers are taskspecific. At the pretraining stage, we train the models with examples from two classes (“bird" vs. “frog") for CIFAR10 and four classes (0, 1, 2, and 3) for MNIST. The resulting embedding is used in the finetuning tasks, where we finetune the model with the examples from the remaining eight classes in CIFAR10 or the other 6 numbers in MNIST for classification task.
We test the correlation between individual pretraining example’s multistage influence function and the real loss difference when the pretraining examples are removed. We test two cases (as mentioned in Section 3.1) – where the pretrained embedding is fixed, and where it is updated during finetuning. For a given example in the pretraining data, we calculate its influence function score with respect to each test example in the finetuning task test set using the method presented in Section 3. To evaluate this pretraining example’s contribution to the overall performance of the model, we sum up the influence function scores across the whole test set in the finetuning task.
To validate the score, we remove that pretraining example and go through the aforementioned process again by updating the model. Then we run a linear regression between the true loss difference values obtained and the influence score computed to show their correlation. The detailed hyperparameters used in these experiments are presented in Appendix B.
Embedding is fixed In Figures 2(a) and 2(b) we show the correlation results of CIFAR10 and MNIST models when the embedding is fixed in finetuning task training. From Figures 2(a) and 2(b) we can see that there is a linear correlation between the true loss difference and the influence function scores obtained. The correlation is evaluated with Pearson’s value. It is almost impossible to get the exact linear correlation because the influence function is based on the firstorder conditions (gradients equal to zero) of the loss function, which may not hold in practice. In [koh2017understanding], it shows the value is around 0.8 but their correlation is based on a single model with a single data source, but we consider a much more complex case with two models and two data sources: the relationship between pretraining data and finetuning loss function. So we expect to have a lower value. Therefore 0.6 is reasonable to show a strong correlation between pretraining data’s influence score and finetuning loss difference. This supports our argument that we can use this score to detect the examples in the pretraining set which contributes most to the model’s performance.
One may doubt the effectiveness of the expensive inverse Hessian computation in our formulation. As a comparison, we replace all inverse Hessians in (11) with identity matrices to compute the influence function score for the MNIST model. The results are shown in Figure 3 with a much smaller Pearson’s of 0.17. This again shows effectiveness of our proposed influence function.
Embedding is updated in finetune Practically, the embedding can also be updated in the finetuning process. In Figure 1(c) we show the correlation between true loss difference and influence function score values using (12). We can see that even under this challenging condition, our multistage influence function from (12) still has a strong correlation with the true loss difference, with a Pearson’s .
In Figure 4 we demonstrate the misclassified test images in the finetuning task and their corresponding largest positive influence score (meaning most influential) images in the pretraining dataset. Examples with large positive influence score are expected to have negative effect on the model’s performance since intuitively when they are added to the pretraining dataset, the loss of the test example will increase. From Figure 4 we can indeed see that the identified examples are with low quality, and they can be easily misclassified even with human eyes.
test example  pretrain example  test example  pretrain example 
prediction=6  influence score=91.6  prediction=“cat"  influence score=1060.9 
true label=5  true label=2  true label=“automobile"  true label=“bird" 
4.2 Data Cleansing using Predicted Influence Score
Since the pretraining examples with large positive influences scores are the ones that will increase the loss function value indicating negative transfer. Based on the influence score computed, we can improve the negative transfer issue. We perform experiment on the CIFAR10 dataset with the same setting as Section 4.1. After we removed the top 10% highest influence scores (positive values) examples from pretrain (source data), we can improve the accuracy on target data from 58.15% to 58.36%. As a reference, if we randomly remove 10% of pretraining data, the accuracy will drop to 58.08%. Note that the influence score computation is fast. For example, on the CIFAR10 dataset, the time for computing influence function with respect to all pretraining data is 230 seconds on a single Tesla V100 GPU, where 200 iterations of Conjugate Gradient for 2 IHVPs in (9), (10) and (11).
4.3 The Finetuning Task’s Similarity to the Pretraining Task
In this experiment, we explore the relationship between influence function score and finetuning task similarity with the pretraining task. Specifically, we study whether the influence function score will increase in absolute value if the finetuning task is very similar to the pretraining task. To do this, we use the CIFAR10 embedding obtained from a “bird vs. frog" classification and test its influence function scores on two finetuning tasks. The finetuning task A is exactly the same as the pretraining “bird vs. frog" classification, while the finetuning task B is a classification on two other classes (“automobile vs. deer"). All hyperparameters used in the two finetuning tasks are the same. In Figure 5, for both tasks we plot the distribution of the influence function values with respect to each pretraining example. We sum up the influence score for all test examples. We can see that, the first finetuning task influence function has much larger absolute values than that of the second task. The average absolute value of task A influence function score is 0.055, much larger than that of task B, which is 0.025. This supports the argument that if pretraining task and finetuning task are similar, the pretraining data will have larger influence on the finetuning task performance.
4.4 Influence Function Score with Different Numbers of Finetuning Examples
We also study the relationship between the influence function scores and number of examples used in finetuning. In this experiment, we update the pretrained embedding in finetuning stage. We use the same pretraining and finetuning task as in Section 4.1. The results are presented in Figure 6, model C is the model used in Section 4.1 while in model D we triple the number of finetuning examples as well as the number of finetuning steps. Figure 6 demonstrates the distribution of each pretraining examples’ influence function score with the whole test set. The average absolute value of influence function score in model D is 0.15, much less than that of model C. This indicates that with more finetuning examples and more finetuning steps, the influence of pretraining data to the finetuning model’s performance will decrease. This makes sense as if the finetuning data does not have sufficient information for training a good finetuning task, then pretraining data will have more impact on the finetuning task.
Test Sentence  Max absolute influence function value  Sentence in Pretrain  Min absolute influence function value  Sentence in Pretrain 
JebBush said he cut FL taxes by $19B. But that includes cuts in estate taxes mandated by federal law.  0.0841  Specifically , British Prime Minister Gordon Brown has recommended that security control in five provinces be handed over by the end of 2010.  One red shirt suffered a gunshot wound , most likely from a rubber bullet.  
Creating jobs is our greatest moral purpose because they strengthen our families and communities.  0.0070  And Friday, the Commerce Department reports on durable goods orders and the University of Michigan releases its reading on consumer sentiment.  Then there is the issue of why readers buy print publications , and whether the content they are buying can be consumed more easily or conveniently on the Internet .  
The Kryptonian science council was more worried about climate change than these scary people.  0.0102  In addition the PTM also runs a vast network of mobile courts in the rest of the Fata areas, he said .  Until recently , such terror attacks inside Iraq could have coerced the village into sheltering Al Qaeda. 
4.5 Quantitative Results on NLP Task
In this section we show the application of our proposed model on NLP task. In this experiment, the pretraining task is training ELMo [peters2018deep] model on the onebillionword (OBW) dataset [chelba2013one] which contains 30 million sentences and 8 million unique words. The final pretrained ELMo model contains 93.6 million parameters. The finetuning task is a binary sentiment classification task on the First GOP Debate Twitter Sentiment data^{1}^{1}1https://www.kaggle.com/crowdflower/firstgopdebatetwittersentiment containing 16,654 tweets about the early August GOP debate in Ohio. The finetuned model uses original pretrained ELMo embedding and a feedforward neural network with hidden size 64 to build the classifier. The embedding is fixed in the finetuning task. To show quantitative results, we randomly pick a test sentence from the finetuning task, and sample a subset of 1000 sentences from onebillionword dataset to check the influence of this test sentence to these data from the pretraining task. In Table 1 we show examples of test sentences and pretraining sentences with the largest and the smallest absolute influence function score values. Note that the computation time on this largescale experiment experiment (the model contains 93.6 million parameters) is reasonable – each pretraining data point takes average of 0.94 second to compute influence score. For extremely large models and data sets, the computation can be further sped up by using parallel algorithm as each data point’s influence computation is independent.
5 Conclusion
We introduce a multistage influence function for two multistage training setups: 1) the pretrained embedding is fixed during finetuning, and 2) the pretrained embedding is updated during finetuning. Our experimental results on CV and NLP tasks show strong correlation between the score of an example, computed from the proposed multistage influence function, and the true loss difference when the example is removed from the pretraining data. We believe our multistage influence function is a promising approach to connect the performance of a finetuned model with pretraining data.
References
Appendix A Proof of Theorem 1
Proof.
Since , , are optimal solutions, and thus satisfy the following optimality conditions:
(17)  
(18) 
where means concatenate the and as and compute the gradient w.r.t . We define the changes of parameters as , , and . Applying Taylor expansion to the rhs of (18) we get
(19)  
Since are optimal of unperturbed problem, , and we have
(20)  
Since , we have further approximation
(21) 
Similarly, based on (17) and applying first order Taylor expansion to its rhs we have
(22) 
Appendix B Models and Hyperparameters for the Experiments in Sections 4.1, 4.2, 4.3 and 4.4
The model structures we used in Sections 4.1, 4.2, 4.3 and 4.4 are listed in Table A. As mentioned in the main text, for all models, CNN layers are used as embeddings and fully connected layers are taskspecific. The number of neurons on the last fully connected layer is determined by the number of classes in the classification. There is no activation at the final output layer and all other activations are Tanh.

For MNIST experiments in Section 4.1 on embedding fixed, we train a fourclass classification (0, 1, 2, and 3) in pretraining. All examples in the original MNIST training set with with these four labels are used in pretraining. The finetuning task is to classify the rest six classes, and we subsample only 5000 examples to finetune. The pretrained embedding is fixed in finetuning. We run Adam optimizer in both pretraining and finetuning with a batch size of 512. The pretrained and finetuned models are trained to converge. When validating the influence function score, we remove an example from pretraining dataset. Then we rerun the pretraining and finetuning process with this leaveoneout pretraining dataset starting from the original models’ weights. In this process, we only run 100 steps for pretraining and finetuning as the models converge. When computing the influence function scores, the damping term for the pretrained and finetuned model’s Hessians are and , respectively. We sample 1000 pretraining examples when computing the pretraind model’s Hessian summation.

For CIFAR experiments on embedding fixed, we train a twoclass classification (“bird" vs “frog") in pretraining. All examples in the original CIFAR training set with with these four labels are used in pretraining. The finetuning task is to classify the rest eight classes, and we subsample only 5000 examples to finetune. The pretrained embedding is fixed in finetuning. We run Adam optimizer to train both pretrained and finetuned model with a batch size of 128. The pretrained and finetuned models are trained to converge. When validating the influence function score, we remove an example from pretraining dataset. Then we rerun the pretraining and finetuning process with this leaveoneout pretraining dataset starting from the original models’ weights. In this process, we only run 6000 steps for pretraining and 3000 steps for finetuning. When computing the influence function scores, the damping term for the pretrained and finetuned model’s Hessians are and , respectively. Same hyperparameters are used in experiments in Sections 4.3 and 4.4. We also use these hyperparameters in with embedding unfix on CIFAR10’s experiments, except that the pretrained embedding is updated in finetuning and the number of finetuning steps is reduced to 1000 in validation. The constant in Equation 15 is chosen as 0.01. We sample 1000 pretraining examples when computing the pretrained model’s Hessian summation.
Dataset  MNIST  CIFAR 
Embedding  Conv  Conv 
Maxpool  Conv  
Conv  Maxpool  
Maxpool  Conv  
Maxpool  
Conv  
Maxpool  
Task specific  FC <# classes>  FC 1500 
FC <# classes> 