A federated forgetting learning method for size model collaborative training
By introducing co-training of large and small models and structured distillation into federated learning, the problems of few-shot learning and model heterogeneity are solved, achieving efficient data forgetting and privacy protection, and improving model performance and robustness.
Patent Information
- Authority / Receiving Office
- CN · China
- Patent Type
- Applications(China)
- Current Assignee / Owner
- DALIAN NATIONALITIES UNIVERSITY
- Filing Date
- 2026-04-02
- Publication Date
- 2026-06-19
AI Technical Summary
Federated learning faces challenges in few-shot learning and model heterogeneity, and existing methods require retraining the global model when data is forgotten, leading to increased computational and storage overhead and failing to effectively guarantee data privacy.
A federated forgetting learning method oriented towards collaboration between large and small models is adopted. The small model is trained locally on the client and the prompt parameters are updated. The server performs structured distillation and perturbation forgetting to achieve knowledge fusion and accurate forgetting.
While ensuring data privacy, we can improve model performance and generalization ability, reduce computation and communication overhead, and achieve efficient data forgetting and collaborative training of heterogeneous models.
Smart Images

Figure CN122242652A_ABST
Abstract
Description
Technical Field
[0001] This invention belongs to the field of machine learning and involves federated forgetting learning, efficient parameter fine-tuning, and knowledge distillation techniques. Specifically, it is a federated forgetting learning method for collaboration between large and small models. Background Technology
[0002] Federated learning, as an emerging distributed learning framework, has been widely applied in fields with high privacy requirements, such as healthcare and finance. Its core idea is to allow multiple clients to train locally while ensuring data privacy, and then send model updates to the server for aggregation, thus avoiding the privacy risks associated with centralized data storage. However, federated learning faces many challenges in practical applications, especially in few-shot learning and model heterogeneity. Few-shot learning requires models to be effectively trained with only a small amount of data, but traditional methods mostly rely on large amounts of data, leading to performance issues in data-scarce scenarios. Furthermore, clients in federated learning typically have limited computing power, using small models for local training, while the server uses complex, large models for global aggregation. Knowledge distillation has been introduced into federated learning as an effective means of knowledge transfer between models. Through distillation, knowledge from large models can be transferred to smaller models, enabling collaborative training between heterogeneous models. How to efficiently integrate the knowledge of these heterogeneous models and ensure that the global model can fully absorb effective information from each client remains a pressing technical challenge.
[0003] Meanwhile, with increasingly stringent data privacy and compliance requirements, achieving data forgetting has become a new challenge in federated learning. In some applications, users may wish to withdraw their data contributions to the model to comply with legal or privacy protection requirements. Existing solutions typically require retraining the global model, which not only consumes significant computational resources but also cannot guarantee that the contributions of the target data can be completely removed, thus increasing computational and storage overhead. Summary of the Invention
[0004] To address the challenges of existing federated learning methods regarding few-shot learning, model heterogeneity, and data forgetting, this invention discloses a federated forgetting method for collaborative learning between small and large models. This method solves the model heterogeneity problem by using a small model for local training on the client side, updating only the cue parameters; simultaneously, on the server side, a large model is used, combined with structured distillation, to effectively integrate the knowledge from the small model into the large model. Furthermore, this invention incorporates a perturbation forgetting mechanism, generating perturbation cue parameters that disrupt intra- and inter-class consistency, and sending them to the server for an aggregation operation. This method achieves accurate forgetting of the target data while ensuring data privacy and communication efficiency.
[0005] Step 1: The server initializes the client's prompt parameters and sends them to the client. The client loads a pre-trained small model and trains it locally on a small number of samples. The client freezes the backbone network parameters and only updates the prompt parameters to adapt to the new task; Step 2: The client uploads the updated prompt parameters to the server. These parameters contain knowledge learned by the client from local data; only the prompt parameters are uploaded, greatly reducing communication volume and maintaining data privacy. Step 3: After receiving the prompt parameters uploaded by each client, the server aggregates the model outputs from each client using a weighted aggregation method. The server assigns weights based on the client's sample size and computing power. The aggregated output represents information from multiple clients, thereby optimizing the global model. Step 4: The server-side uses a structured knowledge distillation method to optimize the global model. By aligning the client output with the global model output, it ensures that the global model can effectively absorb knowledge from each client. During the distillation process, instance-level, batch-level, and category-level losses are calculated to improve the performance of the global model. The federated learning process from Steps 1 to 3 is repeated to further optimize the model. Step 5: When a client wishes to withdraw its data contribution, the server generates a perturbation and, based on minimizing the forgetting loss function, generates an offset representation that is approximately orthogonal to the target data features, ensuring that the impact of the target data on the model is effectively removed. The client uploads a hint parameter containing the offset representation to the server, and the server removes the target data's contribution based on the uploaded offset representation. Step 6: After receiving the hint parameters containing the offset representation, the server performs an aggregation operation to update the parameters of the global model. This aggregation operation removes the influence of the target data, eliminating the need to retrain the global model and thus improving efficiency. Furthermore, step 1 specifically includes: The client loads a small, pre-trained neural network model and trains it locally on a limited number of samples. During training, the client freezes the backbone network parameters of the small model and only updates the parameters of the cue module to adapt to the new task. Assume the client uses a pre-trained model. The model has been trained on a large natural image dataset. During local training, the parameters of the backbone network are frozen on the client side. ,Right now:
[0006] in It is the input sample. These are the frozen backbone network parameters. These are the parameters of the prompt module that need to be updated. This is the output of the prompting module, which guides the model to adapt to new tasks. During training, the client uses the cross-entropy loss function to optimize the prompting module.
[0007] in It is a sample It's a real label. It is a model For the sample The predicted output. The goal of loss function optimization is to minimize the difference between the predicted and the true labels, and update the cue module parameters. After training is complete, proceed to step 2: the client uploads the updated prompt module parameters to the server.
[0008] Furthermore, step 3 specifically includes: After receiving prompt parameters from multiple clients, the server uses a weighted aggregation method to aggregate the model outputs from each client. Specifically, the server receives the updated prompt parameters uploaded by each client. These parameters contain the knowledge the client has learned from local data and represent each client's adaptability to a specific task. To effectively merge the outputs of different clients, the server first softens the output of each client. Assume the client... The output is By softening, it is transformed into a probability distribution. :
[0009] Where T is the temperature parameter. It is the prediction value of client i for category k, while This is the softening probability distribution for this category. (Based on temperature parameters) After softening the output of each client, the probability distribution of the client output becomes smoother, allowing for better fusion of outputs from different clients. The server then assigns different weights based on the sample size of each client. Then, the softened output from each client is weighted and aggregated. The formula for weighted aggregation is:
[0010] in, It is the first The weight of each client, It is a client The softened output This is the aggregated global output. Through weighted aggregation, the server can combine the output results from various clients, thereby optimizing the global model.
[0011] Finally, the server will aggregate the output. This is used to further optimize the global model, enabling it to effectively absorb knowledge from various clients, thereby improving the performance and generalization ability of the global model.
[0012] Furthermore, step 4 specifically includes: The server optimizes the global model using structured knowledge distillation, aligning it at the instance, batch, and class levels. The server first receives softened output from multiple clients. These outputs are the client model's predictions for each category. The server calculates the distillation loss by measuring the Kullback-Leibler divergence between the client output and the global model output. This loss reflects the difference between the client model and the global model output, as shown in the following formula:
[0013] in It is a client The KL divergence between the output of the model and the global model output. It's the client's weight. and These are the softened outputs of the client model and the global model, respectively. By minimizing the KL divergence, the server can encourage the global model to gradually absorb knowledge from each client.
[0014] The server calculates batch-level alignment loss to ensure consistency in the global model's output across the entire batch of samples. The batch-level alignment loss is optimized by calculating the difference in the Gram matrix between the client and the global model within a batch. The calculation formula is:
[0015] in, and These are the feature map Gram matrices for the client and the global model, respectively.
[0016] The server calculates the category-level alignment loss by computing the Gram matrix for each category to ensure the global model better understands the relationships between different categories. The formula for calculating the category-level alignment loss is:
[0017] Where C is the total number of categories, and These are the predicted probabilities of class k by the client and the global model, respectively.
[0018] Finally, the server sums the losses from instance-level alignment, batch-level alignment, and category-level alignment to obtain the total distillation loss:
[0019] By minimizing this loss function, the server can continuously optimize the global model, enabling it to fully absorb knowledge from multiple clients and improve the generalization ability of the global model.
[0020] Furthermore, step 5 specifically includes: When a client wants to withdraw its local data's contribution to the global model, it sends a forget request to the server. Assume the target data is represented by the following features: The client constructs a perturbation vector locally. To keep it approximately orthogonal to the original features:
[0021] Disturbance representation It is used to disrupt the inter-class relationships and intra-class consistency of the target data in the prediction space. This is achieved through the forgetting loss function, which includes inter-class loss. and intra-class loss .
[0022] Inter-class loss is used to disrupt the original discriminative relationship between the target sample and the center of each class. Let the center vector of class c be... Then the inter-class loss is defined as:
[0023] in This represents cosine similarity. Minimizing this loss can significantly alter the discriminative power between the target sample and different categories.
[0024] Intra-class loss is used to reduce the clustering of target samples within their respective categories. Assume the target sample... Category The sample set for this category is The intra-class loss is then expressed as:
[0025] By minimizing this loss, the similarity between the perturbed target sample and samples of the same class will be reduced, thereby reducing its contribution within the class.
[0026] Ultimately, the overall forgetting loss is defined as follows:
[0027] in, To balance the hyperparameters of the two parts, the client can generate a bias representation after orthogonal perturbation by minimizing the forgetting loss. The perturbated prompt parameters are then uploaded to the server. Step 6 achieves the forgetting of the target data in the global model.
[0028] This invention discloses a federated forgetting learning method for collaborative training of models of varying sizes. This method effectively fulfills forgetting requirements while ensuring model performance by constructing a collaborative mechanism between the client and server for lightweight and high-capacity models. Specifically, this invention introduces a prompting learning mechanism on the client side, updating only a small number of parameters for the prompting module. This fully activates the feature extraction capabilities of the pre-trained model under limited sample conditions, avoiding the risk of overfitting due to excessive reliance on limited data, and significantly reducing communication and computational overhead. On the server side, a hierarchical structured distillation approach is used to align the knowledge representation uploaded by the client at multiple levels, including instance-level, batch-level, and category-level distribution structure modeling, ensuring effective knowledge transfer and fusion between models of different sizes and architectures.
[0029] Compared to traditional distillation methods that rely solely on output probability alignment, this invention captures semantic structural relationships more comprehensively, improving the global model's adaptability to heterogeneous environments and low-sample scenarios. When a client requests a forgetting mechanism, this invention further introduces a structured perturbation forgetting mechanism. By applying an orthogonalization offset in the prediction space, it disrupts the target data's contribution to inter-class relationships and intra-class consistency, thereby generating a compensated representation independent of the original feature space. On the server side, the impact of the target data on the global model can be efficiently eliminated with a single aggregation, without needing to save historical update records or perform full retraining, effectively reducing the cost and complexity of forgetting. Attached Figure Description
[0030] To more clearly illustrate the technical solutions in the embodiments of the present invention or the prior art, the drawings used in the description of the embodiments or the prior art will be briefly introduced below. Obviously, the drawings described below are some embodiments of the present invention. For those skilled in the art, other drawings can be obtained based on these drawings without creative effort.
[0031] Figure 1 This is a flowchart illustrating the federated forgetting learning method for collaborative training of large and small models according to the present invention. Figure 2 This is a framework diagram of the federated forgetting learning method for collaborative training of large and small models according to the present invention. Detailed Implementation
[0032] To enable those skilled in the art to better understand the present invention, the technical solutions of the present invention will be clearly and completely described below with reference to the accompanying drawings of the embodiments of the present invention. Obviously, the described embodiments are only some embodiments of the present invention, and not all embodiments. Based on the embodiments of the present invention, all other embodiments obtained by those skilled in the art without creative effort should fall within the scope of protection of the present invention.
[0033] It should be noted that the terms "first," "second," etc., in the specification, claims, and accompanying drawings of this invention are used to distinguish similar objects and are not necessarily used to describe a specific order or sequence. It should be understood that such data can be interchanged where appropriate so that the embodiments of the invention described herein can be implemented in orders other than those illustrated or described herein. Furthermore, the terms "comprising" and "having," and any variations thereof, are intended to cover a non-exclusive inclusion; for example, a process, method, system, product, or apparatus that comprises a series of steps or units is not necessarily limited to those steps or units explicitly listed, but may include other steps or units not explicitly listed or inherent to such processes, methods, products, or apparatus.
[0034] This invention provides a federated forgetting learning method for collaborative learning between large and small models. By enabling collaborative training between clients and servers within a heterogeneous federated learning framework, it meets the requirements of privacy protection and efficient forgetting. This method combines cue learning, structured distillation, and perturbation forgetting mechanisms to improve the model's adaptability and classification performance in the target domain. Especially in scenarios with few samples, it can effectively utilize client data, enhance the accuracy of the target domain model, and precisely remove the contribution of specific client data to the global model.
[0035] like Figure 1 The flowchart shown is a flowchart of the method of this invention. Experiments were conducted using the TinyImageNet, Food101, OxfordPets, and BreakHis datasets for validation. In this embodiment, each client has limited local data and computing resources, and the client only accesses local data to ensure data privacy. In each training round, the client and server collaborate on training. The specific training steps are as follows: Step 1: For all client data, use 1, 2, 16, 64, and 128 samples for each category sequentially. The remaining samples are placed on the server for auxiliary aggregation. The client and server use the same test set to evaluate performance.
[0036] Step 2: Each client loads a pre-trained ResNet-18 model. The server uses a pre-trained ResNet-50 model, which has more parameters and stronger generalization ability. All models use the same input data format and normalization method, while freezing the pre-trained parameters and adding cue parameters.
[0037] Step 3: Each client loads the student model locally and performs few-shot training based on local data. The client uses gradient descent for local training, setting the batch size to 32, the learning rate to 0.005, the momentum to 0.9, the learning rate decay factor to 0.01, and training for 10 epochs. At the end of each training epoch, the client evaluates the performance on the local test set.
[0038] Step 4: After each client completes training, the trained student model parameters are uploaded to the server without transmitting the original data to ensure data privacy and security.
[0039] Step 5: The server aggregates all student model parameters uploaded by clients and performs knowledge fusion using a structured distillation strategy. Instance-level, batch-level, and class-level alignment is implemented by minimizing the KL divergence between the weighted aggregated student model and the teacher model output. Distillation loss includes instance-level loss, batch-level loss, and class-level loss, ultimately optimizing the global model. Parameter settings: The SGD optimizer is used with a learning rate of 0.001, momentum of 0.9, and weight decay of 0.0005, for 10 training epochs. Finally, the aggregated global model is compressed using structured distillation and distributed to each client.
[0040] Step 6: When a client performs a forgetting operation, first execute the local prediction perturbation, then train the model using stochastic gradient descent for one training epoch. The specific training parameters are set as follows: learning rate of 0.08 and batch size of 2. Simultaneously, during the model distillation process, the ratio of inter-class perturbation to intra-class shuffling loss is 1:1.
[0041] Step 7: After receiving the forgotten representation uploaded by the forgotten client, the server performs weighted aggregation to eliminate the contribution of the target data to the global model. The outputs of the student and teacher models are minimized through structured distillation.
[0042] This invention uses accuracy as a model performance evaluation metric. The accuracy calculation formula is as follows:
[0043] in, This represents the true positive cases, i.e., the number of samples that are actually positive and are correctly predicted as positive by the model. This represents the number of true negative examples, which are actually negative samples that were correctly predicted as negative by the model. This represents the number of false positives, which are actually negative samples but were incorrectly predicted as positive samples by the model. This represents false negatives, which are actually positive samples but were incorrectly predicted as negative samples by the model.
[0044] Based on the above steps, build a federated forgetting method for size-oriented models, such as... Figure 2 As shown in Table 1, the data clearly demonstrates that the proposed method significantly improves clean accuracy and substantially reduces backdoor accuracy. This indicates that the method not only effectively enhances model performance but also accurately removes the contribution of target data, ensuring data privacy while strengthening model security and robustness. In the backdoor attack experiment, the experimental setup included randomly selecting 40% of clients as malicious clients. Each malicious client injected a 5×5 white square as a backdoor trigger into 60% of the local samples, and these samples were assigned fixed target labels. The experimental results show that this method effectively reduces the success rate of backdoor attacks. The backdoor attack success rate for all datasets is significantly lower than other comparative methods, with the highest reduction to 0.92%, proving that the method maintains efficient learning capabilities while defending against malicious attacks.
[0045] Table 1 Backdoor attack experiments
[0046] Finally, it should be noted that the above embodiments are only used to illustrate the technical solutions of the present invention, and not to limit them; although the present invention has been described in detail with reference to the foregoing embodiments, those skilled in the art should understand that modifications can still be made to the technical solutions described in the foregoing embodiments, or equivalent substitutions can be made to some or all of the technical features; and these modifications or substitutions do not cause the essence of the corresponding technical solutions to deviate from the scope of the technical solutions of the embodiments of the present invention.
Claims
1. A federated forgetting learning method for collaborative training of large and small models, characterized in that: A pre-trained small model is loaded on the client side and trained locally on a small number of samples using cue learning. Only the cue parameters are updated, while the backbone parameters are kept fixed. Upload the updated prompt parameters to the server. The upload only involves the prompt parameters and not the complete model. On the server side, a large model is used to receive and load the prompt parameters uploaded by the client, and the output results of each client model are obtained by combining the public dataset. The output results are then weighted and aggregated to form an aggregated representation. Based on the aggregate representation, the server-side uses a structured knowledge distillation method to perform constraint learning on the large model, enabling the large model to obtain distribution information from the client-side model output; When any client makes a forgetting request, a perturbation is introduced into the target data in the prediction space to generate an offset representation that is approximately orthogonal to the features of the target data; The prompt parameters containing the offset representation are sent to the server, which then performs an aggregation operation to remove the relevant contributions to the target data.
2. The federated forgetting learning method for collaborative training of large and small models according to claim 1, characterized in that: When local training is performed on a small number of samples using cue-based learning, and only the cue parameters are updated: the cue parameters are implemented by cue modules inserted into each stage of the backbone, which are used to enhance the adaptability to small samples while maintaining the stability of the original model structure. The optimization objective is limited by the cross-entropy loss function. in The small model output after inserting the prompt module is in The predicted output on, For the corresponding label, the update process only involves prompt parameters.
3. The federated forgetting learning method for collaborative training of large and small models according to claim 1, characterized in that: Structured distillation aggregates small models into large models: the server accepts parameters prompted by the client and utilizes public datasets. Calculate its softened output and perform weighted aggregation, where each client-side small model is aggregated from the samples. The softened output is defined as follows: in For client-side small models in samples The original logits on For the temperature parameter, the client indicates aggregation as: in The method is determined based on the client's sample size, data quality, or computing power, and is used to balance the different contributions in heterogeneous environments. The server uses this aggregated representation as the distillation target and employs structured knowledge to distill the loss. in Used to align the predicted distribution at the single-sample level. Used to maintain consistency between samples within a batch. This is used to align semantic relationships between different categories, thereby improving the ability of large models to absorb knowledge from heterogeneous clients.
4. The federated forgetting learning method for collaborative training of large and small models according to claim 1, characterized in that: When any client requests to forget, an offset representation for the target data is generated in the prediction space. And it remains approximately orthogonal to the original feature P: The offset representation is generated by minimizing the following forgetting loss: in and It is used to disrupt inter-class relationships and weaken intra-class consistency, thereby preventing the feature contributions of the target data from being effectively utilized.
5. The federated forgetting learning method for collaborative training of large and small models according to claim 5, characterized in that: When the server performs an aggregation operation: The client completes the offset representation After generation, it will carry the prompt parameter represented by the offset. The data is sent to the server, which performs an aggregation operation to update the global model. in For the first Round global model parameters, the By using structured distillation, the contribution of the target data is removed in a single aggregation, avoiding the computational and storage overhead caused by repeated training of the global model.