A Federated Learning Method and System for Client-Side Selection Aggregation
By performing clustering analysis and weighted aggregation on the client-side classifier on the server side, and combining Gaussian noise generation and contrastive loss function to optimize the feature extractor, the problems of low-quality data sources and insufficient adaptability of feature extractors in traditional federated learning are solved, thereby improving the overall performance and generalization ability of the model.
Patent Information
- Authority / Receiving Office
- CN · China
- Patent Type
- Patents(China)
- Current Assignee / Owner
- CETC BIGDATA RES INST CO LTD
- Filing Date
- 2026-01-20
- Publication Date
- 2026-06-30
AI Technical Summary
Traditional federated learning algorithms have failed to effectively address the negative impact of low-quality or outlier data sources on the global model, and the coupled training of feature extractors and classifiers limits the model's adaptability and generalization ability.
By performing clustering analysis on the client classifier on the server side, selecting clients with high accuracy on the validation set for weighted aggregation, and generating enhanced features by adding Gaussian noise, the feature extractor is optimized, and the model performance is improved by using cosine similarity and contrastive loss functions.
It improves the performance and generalization ability of the global model, enhances the robustness and adaptability of the model, reduces the computational and communication burden, and improves resource utilization.
Smart Images

Figure CN121543770B_ABST
Abstract
Description
Technical Field
[0001] This invention belongs to the field of federated learning technology, specifically relating to a federated learning method and system for client selection aggregation. Background Technology
[0002] Federated learning is a distributed machine learning framework that allows multiple clients to collaboratively train a global model without sharing local data, effectively protecting user privacy and promoting cross-organizational data collaboration. However, traditional federated learning algorithms, such as FedAvg, simply aggregate the models updated by all clients without considering the quality differences between client models. This allows low-quality or outlier data sources (such as clients with severely skewed data distribution or insufficient training resources) to negatively impact the global model, thereby reducing overall model performance.
[0003] Furthermore, traditional federated learning methods typically train and update the feature extractor and classifier as a whole. This coupled training leads to the following problems: (1) the feature extractor cannot be dynamically optimized to adapt to the diversity of client data distribution, limiting its ability to adapt to heterogeneous data; (2) the classifier training is limited by fixed feature representations, making it difficult to fully utilize the richness of local client data. These challenges result in traditional federated learning facing problems such as limited model performance and insufficient generalization ability in practical applications. Summary of the Invention
[0004] This invention proposes a federated learning method and system for client selection aggregation. Through a collaborative optimization strategy of feature extractor and classifier, the server performs cluster analysis on the client classifier and selects clients whose validation set accuracy is higher than the average prediction accuracy of the clusters. The weights of the clients are evaluated based on their accuracy and cosine similarity, and weighted aggregation is performed to optimize the global classifier model. Gaussian noise is added to the feature prototypes to generate enhanced features, and the feature extractor is optimized based on contrastive loss.
[0005] According to a first aspect of the present invention, the present invention claims protection for a federated learning method for client-selected aggregation, comprising the steps of:
[0006] S1, The server initializes the global model and distributes the global model to multiple clients;
[0007] S2, the multiple clients use local data to freeze the feature extractor in the received global model and update the classifier, and calculate the validation set accuracy and various feature prototypes;
[0008] S3, after the server collects the classifiers, validation set accuracy and feature prototypes uploaded by each client, it clusters the multiple client classifiers based on JS divergence, and selects client classifiers whose validation set accuracy is higher than the average accuracy of the cluster from each cluster to form a high-quality client set.
[0009] S4, the server calculates the aggregation weight based on the verification set accuracy of each high-quality client and the cosine similarity between its classifier and the global classifier, and performs a weighted average of the classifiers to obtain the updated global classifier.
[0010] S5, the server adds Gaussian noise to the feature prototypes uploaded by each client to generate enhanced feature prototypes, and constructs a contrastive loss function to optimize the feature extractor;
[0011] S6, the server combines the updated feature extractor and classifier to form a new generation of global model, and repeats steps S2-S6 until the model converges or reaches the set communication round.
[0012] Further, step S2 includes:
[0013] After receiving the global model from the server, the feature extractor is frozen, and the classifier is optimized using local data through the cross-entropy loss function.
[0014] Calculate the validation set accuracy of the classifier on the local validation set and the feature prototype for each class.
[0015] Further, step S4 includes:
[0016] Using JS divergence as a distance metric, the K-means clustering algorithm is used to perform cluster analysis on the classifier of the client to obtain clusters;
[0017] Select client classifiers whose validation set accuracy is higher than the average accuracy of the cluster from each cluster to form a high-quality client set;
[0018] Calculate the aggregate weight for each high-quality client, and aggregate the high-quality client classifier using a weighted average.
[0019] Further, step S5 includes:
[0020] Gaussian noise is added to the original feature prototype of each client to generate an enhanced feature prototype;
[0021] Construct a contrastive loss function based on the enhanced feature prototype and the original feature prototype:
[0022] The feature extractor is updated by minimizing the contrastive loss function using the Adam optimizer.
[0023] Furthermore, step S2 also includes:
[0024] Only clients whose accuracy on the validation set is higher than the average accuracy of their respective clusters are selected for aggregation, and the number of selected clients S is less than the total number of clients K.
[0025] Furthermore, step S4 also includes:
[0026] While calculating aggregate weights, the performance of the client's local model and the degree of difference between the local model and the global model are also taken into account. The model performance is measured by the accuracy of the validation set, and the model difference is measured by the cosine similarity.
[0027] Furthermore, step S5 also includes:
[0028] The feature extractor optimization adopts a contrastive learning approach, which improves the robustness of the feature extractor by maximizing the consistency between the enhanced features and the original features in the feature space.
[0029] Furthermore, the variance of the Gaussian noise is set to 0.1, and the noise is independently and identically distributed across each feature dimension.
[0030] Furthermore, the temperature hyperparameter in the contrast loss function is used to adjust the sharpness of the similarity distribution;
[0031] The method selects only a subset of high-quality client classifiers for aggregation in each iteration, while improving the model's generalization ability through feature enhancement and contrastive learning.
[0032] According to a second aspect of the present invention, the present invention claims protection for a client-selected aggregation federated learning system, comprising:
[0033] One or more processors;
[0034] A memory that stores one or more programs, which, when executed by one or more processors, enable the one or more processors to implement a client-selective aggregation federated learning method.
[0035] This invention claims protection for a federated learning method and system for client selection aggregation, belonging to the field of federated learning technology. On the server side, cluster analysis is performed on client classifiers, and clients with validation set accuracy higher than the average prediction accuracy of the clusters are selected for weighted ensemble. Dynamic optimization of the feature extractor is implemented. Enhanced features are generated by adding Gaussian noise to feature prototypes, and the feature extractor is optimized based on a contrastive loss function, improving the model's adaptability and robustness to changes in data distribution. During client selection, cosine similarity is used as one of the metric standards to ensure sufficient diversity among the selected client models, avoiding overfitting to certain specific data features, thereby enhancing the generalization ability of the global model. The selective ensemble strategy reduces unnecessary computational and communication burdens, improves resource utilization, and maintains or enhances the overall performance of the model. Attached Figure Description
[0036] Figure 1 A flowchart illustrating a client-selective aggregation federated learning method claimed in an embodiment of the present invention;
[0037] Figure 2 This is a second flowchart of a client-selective aggregation federated learning method claimed in an embodiment of the present invention. Detailed Implementation
[0038] The technical solutions of the embodiments of this application will be clearly and completely described below with reference to the accompanying drawings. Obviously, the described embodiments are only a part of the embodiments of this application, and not all of the embodiments. Based on the embodiments of this application, all other embodiments obtained by those of ordinary skill in the art without creative effort are within the scope of protection of this application.
[0039] In this document, the term "embodiment" means that a particular feature, structure, or characteristic described in connection with an embodiment may be included in at least one embodiment of this application. The appearance of this phrase in various places throughout the specification does not necessarily refer to the same embodiment, nor is it a mutually exclusive, independent, or alternative embodiment. It will be explicitly and implicitly understood by those skilled in the art that the embodiments described herein can be combined with other embodiments.
[0040] According to a first embodiment of the present invention, the present invention claims protection for a federated learning method for client-selected aggregation, referring to... Figure 1 The steps include:
[0041] S1, The server initializes the global model and distributes the global model to multiple clients;
[0042] S2, the multiple clients use local data to freeze the feature extractor in the received global model and update the classifier, and calculate the validation set accuracy and various feature prototypes;
[0043] S3, after the server collects the classifiers, validation set accuracy and feature prototypes uploaded by each client, it clusters the multiple client classifiers based on JS divergence, and selects client classifiers whose validation set accuracy is higher than the average accuracy of the cluster from each cluster to form a high-quality client set.
[0044] S4, the server calculates the aggregation weight based on the verification set accuracy of each high-quality client and the cosine similarity between its classifier and the global classifier, and performs a weighted average of the classifiers to obtain the updated global classifier.
[0045] S5, the server adds Gaussian noise to the feature prototypes uploaded by each client to generate enhanced feature prototypes, and constructs a contrastive loss function to optimize the feature extractor;
[0046] S6, the server combines the updated feature extractor and classifier to form a new generation of global model, and repeats steps S2-S6 until the model converges.
[0047] In this embodiment, in step S1, the client receives the global model sent by the server:
[0048]
[0049] in, For global feature extractor, It is a global classifier. This represents the global model during initialization, where the feature extractor is frozen. ;
[0050] In step S6, a new generation of global model is generated. .
[0051] Further, step S2 includes:
[0052] After receiving the global model from the server, the feature extractor is frozen, and the local classifier is optimized using local data through the cross-entropy loss function.
[0053] Calculate the validation set accuracy of the classifier on the local validation set and the feature prototype for each class.
[0054] In this embodiment, a frozen feature extractor is used to extract local data features, the classifier is updated, and an SGD optimizer is used to train the client-side local classifier. The optimization objective function is:
[0055]
[0056] in, For the client The dataset, For the sample The tag, Let cross-entropy be the loss function. For feature extractor, For classifiers, For sample categories, For the first ( The local client classifier k after training (k times);
[0057] Evaluate the accuracy of the classifier trained on client k based on the validation dataset. ;
[0058] The weight of client k is evaluated based on accuracy and cosine similarity. ;
[0059]
[0060] in, Let k be the accuracy of client k on the validation set. For cosine similarity, For the first ( The local client classifier k after training (k times) For the first ( The global classifier is trained 10 times. This refers to the number of clients.
[0061] Calculate each category in client k Feature prototype vector:
[0062]
[0063] in, For the client The middle category is A subset of data For client k The number of class samples.
[0064] Furthermore, referring to Figure 2 Step S4 includes:
[0065] S41, Using JS divergence as a distance metric, the K-means clustering algorithm is used to perform cluster analysis on the classifier of the client to obtain clusters;
[0066] S42, Select client classifiers whose validation set accuracy is higher than the average accuracy of the cluster from each cluster to form a high-quality client set;
[0067] S43, calculate the aggregate weight of each high-quality client, and aggregate the classifier of the high-quality clients using a weighted average.
[0068] In this embodiment, selecting high-quality and diverse client classifiers for aggregation improves the overall system's recognition capability while reducing the time and space complexity of client aggregation. Therefore, this embodiment proposes a client-side filtering and aggregation strategy, specifically including the following steps:
[0069] Using JS divergence as the distance metric, and employing K-means clustering to partition clients, the client classifiers across different clusters exhibit strong complementarity, while those within the same cluster show high similarity. The distance calculation formula is as follows:
[0070]
[0071] in, and This represents the probability distribution of the classifier output in clients k and m. , The divergence is Kullback-Leibler.
[0072] To reduce the impact of weak client-side classifiers on the global aggregation model, and to increase the diversity of classifiers among clients, this patent selects classes based on prediction accuracy from various clusters. Higher than the average prediction accuracy of clusters High-quality client sequence , For the number of clients selected, This represents the total number of clients.
[0073] The selected client classifiers are weighted and aggregated to generate a global classifier model:
[0074]
[0075] in, For the number of clients selected, Let k be the weight of the client.
[0076] Further, step S5 includes:
[0077] Gaussian noise is added to the original feature prototype of each client to generate an enhanced feature prototype;
[0078] Construct a contrastive loss function based on the enhanced feature prototype and the original feature prototype:
[0079] The feature extractor is updated by minimizing the contrastive loss function using the Adam optimizer.
[0080] In this embodiment, Gaussian noise is applied to the global feature prototype to generate an enhanced feature prototype:
[0081]
[0082] in, The variance is Gaussian noise, typically set to 0.1. It is an identity matrix, ensuring that the noise is independently and identically distributed across all feature dimensions.
[0083] Constructing a contrastive loss function
[0084]
[0085] in, is a hyperparameter, and N is the total number of samples.
[0086] Unfreeze and fine-tune the feature extractor using the Adam optimizer:
[0087]
[0088] in, For server-side learning rate, To compare the gradients corresponding to the loss function.
[0089] Furthermore, step S2 also includes:
[0090] Only clients whose accuracy on the validation set is higher than the average accuracy of their respective clusters are selected for aggregation, and the number of selected clients S is less than the total number of clients K.
[0091] Furthermore, step S4 also includes:
[0092] While calculating aggregate weights, the performance of the client's local model and the degree of difference between the local model and the global model are also taken into account. The model performance is measured by the accuracy of the validation set, and the model difference is measured by the cosine similarity.
[0093] Furthermore, step S5 also includes:
[0094] The feature extractor optimization adopts a contrastive learning approach, which improves the robustness of the feature extractor by maximizing the consistency between the enhanced features and the original features in the feature space.
[0095] Furthermore, the variance of the Gaussian noise is set to 0.1, and the noise is independently and identically distributed across each feature dimension.
[0096] Furthermore, the temperature hyperparameter in the contrast loss function is used to adjust the sharpness of the similarity distribution;
[0097] The method selects only a subset of high-quality client classifiers for aggregation in each iteration, while improving the model's generalization ability through feature enhancement and contrastive learning.
[0098] According to a second aspect of the present invention, the present invention claims protection for a client-selected aggregation federated learning system, comprising:
[0099] One or more processors;
[0100] A memory that stores one or more programs, which, when executed by one or more processors, enable the one or more processors to implement a client-selective aggregation federated learning method.
[0101] This invention improves the overall performance of the global model by selecting clients whose validation set accuracy is higher than the average prediction accuracy of the clusters.
[0102] Using cosine similarity as one of the metrics can filter out client models that differ significantly from the current global model, ensuring diversity among the selected clients, helping to avoid overfitting to certain specific data features, and enhancing the generalization ability of the global model.
[0103] By adding Gaussian noise to the feature prototype to generate enhanced features and optimizing the feature extractor based on the contrastive loss function, the model can better adapt to changes in data distribution, thus enhancing its robustness and adaptability.
[0104] Cluster analysis is used to dynamically select high-quality clients, which can adapt to the dynamic changes in client data distribution, provide a more personalized model training strategy, and effectively solve the non-independent and identically distributed (Non-IID) problem.
[0105] Compared to traditional methods that indiscriminately aggregate updates from all clients, this method selects only a subset of high-quality and diverse clients for aggregation, reducing unnecessary computational overhead and improving resource utilization.
[0106] In the several embodiments provided in this application, it should be understood that the disclosed systems, apparatuses, and methods can be implemented in other ways. For example, the apparatus embodiments described above are merely illustrative; for instance, the division of units is only a logical functional division, and in actual implementation, there may be other division methods. For example, multiple units or components may be combined or integrated into another system, or some features may be ignored or not executed. Furthermore, the coupling or direct coupling or communication connection shown or discussed may be through some interfaces, or indirect coupling or communication connection between apparatuses or units, and may be electrical, mechanical, or other forms.
[0107] Furthermore, the functional units in the various embodiments of this application can be integrated into one processing unit, or each unit can exist physically separately, or two or more units can be integrated into one unit. The integrated units described above can be implemented in hardware or as software functional units. The above are merely embodiments of this application and do not limit the patent scope of this application. Any equivalent structural or procedural transformations made based on the description and drawings of this application, or direct or indirect applications in other related technical fields, are similarly included within the patent protection scope of this application.
[0108] The specific embodiments of the invention have been described in detail above, but they are only examples, and this application is not limited to the specific embodiments described above. For those skilled in the art, any equivalent modifications or substitutions to the invention are also within the scope of this application. Therefore, all equivalent changes, modifications, and improvements made without departing from the spirit and principles of this application should be covered within the scope of this application.
Claims
1. A federated learning method of client selection aggregation, characterized in that, Including the following steps: S1, The server initializes the global model and distributes the global model to multiple clients; S2, the multiple clients use local data to freeze the feature extractor in the received global model and update the classifier, and calculate the validation set accuracy and various feature prototypes; S3, after the server collects the classifiers, validation set accuracy and feature prototypes uploaded by each client, it clusters the classifiers uploaded by multiple clients based on JS divergence, and selects the classifiers of clients whose validation set accuracy is higher than the average accuracy of the cluster from each cluster to form a high-quality client set. S4, the server calculates the aggregation weight based on the verification set accuracy of each high-quality client and the cosine similarity between its classifier and the global classifier, and performs a weighted average of the classifiers to obtain the updated global classifier. S5, the server adds Gaussian noise to the feature prototypes uploaded by each client to generate enhanced feature prototypes, and constructs a contrastive loss function to optimize the feature extractor; S6, the server combines the updated feature extractor and classifier to form a new generation of global model, and repeats steps S2-S6 until the model converges or reaches the set number of communication rounds; Step S4 includes: Using JS divergence as a distance metric, the K-means clustering algorithm is used to perform cluster analysis on the classifier of the client to obtain clusters; From each cluster, select the classifiers of clients whose validation set accuracy is higher than the average accuracy of that cluster to form a high-quality client set; Calculate the aggregate weight for each high-quality client, and aggregate the classifiers of the high-quality clients using a weighted average. Using JS divergence as a distance metric, the client is divided using the K-means clustering approach; The distance calculation formula is: wherein, and denotes the classifier output probability distribution in the kth client and mth client, , is the Kullback-Leibler divergence; Select prediction accuracy from clusters Higher than the average prediction accuracy of clusters High-quality client sequence , For the number of clients selected, For the total number of clients, For the first The classifier for the Sth client after training. ; The selected client classifiers are weighted and aggregated to generate a global classifier model: in, , Let the weight be the weight of the k-th client. For the first The classifier for the k-th client after training. .
2. The federated learning method for client-selected aggregation according to claim 1, characterized in that, Step S2 includes: After receiving the global model from the server, the feature extractor is frozen, and the classifier is optimized using local data through the cross-entropy loss function. Calculate the validation set accuracy of the classifier on the local validation set and the feature prototype for each class.
3. The federated learning method for client-selected aggregation according to claim 1, characterized in that, Step S5 includes: Gaussian noise is added to the original feature prototype of each client to generate an enhanced feature prototype; Construct a contrastive loss function based on the enhanced feature prototype and the original feature prototype; The feature extractor is updated by minimizing the contrastive loss function using the Adam optimizer.
4. The federated learning method for client-selected aggregation according to claim 1, characterized in that, Step S2 further includes: Only clients whose accuracy on the validation set is higher than the average accuracy of their respective clusters are selected for aggregation, and the number of selected clients S is less than the total number of clients K.
5. The federated learning method for client-selected aggregation according to claim 1, characterized in that, Step S4 further includes: While calculating aggregate weights, the performance of the client's local model and the degree of difference between the local model and the global model are also taken into account. The model performance is measured by the accuracy of the validation set, and the model difference is measured by the cosine similarity.
6. The federated learning method for client-selected aggregation according to claim 1, characterized in that, Step S5 further includes: The feature extractor optimization adopts a contrastive learning approach, which improves the robustness of the feature extractor by maximizing the consistency between the enhanced features and the original features in the feature space.
7. The federated learning method for client-selected aggregation according to claim 4, characterized in that, The variance of the Gaussian noise is set to 0.1, and the noise is independently and identically distributed across all feature dimensions.
8. The federated learning method for client-selected aggregation according to claim 4, characterized in that, The temperature hyperparameter in the contrastive loss function is used to adjust the sharpness of the similarity distribution; the method updates only the classifiers of some high-quality clients in each iteration, while improving the generalization ability of the model through feature enhancement and contrastive learning.
9. A client-selective aggregation federated learning system, characterized in that, include: One or more processors; A memory having stored one or more programs that, when executed by one or more processors, cause the one or more processors to implement a client-selective aggregation federated learning method according to any one of claims 1 to 8.