Heterogeneous model aggregation method and system based on domain difference perception distillation
By employing a heterogeneous model aggregation method based on domain difference-aware distillation, and utilizing autoencoders and knowledge distillation algorithms to perform model aggregation on the server side, the problem of data domain difference interference in cross-domain joint learning is solved, achieving more efficient model aggregation and performance improvement.
Patent Information
- Authority / Receiving Office
- CN · China
- Patent Type
- Patents(China)
- Current Assignee / Owner
- FUDAN UNIVERSITY
- Filing Date
- 2022-09-09
- Publication Date
- 2026-06-16
AI Technical Summary
In cross-domain joint learning scenarios, differences in data domains among different participants can interfere with the server distillation and aggregation process, leading to problems such as the inability to aggregate models or poor performance of aggregated models.
A heterogeneous model aggregation method based on domain difference-aware distillation is adopted. The model is collected from each client through an autoencoder and knowledge distillation based on domain difference awareness is performed on the server side. Different weights are adaptively assigned to different teacher models to shield the influence of model structure differences.
It effectively mitigates the impact of differences between different client and server datasets, supports various forms of model heterogeneity, adapts to more real-world application scenarios, and improves the efficiency and performance of model aggregation.
Smart Images

Figure CN117746172B_ABST
Abstract
Description
Technical Field
[0001] This invention belongs to the field of machine learning and image processing technology, specifically relating to a heterogeneous model aggregation method and system based on domain difference-aware distillation. Background Technology
[0002] With the increasing adoption of deep learning technology and growing demand for privacy protection, federated learning-based joint machine learning techniques have gained significant attention in recent years. Federated learning allows multiple clients to jointly train a model, where client datasets are not shared. In each round of communication, the server sends the global model to the clients. Subsequently, each client uses its local data to train its local model based on the global model, and the server collects the updated local models from the clients for aggregation. Through multiple rounds of communication, the aggregated model gradually converges, eventually enabling the simultaneous processing of data from multiple clients.
[0003] In traditional federated learning techniques, servers aggregate models using model parameter averaging. This approach requires numerous iterations across communication rounds to converge, and model parameter averaging is only applicable when models have completely identical structures. In real-world scenarios, due to limitations in client storage and computing resources, different clients may employ different model structures. For example, mobile devices might use lightweight network models, while desktop clients can use larger-scale models. In such cases, model averaging cannot aggregate heterogeneous models.
[0004] Knowledge distillation is a method for aggregation of heterogeneous models. Assuming a server possesses a large amount of unlabeled data, it can use this data for distillation training, allowing the student model (i.e., the aggregated model) to predict the same data as the teacher model (i.e., the client model collected by the server). However, the server's unlabeled data and the client data often differ in data domain. Since the server cannot directly access the client data, this difference is difficult to control. When the data domain differences are significant, the performance of knowledge distillation-based methods degrades significantly, causing the aggregated model to be unable to adapt to data distributions across multiple domains.
[0005] As mentioned above, in cross-domain joint learning scenarios, the differences in data domains among different participants interfere with the server distillation and aggregation process, resulting in problems such as the inability to aggregate models or poor performance of aggregated models. Summary of the Invention
[0006] This invention addresses the aforementioned problems by providing a heterogeneous model aggregation method and system that resolves the interference of data domain differences among different participants in the server distillation and aggregation process during cross-domain joint learning scenarios. The invention employs the following technical solution:
[0007] This invention provides a heterogeneous model aggregation method based on domain difference-aware distillation, used to aggregate multi-party models when there are data domain differences among multiple clients participating in multi-party joint learning. The method includes the following steps: an autoencoder collection step, where each client trains its autoencoder using local data and uploads the trained autoencoder to a server; and a joint training step, where each client trains its local client classification model and uploads it to the server. The server then uses the client classification model and the autoencoder to perform model aggregation based on domain difference-aware distillation to obtain an aggregated model, which is then distributed to each client for the next round of training.
[0008] The heterogeneous model aggregation method based on domain difference-aware distillation provided by the present invention may also have the following technical features, wherein the autoencoder includes an encoder module and a decoder module, which are trained locally by the client, and the autoencoder collection step is performed only once.
[0009] The heterogeneous model aggregation method based on domain difference-aware distillation provided by this invention may also have the following technical features: the autoencoder includes an encoder module and a decoder module, which are trained locally by the client. The client is applied to scenarios where data is constantly changing, and the autoencoder collection step is executed periodically.
[0010] The heterogeneous model aggregation method based on domain difference-aware distillation provided by this invention may also have the following technical features, wherein the joint training step includes the following sub-steps: a model upload sub-step, in which the client uploads its local initial classification model as the client classification model to the server; a model aggregation sub-step, in which the server aggregates the collected client classification models to obtain an aggregated model, and distributes the aggregated model to each client; a model training sub-step, in which the client trains its local client classification model based on the received aggregated model, obtains the trained client classification model, and uploads it to the server; and an iterative judgment sub-step, in which it is determined whether a predetermined communication round has been reached, and if the determination is negative, it returns to the model aggregation sub-step to perform the next round of training.
[0011] The heterogeneous model aggregation method based on domain difference-aware distillation provided by this invention may also have the following technical features: Based on the storage and computing power limitations of each client, the client classification model adopts a model structure of appropriate scale. In the model training sub-step, if the model structure is the same as the aggregated model, the client directly uses the aggregated model for initialization and further trains based on its local data. If the model structure is different from the aggregated model, the client uses the aggregated model as the teacher model and its local client classification model as the student model for knowledge distillation.
[0012] The heterogeneous model aggregation method based on domain difference-aware distillation provided by this invention may also have the following technical features: in the model aggregation sub-step, the server aggregates multiple client classification models as multiple teacher models; the server uses the collected autoencoders to determine the distance between the current sample in the server and the client data distribution; for client data that is closer to the current sample, the corresponding teacher model is assigned a larger weight; and the current sample is unlabeled data.
[0013] The heterogeneous model aggregation method based on domain difference-aware distillation provided by this invention may also have the following technical feature, wherein, for sample x, the reconstruction error of all the autoencoders is first calculated:
[0014]
[0015] In the formula, For the autoencoder of the i-th client,
[0016] Then, the weights of the teacher model are calculated using the reconstruction error:
[0017]
[0018] Finally, the weighted outputs of multiple teacher models are used to guide the training of the student model, with the loss function being:
[0019]
[0020] In the formula, D KL Here, KL divergence is used to measure the distance between two vectors, and h(x) is the output of the student model. This is the weighted output of the teacher model.
[0021] This invention provides a heterogeneous model aggregation system based on domain difference-aware distillation, characterized by comprising: a server; multiple clients, wherein the multiple clients have data domain differences; an autoencoder collection module, used to collect the autoencoders of each client, wherein each client trains its autoencoder using local data and uploads the trained autoencoder to the server; and a joint training module, used to jointly train the client classification models of each client, wherein each client trains its local client classification model and uploads it to the server, and the server uses the client classification model and the autoencoder to perform model aggregation based on domain difference-aware distillation to obtain an aggregated model, and distributes the aggregated model to each client for the next round of training.
[0022] Invention Function and Effect
[0023] The heterogeneous model aggregation method and system based on domain difference-aware distillation according to the present invention, because it employs a knowledge distillation algorithm for model aggregation on the server side and also supports knowledge distillation for model training on the client side, can shield the impact of differences in model structure on joint learning. Furthermore, because the server-side distillation algorithm adaptively assigns different weights to different teacher models, it can mitigate the impact of differences between datasets from different clients and the server on model aggregation. In addition, since the present invention supports various forms of model heterogeneity, namely heterogeneity between server and client models and heterogeneity between different clients, it can adapt to more real-world application scenarios and has greater practical application value. Attached Figure Description
[0024] Figure 1 This is a system equipment framework diagram of the heterogeneous model aggregation method based on domain difference-aware distillation in an embodiment of the present invention;
[0025] Figure 2 This is a flowchart of the heterogeneous model aggregation method based on domain difference-aware distillation in an embodiment of the present invention;
[0026] Figure 3 This is a specific implementation framework diagram of the joint training step in this embodiment of the invention;
[0027] Figure 4 This is a flowchart illustrating the specific implementation of the joint training steps in this embodiment of the invention.
[0028] Figure 5 This is a structural block diagram of the heterogeneous model aggregation system based on domain difference-aware distillation in an embodiment of the present invention. Detailed Implementation
[0029] To make the technical means, creative features, objectives and effects of the present invention easy to understand, the following describes in detail the heterogeneous model aggregation method and system based on domain difference-aware distillation of the present invention with reference to embodiments and accompanying drawings.
[0030] <Example>
[0031] Figure 1 This is a system equipment framework diagram of the heterogeneous model aggregation method based on domain difference-aware distillation in this embodiment.
[0032] like Figure 1 As shown, the system 100 includes server device 110 and computing devices 111-113. Both computing devices 111-113 and server device 110 are used to process image and model data, mainly including processors and memory. The processor is a hardware processor used for computation, such as a central processing unit (CPU) or a graphics processing unit (GPU). Memory is a non-volatile storage device used to store code, model parameters, and other intermediate data from server device 110 and computing devices 111-113. The dataset consists of image data stored on each device; this image data contains privacy information for each device and cannot leave the device.
[0033] Specifically, the server device 110 in this embodiment includes a processor 120 and memory 130. The memory 130 stores a dataset 1301 and executable code 140. The executable code 140 includes a domain difference-aware distillation module 1401, responsible for aggregating client models and sending the aggregated model to the client. The computing device 111 includes a processor 121 and memory 131. The memory 131 stores a dataset 1311 and executable code 141. The executable code 141 includes an autoencoder module 1411, a model training module 1412, and a distillation training module 1413. The autoencoder module 1411 is responsible for training the autoencoder of the computing device 111 using local data (i.e., data in the dataset 1311). The model training module 1412 is responsible for training the local classification model when the local classification model and the server model have the same structure. The distillation training module 1413 is responsible for performing knowledge distillation training when the local classification model and the server model have different structures. Figure 1 As shown, the structures and functions of computing devices 112 and 113 are the same as those of computing device 111, so they will not be described again.
[0034] In this embodiment, the data in the three computing devices 111 to 113 come from different domains.
[0035] In the following description, the server refers to server device 110, and the client refers to computing devices 111 to 113.
[0036] Furthermore, client devices fall into two categories: those whose model structure is identical to the server model structure (hereinafter referred to as homogeneous clients), and those whose model structure differs from the server model structure (hereinafter referred to as heterogeneous clients). Each client can adopt a different scale of model structure based on its own storage and computing resources.
[0037] In this embodiment, computing devices 111 and 112 are homogeneous clients, and computing device 113 is a heterogeneous client.
[0038] Figure 2 This is a flowchart of the heterogeneous model aggregation method based on domain difference-aware distillation in this embodiment.
[0039] like Figure 2 As shown, the heterogeneous model aggregation method based on domain difference-aware distillation specifically includes the following steps:
[0040] Step S1, autoencoder collection step: Each client trains an autoencoder using its local data and uploads the trained autoencoder to the server.
[0041] The autoencoder comprises an encoder module and a decoder module, which are trained locally by the client. In this embodiment, the autoencoder collection process (i.e., the process of uploading to the server) is performed only once; that is, an additional round of communication is conducted before the joint training begins to aggregate the autoencoders from multiple clients to the server. However, when the client is used in scenarios where the data is constantly changing, the autoencoder collection step can be executed periodically and multiple times.
[0042] Step S2, joint training step: Each client trains its local client classification model and uploads it to the server. The server uses the collected client classification models and autoencoders to perform model aggregation based on domain difference-aware distillation, and distributes the aggregated model to each client for the next round of training.
[0043] Figure 3 and Figure 4 These are the implementation framework diagram and flowchart for joint training in this embodiment, respectively.
[0044] like Figures 3-4 As shown, in this embodiment, step S2 specifically includes the following sub-steps:
[0045] Step S2-1: The client uploads its local initial classification model to the server as the client classification model.
[0046] In step S2-2, the server collects multiple client classification models and aggregates them as multiple teacher models to obtain an aggregated model, which is then distributed to each client.
[0047] In steps S2-3, the client trains its local client classification model based on the received aggregated model, obtains the trained client classification model, and uploads it to the server.
[0048] Step S2-4: Determine whether the preset number of communication rounds has been reached. If the determination is yes, enter the end state; if the determination is no, return to step S2-2 and proceed to the next round of training.
[0049] In step S2-2, the weights of each teacher model are obtained as follows:
[0050] After the autoencoders and client classification models are collected, the domain difference-aware distillation module 1401 of the server device 110 uses the client autoencoders to determine the distance between the server's current sample (unlabeled data) and the client data distribution. For client data that is closer to the current sample, the corresponding teacher model (i.e., the client classification model in that client) is given a larger weight; conversely, the weight of the corresponding teacher model is reduced. Specifically, for sample x, the reconstruction error of all autoencoders is first calculated:
[0051]
[0052] In the formula, Let be the autoencoder for the i-th client.
[0053] Then, in step S2-2, the weights of the teacher model are calculated using the reconstruction error:
[0054]
[0055] As can be seen from the above formula, the teacher model of the client corresponding to the autoencoder with a larger reconstruction error will have a larger weight.
[0056] Finally, in steps S2-3, the weighted outputs of multiple teacher models are used to guide the training of the student model, with the loss function being:
[0057]
[0058] In the formula, D KL Here, KL divergence is used to measure the distance between two vectors, and h(x) is the output of the student model. This is the weighted output of the teacher model.
[0059] As mentioned above, client devices fall into two categories. For homogeneous clients, the client is directly initialized using the aggregation model and further trained based on local data. For heterogeneous clients, the aggregation model is used as the teacher model, and the local client classification model is used as the student model for knowledge distillation.
[0060] Figure 5This is a structural block diagram of the heterogeneous model aggregation system based on domain difference-aware distillation in this embodiment.
[0061] like Figure 5 As shown, this embodiment also provides a heterogeneous model aggregation system 10 corresponding to the above method. The heterogeneous model aggregation system 10 includes a server 11, multiple clients 12, an autoencoder collection module 13, a joint training module 14, and a control module 15.
[0062] The server 11 communicates with multiple clients 12, and these clients 12 have different data domains, meaning their data comes from different domains. The autoencoder collection module 13 collects the autoencoders from the multiple clients 12 into the server 11 according to the method described in step S1 above. The joint training module 14 performs joint training on the client classification models of the multiple clients 12 according to the method described in step S2 above. The control module 15 controls the operation of the server 11, the multiple clients 12, the autoencoder collection module 13, and the joint training module 14.
[0063] Functions and effects of the embodiments
[0064] The heterogeneous model aggregation method and system based on domain difference-aware distillation provided in this embodiment can mitigate the impact of model structure differences on joint learning by employing a knowledge distillation algorithm on the server side and supporting knowledge distillation for model training on the client side. Furthermore, the use of a domain difference-aware distillation algorithm on the server side allows for adaptive assignment of different weights to different teacher models, thus mitigating the impact of differences between datasets from different clients and the server on model aggregation. Additionally, this embodiment supports various forms of model heterogeneity, including server-client model heterogeneity and model heterogeneity between different clients, making it adaptable to more real-world application scenarios and possessing greater practical application value.
[0065] The above embodiments are only used to illustrate specific implementations of the present invention, and the present invention is not limited to the scope of the description of the above embodiments.
[0066] In the above embodiments, the data of the participants in the joint learning (i.e., each client) comes from different domains. In an alternative, the data of the participants in the joint learning can also come from the same domain. When the data comes from different domains, the method of the present invention has a more significant performance improvement.
Claims
1. A heterogeneous model aggregation method based on domain difference-aware distillation, used to aggregate multi-party models when data domain differences exist among multiple clients participating in multi-party joint learning, characterized in that, Includes the following steps: In the autoencoder collection step, each client trains its autoencoder using local data and uploads the trained autoencoder to the server. In the joint training step, each client trains its local client classification model and uploads it to the server. The server uses the client classification model and the autoencoder to perform model aggregation based on domain difference-aware distillation to obtain an aggregated model, and then distributes the aggregated model to each client for the next round of training. The joint training step includes the following sub-steps: model uploading, model aggregation, model training, and iterative judgment. Based on the storage and computing power limitations of each client, the client classification model adopts a model structure of appropriate scale. In the model training sub-step, if the model structure is the same as the aggregated model structure, the client directly uses the aggregated model for initialization and performs further training based on its local data. If the structure of the model differs from that of the aggregated model, the client uses the aggregated model as the teacher model and its local client-side classification model as the student model for knowledge distillation. In the model aggregation sub-step, the server aggregates the collected client classification models as multiple teacher models. The server uses the collected autoencoder to determine the distance between the current sample in the server and the client data distribution. For client data that is closer to the current sample, the corresponding teacher model is assigned a larger weight. The current sample is unlabeled data.
2. The heterogeneous model aggregation method based on domain difference-aware distillation according to claim 1, characterized in that: in, The autoencoder comprises an encoder module and a decoder module, and is trained locally by the client. The autoencoder collection step is performed only once.
3. The heterogeneous model aggregation method based on domain difference-aware distillation according to claim 1, characterized in that: in, The autoencoder comprises an encoder module and a decoder module, and is trained locally by the client. The client is used in scenarios where data is constantly changing, and the autoencoder collection step is executed periodically.
4. The heterogeneous model aggregation method based on domain difference-aware distillation according to claim 1, Its features are: In the model upload sub-step, the client uploads its local initial classification model to the server as the client classification model. In the model aggregation sub-step, the server aggregates the collected client classification models to obtain an aggregated model, and then distributes the aggregated model to each client. In the model training sub-step, the client trains its local client classification model based on the received aggregate model, obtains the trained client classification model, and uploads it to the server. In the iterative judgment sub-step, it is determined whether the predetermined communication round has been reached. If the determination is not correct, the process returns to the model aggregation sub-step to proceed to the next round of training.
5. The heterogeneous model aggregation method based on domain difference-aware distillation according to claim 1, characterized in that: in, For the sample First, calculate the reconstruction error of all the autoencoders: In the formula, For the first A client-side autoencoder, Then, the weights of the teacher model are calculated using the reconstruction error: Finally, the weighted outputs of multiple teacher models are used to guide the training of the student model, with the loss function being: In the formula, KL divergence is used to measure the distance between two vectors. For the output of the student model, This is the weighted output of the teacher model.
6. A heterogeneous model aggregation system based on domain difference-aware distillation, characterized in that, include: server; Multiple clients, wherein the data domains of these multiple clients differ; An autoencoder collection module is used to collect autoencoders from each of the clients, each client trains its autoencoder using local data, and uploads the trained autoencoder to the server; and The joint training module is used to jointly train the client classification models of each client. Each client trains its local client classification model and uploads it to the server. The server uses the client classification model and the autoencoder to perform model aggregation based on domain difference-aware distillation to obtain an aggregated model, and then distributes the aggregated model to each client for the next round of training. The joint training module performs the following sub-steps: model uploading, model aggregation, model training, and iterative judgment. Based on the storage and computing power limitations of each client, the client classification model adopts a model structure of appropriate scale. In the model training sub-step, if the model structure is the same as the aggregated model structure, the client directly uses the aggregated model for initialization and performs further training based on its local data. If the structure of the model differs from that of the aggregated model, the client uses the aggregated model as the teacher model and its local client-side classification model as the student model for knowledge distillation. In the model aggregation sub-step, the server aggregates the collected client classification models as multiple teacher models. The server uses the collected autoencoder to determine the distance between the current sample in the server and the client data distribution. For client data that is closer to the current sample, the corresponding teacher model is assigned a larger weight. The current sample is unlabeled data.