Personalized federated learning model construction method and system based on conjugate multi-class gaussian process classification and deep kernel learning
By employing a conjugate multi-class Gaussian process classification and deep kernel learning approach, we have solved the challenges of parameterized assumptions and multi-class classification tasks in personalized Bayesian federated learning. This approach enables efficient posterior inference and communication, thereby improving the model's prediction accuracy and robustness.
Patent Information
- Authority / Receiving Office
- CN · China
- Patent Type
- Applications(China)
- Current Assignee / Owner
- RENMIN UNIVERSITY OF CHINA
- Filing Date
- 2026-03-23
- Publication Date
- 2026-06-19
AI Technical Summary
Existing personalized Bayesian federated learning methods rely on strongly parameterized assumptions in posterior inference, making non-parametric modeling impossible, and they also face technical challenges such as high communication overhead in multi-class classification tasks.
We employ a conjugate multi-class Gaussian process classification method combined with deep kernel learning. We achieve analytical posterior inference through One-vs-Each approximation and Pólya-Gamma augmentation, and use deep kernel learning to compress client model information into a small number of hyperparameters for communication.
It achieves accurate uncertainty estimation, reduces communication overhead, and improves the model's prediction accuracy, convergence speed, and robustness, making it suitable for resource-constrained client environments.
Smart Images

Figure CN122242815A_ABST
Abstract
Description
Technical Field
[0001] This invention belongs to the field of machine learning technology and relates to an innovative method and system for implementing personalized federated learning. Background Technology
[0002] With the advent of the big data era and the increasing awareness of user privacy protection, how to effectively utilize data scattered across various terminals for machine learning model training while ensuring data security has become a focus of common concern in academia and industry. Federated Learning (FL), as an emerging distributed machine learning paradigm, allows multiple clients to collaboratively train a global model without sharing the original data, thanks to its core idea of "data remains stationary, model moves." This demonstrates broad application prospects in privacy-sensitive fields such as financial risk control, medical image analysis, and the Internet of Things. The classic federated averaging algorithm achieves joint modeling of scattered data through an iterative process of local updates on the client and parameter averaging on the server.
[0003] However, the effectiveness of federated learning heavily relies on the ideal assumption that client data is independent and identically distributed. In complex real-world application scenarios, due to differences in user behavior patterns, the diversity of data collection environments, and the specificities of business domains, the data held by each client often exhibits significant non-independent and identically distributed characteristics, i.e., data heterogeneity. For example, in medical image analysis, patient groups, equipment models, and imaging protocols from different medical institutions may lead to huge differences in image features; in mobile keyboard input prediction, different users have drastically different language habits and commonly used vocabulary. This significant inconsistency in data distribution severely undermines the optimization foundation of standard federated averaging algorithms, leading to slow model convergence, decreased stability, and even causing the global model to deviate from the optimal solution, ultimately failing to achieve satisfactory performance on all clients.
[0004] To address the challenges posed by heterogeneous data, Personalized Federated Learning (PFL) has emerged. This technology aims to build a customized model for each client, enabling it to benefit from global knowledge while adapting well to local data distributions. Current mainstream PFL methods are based on a frequentist framework, with their core output being single-point estimates of model parameters. When the amount of local data on the client is extremely limited (e.g., in cold-start scenarios or with limited sample learning), such methods are prone to overfitting, resulting in severely insufficient model generalization ability. More importantly, in many high-risk decision-making domains, such as disease diagnosis, autonomous driving, and financial credit assessment, models not only need to provide accurate predictions but also reliable prediction confidence levels to support subsequent human intervention or risk control. Point estimation models cannot provide this uncertainty information, significantly limiting their deployment value in critical tasks.
[0005] To overcome the shortcomings of frequentist methods, researchers have introduced Bayesian inference into personalized federated learning, forming Personalized Bayesian Federated Learning (PBFL). PBFL achieves automatic regularization during training by introducing a prior distribution into the model parameters and calculating the posterior distribution after observing local data, effectively mitigating overfitting. Simultaneously, the posterior distribution naturally implies parameter uncertainty, which can be passed to the prediction stage through marginal integrals, outputting probabilistic prediction results. Although the aforementioned Personalized Bayesian Federated Learning methods have made initial progress in uncertainty quantification, the following key technical bottlenecks still urgently need to be addressed:
[0006] Strongly parameterized assumptions in posterior inference: To obtain a computable posterior form, existing methods generally employ parameterized variational distributions (such as Gaussian distributions) to approximate the true posterior. However, the true posterior often exhibits complex multimodal, long-tailed, or asymmetric structures. A simple Gaussian assumption introduces significant approximation errors, causing the inferred posterior distribution to deviate from the true situation, thus affecting model performance and calibration quality. Methods such as pFedBayes, in pursuit of analytical KL divergence calculations, force the posterior to be Gaussian; this constraint is particularly unreasonable in highly nonlinear deep models.
[0007] The lack of non-parametric modeling capabilities: Most existing methods are based on parametric Bayesian neural networks, whose expressive power is limited by the network structure and the choice of parametric priors. Inappropriate prior settings may lead to model bias, and the huge scale of network parameters results in huge communication overhead for transmitting the entire posterior distribution, which is extremely unfriendly to resource-constrained mobile clients.
[0008] Limitations of multi-class classification tasks: Many Bayesian methods face challenges when extended to multi-class classification. For example, while pFedGP attempts to introduce Gaussian processes into federated learning, it uses GP trees to decompose the multi-class problem into multiple binary sub-problems. Essentially, it is an ensemble-based approach, which not only incurs high computational costs but also fails to achieve true conjugate multi-class Gaussian process classification, thus limiting its performance on complex vision tasks.
[0009] Gaussian processes, as a type of Bayesian nonparametric model, directly impose priors on the function space, enabling them to adaptively adjust model complexity based on data and avoid the biases inherent in parameterized priors. They also naturally provide uncertainty estimation. However, in multi-class classification, the non-Gaussian likelihood function of standard Gaussian processes makes posterior inference difficult to solve analytically, hindering their efficient deployment in federated learning environments. In recent years, the development of Pólya-Gamma augmentation techniques and One-vs-Each approximation has provided theoretical possibilities for conjugate inference in multi-class Gaussian process classification, but no work has yet applied them to the framework of personalized Bayesian federated learning to simultaneously address the issues of non-conjugate posterior inference, high communication overhead, and data heterogeneity.
[0010] In summary, there is an urgent need for a personalized Bayesian federated learning method that can overcome parameterized posterior assumptions, achieve efficient non-parametric inference, and effectively handle multi-class classification tasks, so as to improve the prediction accuracy, uncertainty calibration ability, convergence speed and robustness of the model on non-independent and identically distributed data. Summary of the Invention
[0011] The purpose of this invention is to overcome the shortcomings of existing technologies and provide a method and system for constructing personalized Bayesian federated learning models based on conjugate multi-class Gaussian process classification and deep kernel learning. This addresses the technical problems of existing personalized Bayesian federated learning methods, such as reliance on strong parameterized assumptions in posterior inference, inability to perform non-parametric modeling, and the complexity of handling multi-class classification tasks. This invention aims to achieve the following objectives: In a highly heterogeneous federated learning environment, by introducing conjugate multi-class Gaussian process classification, the client can perform analytical posterior inference to obtain accurate uncertainty estimates; simultaneously, by constructing a shared global prior through deep kernel learning and employing an efficient parameter aggregation method, communication overhead is reduced, and the model's prediction accuracy, convergence speed, and robustness are improved, providing an accurate, reliable, and communication-efficient solution for distributed machine learning.
[0012] To achieve the above-mentioned objectives, this invention provides a method for constructing a personalized Bayesian federated learning model based on conjugate multi-class Gaussian process classification and deep kernel learning. This method is executed collaboratively by a central server and multiple clients, and its specific implementation is as follows:
[0013] Step 1: Initialization Phase
[0014] Global hyperparameters of the central server's initialization deep kernel function The deep kernel function is derived from a deep neural network. With basis kernel function Combining, that is It is used to map the input space to the feature space and calculate the covariance.
[0015] Step 2: Client-side local posterior inference and hyperparameter optimization
[0016] In each round of communication, the central server randomly selects a portion of clients to participate in training and sets the current global depth kernel hyperparameters. Broadcast to the selected clients. Each selected client Perform the following sub-steps:
[0017] Sub-step 2.1: Construct the local Gaussian process prior.
[0018] Client Based on received global depth kernel hyperparameters For each category Constructing independent zero-mean Gaussian process priors Thus, the joint prior of all categories is obtained. ,in It is a block diagonal covariance matrix, and each block corresponds to a kernel matrix of a class.
[0019] Sub-step 2.2: Use One-vs-Each approximation to decompose the multiple similarities.
[0020] For multi-class classification problems, the client The original softmax likelihood is decomposed into a series of products of binary likelihoods using the One-vs-Eve approximation:
[0021]
[0022] And construct a sparse matrix To encode the differences between all classes Such that the likelihood can be expressed as The function.
[0023] Sub-step 2.3: Introduce Pólya-Gamma augmentation to achieve conjugate inference.
[0024] To handle the non-conjugacy in binary similarity, the client... Introduce the Pólya-Gamma auxiliary variable for each comparison term , so that in a given Under these conditions, the likelihood takes on a Gaussian form:
[0025]
[0026] in Therefore, by combining Gaussian priors, a conditionally conjugate model is obtained.
[0027] Sub-step 2.4: Gibbs sampling for posterior inference.
[0028] Client Using Gibbs sampling to alternately update latent function values and auxiliary variables :
[0029] - Sampling , ,in ;
[0030] - Sampling .
[0031] Through multiple rounds of iteration, obtain and The posterior samples.
[0032] Sub-step 2.5: Optimize local depth kernel hyperparameters.
[0033] Client Based on the collected posterior samples, Monte Carlo estimation is used to calculate the log-marginal likelihood with respect to the hyperparameters. The gradient is calculated, and the local depth kernel hyperparameters are updated using the gradient method:
[0034]
[0035] in For learning rate, This represents the number of samples. Updated local hyperparameters. This will be the result of this round of client updates.
[0036] Step 3: Upload update on client
[0037] After completing local optimization, the client Updated depth kernel hyperparameters Uploaded to the central server.
[0038] Step 4: Server-side global aggregation
[0039] The central server collects the hyperparameters uploaded by all clients participating in this training round. And based on the local data of each client We perform a weighted average to obtain the new global depth kernel hyperparameters:
[0040]
[0041] The updated global hyperparameters are used as shared priors for the next round of communication.
[0042] Step 5: Iterate until convergence
[0043] Repeat steps 2 through 4 until the preset number of communication rounds is reached or the model performance converges.
[0044] Step 6: Make predictions using the trained model.
[0045] After training is complete, any client Its local model can be used for new samples Make a prediction. Given training data. and auxiliary variables of sampling latent function at test point The predicted distribution is in Gaussian form:
[0046]
[0047] The mean and covariance matrices can be analytically calculated using the conditional formula of the joint Gaussian distribution of training and testing. Finally, the predicted probabilities for each category are obtained through numerical integration (such as Gauss-Hermite quadrature). .
[0048] Furthermore, the efficient implementation of Gibbs sampling in step 2.4 employs the Hoffman-Ribak method, utilizing the block diagonal covariance matrix and The low-rank diagonal structure reduces the single-sampling complexity from Reduce to This significantly reduces the computational burden on the client side.
[0049] The base kernel function in deep kernel learning Choose one of the following kernels: radial basis function kernel, linear kernel, Laplace kernel, or cosine similarity kernel, and determine the optimal choice through experiments.
[0050] The present invention also provides a personalized Bayesian federated learning system based on conjugate multi-class Gaussian process classification and deep kernel learning, including a central server and multiple clients. The central server and clients are respectively loaded with computer programs, which, when executed by a processor, implement the steps of the above method.
[0051] The advantages and positive effects of the technical solution to be protected by this invention are as follows:
[0052] Non-parametric posterior inference: Utilizing the non-parametric properties of Gaussian processes, Bayesian inference is performed directly in the function space, avoiding parameterized assumptions about network weights. This allows for a more accurate capture of the true posterior distribution, improving the model's expressive power and uncertainty calibration quality. This invention is the first to introduce conjugate multi-class Gaussian process classification into a personalized Bayesian federated learning framework. Leveraging the non-parametric properties of Gaussian processes, Bayesian inference is performed directly in the function space without imposing any parameterized assumptions on network weights or the posterior distribution. The expressive power of Gaussian processes can adaptively adjust with the amount of data, thus more accurately capturing the complex form of the true posterior distribution and significantly improving the model's expressive power and the reliability of uncertainty estimation. Compared with existing parameterized methods, experimental results on multiple benchmark datasets demonstrate that the prediction accuracy is improved by an average of 1-3%, and the expected calibration error is reduced by more than 30%, fully proving the advantages of non-parametric modeling.
[0053] Conjugate Multi-Class Classification: This invention creatively combines the One-vs-Each approximation with the Pólya-Gamma augmentation technique, achieving for the first time a fully analytical conditional posterior distribution in multi-class Gaussian process classification. Specifically, the One-vs-Each approximation decomposes the multi-class similarity probability into a product of a series of binary similarities, while the Pólya-Gamma augmentation ensures that each binary similarity probability, after introducing auxiliary variables, takes a Gaussian form, thus forming a conditional conjugate pair with the Gaussian process prior. This technological breakthrough enables clients to perform accurate posterior inference through efficient Gibbs sampling, without relying on heuristic methods such as variational approximation or Monte Carlo methods, fundamentally eliminating errors introduced by approximate inference. Simultaneously, because the sampling process utilizes the block diagonal structure and low-rank property of the covariance matrix, the complexity of a single sampling is reduced from... Reduce to It maintains inference accuracy while ensuring computational efficiency, making it particularly suitable for federated learning scenarios where client resources are limited.
[0054] Communication Efficiency: Existing Bayesian federated learning methods such as FedPA and FedEP require the transmission of complete posterior distribution parameters (e.g., mean and variance) or particle sets (e.g., FedWBA) between the client and server. When the number of model parameters is large, the communication cost increases dramatically, becoming a bottleneck for system scalability. For example, in deep neural networks, transmitting the Gaussian posterior mean and variance means twice the amount of parameters transmitted, while particle methods require transmitting multiple complete model copies, which is extremely unfriendly to mobile clients with limited bandwidth. This invention uses deep kernel learning technology to compress the client model information into a small number of hyperparameters of the deep neural network. In each round of communication, the client only needs to upload these hyperparameters, rather than the complete posterior distribution or model parameters. Experiments show that, with the same model capacity, the communication overhead of this invention is only about 50% of pFedBayes, 20% of FedWBA, and roughly on par with the standard FedAvg, while providing complete Bayesian uncertainty quantification capabilities. This communication efficiency advantage enables this invention to be deployed on a large scale in bandwidth-constrained practical application scenarios such as the Internet of Things and mobile devices without sacrificing privacy protection and uncertainty quantification capabilities.
[0055] Robustness and Generalization Ability: At the theoretical level, this invention derives the generalization error bound of the proposed method based on the PAC-Bayesian framework, proving the quantitative relationship between the model's expected performance on new clients and the number of training clients, as well as the posterior-prior KL divergence. This provides a rigorous theoretical guarantee for the method's application in the real world. At the engineering practice level, the inherent kernel smoothness of Gaussian processes endows the model with natural robustness: on the one hand, the shared deep kernel prior can effectively suppress abnormal updates from noisy clients or malicious attackers; on the other hand, the prediction mechanism based on similarity metrics makes it difficult for attackers to disrupt the global function through local parameter perturbations. Experimental results show that when 30% of client data is contaminated by five different types of noise, the accuracy decrease of this invention is significantly smaller than all comparative methods. In the Byzantine attack scenario, after being combined with robust aggregation rules such as NNM, this invention can maintain a prediction accuracy of over 85% even with 20% malicious clients, demonstrating excellent robustness and stability. This combination of theoretical rigor and engineering practicality enables this invention to meet the stringent reliability requirements of high-security fields such as finance, healthcare, and autonomous driving.
[0056] In summary, this invention provides a novel non-parametric solution for personalized Bayesian federated learning. It achieves accurate posterior inference through conjugate multi-class Gaussian process classification and realizes efficient feature sharing through deep kernel learning. It significantly outperforms existing technologies in terms of prediction accuracy, uncertainty quantification, communication efficiency, and robustness, and has good theoretical value and broad application prospects. Attached Figure Description
[0057] Figure 1 is a flowchart of the algorithm of the present invention;
[0058] Figure 2 shows the experimental results of uncertainty quantification in this invention;
[0059] Figure 3 shows the experimental results of the convergence rate of the present invention;
[0060] Figure 4 The results are experimental results regarding the input noise robustness of this invention. Detailed Implementation
[0061] To make the objectives, technical solutions, and advantages of this invention clearer, the invention will be further described in detail below with reference to embodiments. It should be understood that the specific embodiments described herein are merely illustrative and not intended to limit the invention.
[0062] This invention proposes a personalized Bayesian federated learning method and system based on conjugate multi-class Gaussian process classification and deep kernel learning, belonging to the field of federated learning technology. This method aims to solve the technical problems of existing personalized Bayesian federated learning methods, such as reliance on strong parameterization assumptions in posterior inference, complex multi-class classification processing, and high communication overhead. This invention mainly includes two key modules: At the client level, a One-vs-Each approximation is used to decompose the multi-class similarity probability into a binary similarity product, and an auxiliary variable is introduced by combining Pólya-Gamma augmentation technology. For the first time, an analytical conditional posterior distribution is achieved in a multi-class Gaussian process. The client performs accurate posterior inference through efficient Gibbs sampling, avoiding errors introduced by variational approximation. At the server level, the client encodes local data features into deep kernel hyperparameters through deep kernel learning and uploads them. The server updates the global deep kernel prior using weighted average aggregation. This prior is broadcast as a regularization term to all clients in the next round of communication, realizing cross-client knowledge sharing and collaborative optimization of personalized learning. Theoretically, this invention derives the generalization error bound of the proposed method based on the PAC-Bayesian framework, providing a theoretical guarantee for the effectiveness of the algorithm. Experimentally, comparisons with various baseline methods on multiple real-world datasets such as MNIST, CIFAR-10, CIFAR-100, and Tiny-ImageNet show that the proposed method excels in prediction accuracy, uncertainty calibration, and convergence rate, particularly in highly heterogeneous data and low-sample scenarios, demonstrating stable model performance. Furthermore, robustness experiments verify the invention's strong resistance to noisy data and Byzantine attacks, while ablation experiments further validate the necessity of each key component. This invention effectively addresses the limitations of existing personalized Bayesian federated learning methods in posterior inference and global aggregation, significantly improving the performance, reliability, and communication efficiency of federated learning systems.
[0063] This system is a Bayesian inference framework for personalized federated learning. Its core idea is to achieve collaborative optimization of client-side local posterior inference and server-side global prior aggregation through conjugate multi-class Gaussian process classification and deep kernel learning, while protecting data privacy. The system consists of a client-side local posterior inference module and a server-side global prior aggregation module, which achieve a balance between knowledge sharing and personalized modeling through iterative communication.
[0064] I. Overall System Workflow
[0065] The system employs a federated learning architecture. Each round of communication includes the following steps: the server broadcasts the current global depth kernel hyperparameters to all participating clients; each client receives the global hyperparameters as prior information for its local model and uses its local data to perform posterior inference; after completing local computation, each client uploads its updated local depth kernel hyperparameters to the server; the server aggregates all hyperparameters uploaded by clients, performs a weighted average based on the amount of local data from each client, and obtains a new global depth kernel hyperparameter for the next round of communication. This process is repeated until the model converges or the preset number of communication rounds is reached.
[0066] II. Working principle of the client-side local posterior inference module
[0067] The core task of the client-side local posterior inference module is to update the latent function posterior distribution and optimize the depth kernel hyperparameters using local data. This module employs a conjugate multi-class Gaussian process classification method to solve multi-class classification problems, and its technical approach is as follows:
[0068] 1. Multiple similarity decomposition and introduction of auxiliary variables
[0069] For multi-class classification tasks, the module first constructs a sparse matrix based on local labels. This sparse matrix encodes all inter-class differences, approximating the multi-class likelihood as a product of multiple binary likelihoods. This decomposition transforms the complex multi-class problem into a series of binary sub-problems, laying the foundation for subsequent analytical computation. Building upon this, an auxiliary variable following a Pólya-Gamma distribution is introduced for each binary likelihood. With the auxiliary variable, the likelihood function, given the auxiliary variable, takes on a Gaussian form, thus achieving conditional conjugation with the Gaussian process prior.
[0070] 2. Gibbs sampling posterior inference
[0071] After obtaining the conditional conjugate structure, the module uses the Gibbs sampling algorithm to alternately update the latent function values and auxiliary variables. The sampling process initializes multiple parallel Gibbs chains, each iterating independently: in each iteration, the auxiliary variables of the current round are sampled based on the latent variables of the previous round, and then the latent variables of the current round are sampled based on the auxiliary variables of the current round. Since there is a conditional conjugate relationship between the auxiliary variables and the latent variables, their conditional posterior distributions are both standard distribution families, allowing for direct sampling. To reduce the complexity of a single sampling, the module utilizes the block diagonal structure of the covariance matrix and the low-rank, diagonal property of multiplying the transpose of the sparse matrix with the diagonal matrix of the auxiliary variable and then with the sparse matrix, employing the Hoffman-Ribak method for efficient sampling. After completing the preset number of iterations, the auxiliary variables and latent variables at the ends of each Gibbs chain constitute the posterior samples.
[0072] 3. Deep kernel hyperparameter optimization
[0073] After obtaining the posterior samples, the module optimizes the local depth kernel hyperparameters based on these samples. The optimization objective is to maximize the log-marginal likelihood. However, since the marginal likelihood is difficult to calculate directly, the module uses Fisher's identity to transform the gradient of the log-marginal likelihood with respect to the hyperparameters into the expected gradient under the posterior distribution. This expectation is estimated using the posterior samples obtained through Gibbs sampling via the Monte Carlo method. After obtaining the gradient estimate, gradient descent is used to update the local depth kernel hyperparameters with a set learning rate, allowing the kernel function to better fit the local data distribution.
[0074] III. Working Principle of Server Global Prior Aggregation Module
[0075] The server-side global prior aggregation module is responsible for integrating the local knowledge of each client to generate a better global prior. After receiving the local depth kernel hyperparameters uploaded by all clients, the module first calculates the local data volume of each client, summing the data volume across all clients. Then, it multiplies each client's local depth kernel hyperparameter by its local data volume, sums the results across all clients, and divides the sum by the total data volume to obtain the weighted average global depth kernel hyperparameters. This weighted averaging strategy ensures that clients with larger data volumes contribute more to the global prior, while retaining personalized information from a small number of data-rich clients. The updated global hyperparameters are broadcast as a shared prior to all clients for the next round of communication, enabling cross-client knowledge transfer.
[0076] IV. Theoretical Guarantees and Convergence Properties
[0077] The system provides a generalization error guarantee based on the PAC-Bayes theoretical framework. The true risk is defined as the expected classification error of a new sample under the posterior distribution, the empirical risk is the classification error rate on the training set, and the complexity term is composed of the KL divergence between the posterior and prior distributions. The generalization error bound indicates that the true risk of the model is bounded with high probability by the inverse Bernoulli KL divergence function of the sum of the empirical risk and the complexity term. This theoretical guarantee ensures that the system possesses good generalization ability while protecting privacy.
[0078] Through the above mechanism, the system realizes personalized modeling under the Bayesian federated learning framework: the client captures personalized data features through local posterior inference, and the server achieves knowledge sharing through global prior aggregation. The two work together to optimize in iteration, ultimately achieving multiple goals such as protecting privacy, adapting to heterogeneous data, and improving generalization performance.
[0079] The computer device, computer-readable storage medium, and information data processing terminal of the present invention revolve around a closed-loop architecture of "storage-calling-execution-coordination" in their core working logic. Although the three have different hardware forms, their core execution processes are highly unified. They all provide the physical carrier and operating environment for the aforementioned personalized Bayesian federated learning algorithm, ensuring the efficient implementation of the algorithm in distributed scenarios.
[0080] Computer-readable storage media, serving as the static carrier of the algorithm, preferably employs non-volatile storage media such as solid-state drives, high-speed flash memory, or read-only memory, and stores the computer program in binary format. This program includes execution logic for both the client and server ends, specifically encapsulating complete code instructions for conjugate multi-class Gaussian process classification, Gibbs sampling optimization, deep kernel hyperparameter iteration, and global weighted aggregation. It also pre-stores the initial parameter templates and data interaction protocols required for algorithm execution, providing a foundation for device execution.
[0081] The computer device serves as the core computing platform, with its memory and processor establishing bidirectional communication via a high-speed bus. After startup, the processor first loads the computer program from the storage medium, completing process initialization; then, it calls the corresponding module instructions based on its deployment role (client or server). The client-side processor executes local posterior inference logic, completing sparse matrix construction, introducing Pólya-Gamma auxiliary variables, Gibbs sampling iteration, and hyperparameter gradient calculation; the server-side processor executes global aggregation logic, completing hyperparameter reception, data weighting operations, and result broadcasting. Preferably, the processor employs a multi-core architecture, accelerating Gibbs sampling and matrix operations through parallel computing, thereby improving algorithm execution efficiency.
[0082] As a scenario-based implementation platform, the information data processing terminal integrates data acquisition, transmission, and processing functions. The terminal acquires local business data (such as images, text, and sensor data) through a data interface, preprocesses it, and then transmits it to the built-in processor to execute algorithm steps. At the same time, it establishes encrypted communication with the server through a network module to achieve two-way interaction between hyperparameters and calculation results, and finally outputs personalized classification results, completing the entire process from data input to intelligent decision-making.
[0083] The algorithm steps of this invention include:
[0084]
[0085] To comprehensively and thoroughly evaluate the performance of the Personalized Bayesian Federated Learning (FedCMGP) method based on conjugate multi-class Gaussian process classification and deep kernel learning proposed in this invention, we meticulously planned and implemented a series of rigorous experiments. The specific experimental setup covers several key aspects, including the experimental environment, comparison methods, datasets, model selection, hyperparameter configuration, and evaluation metrics. Details are as follows:
[0086] Experimental environment
[0087] All experiments were run on a high-performance server equipped with an NVIDIA GeForce RTX 5090 GPU. This hardware configuration provides powerful parallel computing capabilities, ensuring the efficient and stable operation of the method of this invention and all comparative methods, and minimizing the potential impact of hardware performance differences on the experimental results. To ensure the reliability, repeatability, and statistical significance of the experimental results, each experiment was run five times independently. Performance metrics of the model were recorded during each run, and the final results are presented as the mean and standard error (Mean ± SEM) of the five runs. Throughout the experiments, strict control was maintained over the consistency of the experimental environment, including the operating system version, deep learning framework (such as PyTorch), and versions of other relevant dependent libraries, to ensure the accuracy and comparability of the experimental results.
[0088] Comparison Methods
[0089] To accurately measure the advantages of the method of this invention, we selected several representative federated learning (FL), personalized federated learning (PFL), and Bayesian federated learning (BFL) methods as baselines for comparison. All comparison methods were implemented under a unified standardized framework to ensure the fairness of the comparison.
[0090] Federated learning basic algorithms:
[0091] FedAvg, a classic algorithm in the field of federated learning, updates the global model by simply averaging the client model parameters and serves as an important benchmark for evaluating the performance of other federated learning algorithms. It performs well under the assumption of independent and identically distributed data, but has limitations in real-world scenarios where the data is not independent and identically distributed.
[0092] FedProx: Introduces near-terms on top of FedAvg to constrain the difference between local updates and the global model, improving performance on heterogeneous data.
[0093] Personalized Federated Learning Algorithms:
[0094] Ditto achieves a balance between personalization and generalization by simultaneously learning a global model and a personalized model, and by introducing a fairness regularization term.
[0095] FedALA: Through an adaptive local aggregation mechanism, it learns personalized aggregation weights for each client, thereby improving the personalization effect.
[0096] FedDBE: Improves federated learning's generalization ability on heterogeneous data by eliminating domain bias in the representation space.
[0097] FedAS: Achieves superior personalized performance by bridging inconsistencies in personalized federated learning. The performance of FedCMGP in multi-class classification tasks can be examined by comparing it with FedAS.
[0098] Bayesian federated learning algorithm:
[0099] pFedBayes: Within a personalized federated learning framework, it employs variational inference for posterior reasoning, assuming a Gaussian posterior distribution and optimizing it by maximizing the lower bound of evidence. Compared to pFedBayes, this invention's non-parametric method demonstrates its advantages in avoiding approximation errors and improving the quality of uncertainty quantification.
[0100] pFedGP introduces Gaussian processes into federated learning, but uses GP trees to decompose multi-class problems into multiple binary sub-problems, essentially making it an ensemble-based approach. A comparison with pFedGP verifies the advanced nature and computational efficiency of this invention's conjugate multi-class Gaussian process classification.
[0101] FedWBA: Employs Wasserstein's centroid to aggregate client posterior particles, achieving geometrically meaningful aggregation on a probability distribution manifold. Compared to FedWBA, this helps demonstrate the advantageous balance between communication efficiency and geometrical aggregation of this invention.
[0102] Dataset
[0103] To comprehensively evaluate the performance of the method of this invention on datasets of different complexities and sizes, we selected four benchmark datasets widely used in the field of image classification and simulated the non-independent and identically distributed characteristics in real-world scenarios through specific data partitioning methods.
[0104] The MNIST dataset, composed of images of handwritten digits, is a classic dataset widely used in machine learning. It contains 70,000 samples, with 60,000 used for training and 10,000 for testing. These images are all 28×28 pixel grayscale images, with each sample corresponding to one of the 10 digits from 0 to 9. The MNIST dataset has a moderate size and relatively simple image structure, making it very suitable as a foundational dataset for initially validating the performance of algorithms on basic image classification tasks.
[0105] The CIFAR-10 dataset contains 60,000 color images across 10 different categories, with 50,000 used for training and 10,000 for testing. These images have a resolution of 32×32 pixels and cover common object categories such as airplanes, cars, birds, and cats. Compared to the MNIST dataset, the CIFAR-10 dataset has lower image resolution and relatively smaller differences between categories, making image classification tasks more challenging. Experiments on the CIFAR-10 dataset allow for in-depth evaluation of FedCMGP's performance on moderately complex image classification tasks.
[0106] The CIFAR-100 dataset is an extension of CIFAR-10, containing 100 categories and a total of 60,000 images, with 50,000 images in the training set and 10,000 in the test set. The images in this dataset are also 32×32 pixel color images, covering a wider range of object categories. The CIFAR-100 dataset has a larger number of categories, a more complex data distribution, and more blurred category boundaries, which places extremely high demands on the model's generalization ability and classification accuracy. It can comprehensively examine the performance of FedCMGP in large-scale, highly complex image classification tasks.
[0107] Tiny-ImageNet dataset: A scaled-down version of the ImageNet dataset, containing 200 categories, with 500 training images, 50 validation images, and 50 test images per category. Each image is 64×64 pixels. This dataset is larger in scale, has more categories, and more complex image content, making it ideal for evaluating the performance of algorithms in complex real-world scenes. Experimental results on this dataset fully validate the scalability and practical application potential of the method presented in this invention.
[0108] Construction of Non-Independent and Identically Distributed Data: To simulate the non-independent and identically distributed characteristics of client data in real-world scenarios (i.e., heterogeneous label distribution), we performed special processing on the aforementioned datasets. Specifically, for the MNIST and CIFAR-10 datasets, we randomly assigned 5 unique labels to each client; for the CIFAR-100 dataset, we randomly assigned 10 labels from different superclasses to each client; and for the Tiny-ImageNet dataset, we randomly assigned 20 unique labels to each client. In this way, we constructed experimental datasets with varying degrees of data heterogeneity, more closely reflecting the data distribution in real-world applications. Furthermore, to further verify the robustness of the method under different degrees of heterogeneity, we also adopted a Dirichlet distribution (parameter...). The client-side tag distribution is sampled to simulate various data partitioning scenarios, ranging from strongly heterogeneous to approximately homogeneous.
[0109] Model selection
[0110] Considering the limitations of client computing power and memory in federated learning, as well as the differences in complexity between different datasets, we selected appropriate model architectures for different datasets to achieve good classification performance while ensuring computational efficiency.
[0111] MNIST dataset: A multilayer perceptron (MLP) with a single hidden layer has 512 neurons and the activation function is ReLU. This model has a relatively simple structure and low computational cost, but it can effectively handle image classification tasks on the MNIST dataset, making it suitable for experiments with limited client resources.
[0112] CIFAR-10 dataset: LeNet-5 convolutional neural network is used. LeNet-5 is a classic convolutional neural network architecture, containing convolutional layers, pooling layers, and fully connected layers, capable of automatically extracting image features. For the CIFAR-10 dataset, the standard LeNet-5 structure is maintained to balance computational cost and classification performance.
[0113] CIFAR-100 dataset: ResNet-10 residual network is used. Because the images in the CIFAR-100 dataset are more complex and have a larger number of categories, deeper network structures are needed to learn the complex features in the data. ResNet-10 alleviates the gradient vanishing problem in deep networks through residual connections, improving model performance with moderate computational overhead.
[0114] Tiny-ImageNet dataset: Employs ResNet-18 residual network. This dataset is larger in scale and has more categories, requiring a more powerful model. ResNet-18, as a widely used benchmark model, strikes a good balance between feature extraction capability and computational efficiency, making it suitable for evaluating algorithm performance on large-scale datasets.
[0115] Hyperparameter settings
[0116] To ensure that the FedCMGP method of this invention achieves optimal performance under different datasets and scenarios, we meticulously set and adjusted key hyperparameters and verified the rationality of the selection through experiments. The specific settings and selection criteria are as follows:
[0117] Communication Rounds and Client Sampling Ratio: In all experiments, the number of communication rounds was consistently set to 500. This choice was based on preliminary experimental observations: FedCMGP achieves rapid convergence within 200 rounds on most datasets, with subsequent rounds showing a stable performance improvement phase. 500 rounds are sufficient to ensure adequate model convergence and allow for observation of long-term trends. The client sampling ratio was fixed at 20% (i.e., for total clients of 200, 300, and 500, 40, 60, and 100 clients were sampled per round, respectively). This ratio strikes a good balance between communication efficiency and model update quality, avoiding the communication pressure caused by full client participation while ensuring the representativeness of aggregated updates. Experiments show that convergence speed decreases significantly when the sampling ratio is below 10%, while communication overhead increases above 30% with limited performance improvement.
[0118] Hyperparameters related to deep kernel learning:
[0119] Kernel function: A radial basis function (RBF) kernel is used, and its expression is as follows: ,in For amplitude parameters, is the length scale parameter. The RBF kernel possesses smoothness and infinite differentiability, effectively measuring similarity in the feature space, and is the most commonly used choice among kernel methods. Through comparative experiments with other base kernels (linear kernel, Laplacian kernel, cosine similarity kernel), the RBF kernel achieved the best performance on most datasets, and was therefore selected as the default base kernel in this invention.
[0120] Deep network embedding dimension: We set different embedding dimensions for different tasks based on the complexity of the dataset. For simple datasets like MNIST, the embedding dimension is set to 128; for complex datasets like CIFAR-10, CIFAR-100, and Tiny-ImageNet, the embedding dimension is increased to 512. This setting is based on the following considerations: too low a dimension cannot fully capture the complex patterns in the data, while too high a dimension may introduce redundant parameters and increase the computational burden. Experiments show that MNIST reaches performance saturation at 128 dimensions, while CIFAR-10 requires 512 dimensions to fully utilize the representational power of the deep kernel.
[0121] Gibbs sampling-related hyperparameters:
[0122] Number of parallel chains Set to 5. (By comparison) Performance, found when At this point, the variance of the Monte Carlo gradient estimation is effectively controlled, and further increasing the number of chains offers limited performance improvement but results in a linear increase in computational overhead. Therefore, [the following option is chosen]. As a balance point between computational accuracy and efficiency.
[0123] Local Gibbs iteration steps : Dynamically adjusts based on dataset complexity. For simple datasets such as MNIST, Good performance can be achieved; for complex datasets such as CIFAR-10, CIFAR-100, and Tiny-ImageNet, the performance can be increased to... Experiments show that when the number of iterations is too small, the posterior samples are not sufficiently mixed, affecting the quality of gradient estimation; when the number of iterations is too large, numerical errors may accumulate (especially when the covariance matrix is repeatedly decomposed), leading to a decrease in performance.
[0124] Local hyperparameter optimization related hyperparameters:
[0125] Optimizer and Learning Rate: The Adam optimizer is used for gradient ascent optimization of local deep kernel hyperparameters. Adam combines the advantages of momentum and adaptive learning rate, enabling stable convergence on complex loss surfaces. The learning rate is uniformly set to 0.001, which has proven robust to various datasets in preliminary experiments. To verify its effectiveness, we compare the learning rate with... At a convergence performance of 0.001, the best balance is achieved between convergence speed and final accuracy.
[0126] Monte Carlo gradient estimation sample size: obtained directly using Gibbs sampling. Gradient estimation is performed on PG samples. This setting is consistent with the number of parallel chains, ensuring the reliability of the gradient estimation. Theoretically, more samples can reduce the estimation variance, but experiments show... It can already provide sufficiently accurate gradient directions, so no additional sampling is required.
[0127] Effect Experiment Analysis 1: Accuracy Experiment
[0128] In the experimental investigation of this invention, to comprehensively evaluate the performance of the proposed method FedCMGP, experiments were conducted on four major image classification datasets: MNIST, CIFAR-10, CIFAR-100, and Tiny-ImageNet. Experiments were carefully set up with 200, 300, and 500 clients respectively in a federated learning scenario. The test accuracy (% ± SEM) was used as the core metric, and detailed comparisons were made with various mainstream federated learning and personalized Bayesian federated learning methods, including FedAvg, FedProx, Ditto, FedALA, FedDBE, FedAS, pFedBayes, pFedGP, and FedWBA. Specific experimental results are shown in Table 1 below. All experiments were run 5 times independently, and the average and standard error were taken to ensure the statistical reliability of the results.
[0129] 1. MNIST dataset
[0130] On the MNIST dataset, the method of this invention achieved the highest test accuracy across all client count settings. With 200 clients, the test accuracy reached 98.48% ± 0.02%. Among all compared methods, FedAvg had the lowest accuracy at 93.71% ± 0.03%, a significant improvement of approximately 4.77 percentage points. FedWBA came in second with an accuracy of 98.16% ± 0.02%, a 0.32 percentage point improvement. When the number of clients increased to 300, the method of this invention maintained its leading position with a test accuracy of 98.29% ± 0.04%. At this point, FedAvg remained the worst method (92.89% ± 0.02%), an improvement of approximately 5.40 percentage points. FedWBA, in second place, achieved an accuracy of 97.78% ± 0.01%, a 0.51 percentage point improvement. When the number of clients increased to 500, the accuracy of the method described in this invention remained at 97.08% ± 0.02%, continuing to lead. The worst method was FedAvg (91.93% ± 0.01%), which was improved by approximately 5.15 percentage points. The second-best method, pFedGP, had an accuracy of 97.06% ± 0.03%, which was improved by 0.02 percentage points. These results demonstrate that on the relatively simple MNIST dataset, the method described in this invention consistently outperforms all baseline methods, and its advantage remains even as the number of clients increases and the amount of local data decreases.
[0131] 2. CIFAR-10 dataset
[0132] On the CIFAR-10 dataset, the method of this invention also performs excellently. With 200 clients, the accuracy of this method is 73.45% ± 0.23%, significantly better than all comparable methods. The worst method is FedAvg (61.93% ± 0.14%), and this method improves upon it by approximately 11.52 percentage points. The second best is pFedGP, with an accuracy of 72.71% ± 0.14%, and this method improves upon it by 0.74 percentage points. With 300 clients, the accuracy of this method is 72.22% ± 0.13%, still the best. The worst method is FedAvg (59.38% ± 0.16%), and this method improves upon it by approximately 12.84 percentage points. The second best method, pFedGP, has an accuracy of 71.51% ± 0.19%, and this method improves upon it by 0.71 percentage points. With 500 clients, the accuracy of our method is 68.75% ± 0.10%, maintaining our leading position. The worst method is FedAvg (55.04% ± 0.11%), which our method improves by approximately 13.71 percentage points. The second-best method, pFedGP, has an accuracy of 68.53% ± 0.38%, which our method improves by 0.22 percentage points. It can be seen that on the moderately complex CIFAR-10 dataset, our method demonstrates a more significant performance advantage, especially in scenarios with a large number of clients and more dispersed data, where the improvement exceeds 13 percentage points. This fully demonstrates the powerful ability of non-parametric Gaussian processes to handle data heterogeneity and few-shot learning.
[0133] 3. CIFAR-100 dataset
[0134] The CIFAR-100 dataset has more categories and more complex tasks. When the number of clients is 200, the accuracy of our method is 74.56% ± 0.14%, ranking first. The worst method is FedAvg (68.49% ± 0.65%), and our method improves upon it by approximately 6.07 percentage points. Second place is Ditto, with an accuracy of 73.86% ± 0.29%, and our method improves upon it by 0.70 percentage points. When the number of clients is 300, the accuracy of our method is 72.36% ± 0.19%, still the best. The worst method is FedDBE (69.22% ± 0.26%), and our method improves upon it by approximately 3.14 percentage points. Second place is FedAS, with an accuracy of 71.03% ± 0.26%, and our method improves upon it by 1.33 percentage points. When the number of clients is 500, the accuracy of our method is 68.50% ± 0.25%, maintaining second place. The worst-performing method was FedAvg (61.06% ± 0.16%), which is improved by approximately 7.44 percentage points. This invention ranked first with 200 and 300 clients, and second with 500 clients, though the difference from the first was minimal and significantly higher than other methods. With 500 clients, this invention achieved 68.50%, only 0.2 percentage points lower than the first-place pFedGP (68.70%), but far exceeding the third-place FedWBA (68.28%). This invention ranked first with 200 and 300 clients, and second with 500 clients, with a small difference from the first, while significantly outperforming other methods.
[0135] 4. Tiny-ImageNet dataset
[0136] The Tiny-ImageNet dataset, with its larger size and more categories, is crucial for evaluating the algorithm's scalability. With 200 clients, the accuracy of our method is 47.38% ± 0.24%, significantly outperforming all comparable methods. The worst method is FedDBE (33.30% ± 0.52%), which our method improves by approximately 14.08 percentage points. FedWBA comes in second with an accuracy of 44.59% ± 0.12%, which our method improves by 2.79 percentage points. With 300 clients, our method maintains the highest accuracy at 45.48% ± 0.35%. The worst method is FedDBE (32.80% ± 0.45%), which our method improves by approximately 12.68 percentage points. FedWBA comes in second with an accuracy of 43.90% ± 0.15%, which our method improves by 1.58 percentage points. With 500 clients, the accuracy of this invention's method is 42.63% ± 0.17%, maintaining its leading position. The worst method is FedDBE (31.89% ± 0.33%), which is approximately 10.74 percentage points higher than this method. FedALA ranks second with an accuracy of 42.30% ± 0.19%, which is 0.33 percentage points higher than this method. It is noteworthy that FedAS, pFedBayes, and pFedGP failed to produce results on Tiny-ImageNet due to excessive computation time, while this invention's method not only runs efficiently but also achieves optimal performance, fully demonstrating its practicality and superiority on complex, large-scale datasets.
[0137] In summary, the experimental results clearly and strongly demonstrate that, across four datasets—MNIST, CIFAR-10, CIFAR-100, and Tiny-ImageNet—and with varying numbers of clients (200, 300, and 500), the FedCMGP method of this invention significantly improves the test accuracy compared to the worst-performing method in each scenario (with improvements exceeding 14 percentage points in some cases). It also generally exhibits a clear advantage over the second-place method (improvements ranging from 0.3 to 2.8 percentage points in most cases). Only in the CIFAR-100 scenario with 500 clients does it rank second by a small margin, but it still far outperforms other methods. These results fully demonstrate the remarkable effectiveness and significant superiority of the method described in this invention in improving model accuracy in federated learning tasks.
[0138] Furthermore, by observing the performance changes as the number of clients increases (i.e., the amount of local data decreases relatively), it can be found that the performance degradation of the proposed method is significantly smaller than that of most baseline methods. For example, on CIFAR-10, when the number of clients increases from 200 to 500, the accuracy of the proposed method decreases by 3.7 percentage points, while FedAvg decreases by 6.89 percentage points and pFedGP decreases by 4.18 percentage points. On Tiny-ImageNet, the accuracy of the proposed method decreases by 4.75 percentage points from 200 to 500, while FedWBA decreases by 4.1 percentage points. This indicates that when facing scenarios with more dispersed, small sample data, the proposed method can effectively mitigate overfitting and maintain strong generalization ability through shared deep kernel priors and Bayesian nonparametric modeling.
[0139] Furthermore, in-depth analysis of the experimental results revealed that the standard error of the accuracy of the method described in this invention is generally small under different datasets and client count settings, indicating that its performance fluctuations are minimal and it possesses excellent stability. This stability is of great significance in practical applications, especially in distributed learning scenarios where data heterogeneity and instability are common problems. The method described in this invention can provide more reliable results, ensuring relatively stable model performance under different environments, thus providing strong support for practical applications.
[0140] Table 1. Experimental results on the accuracy of the MNIST, FMNIST, CIFAR-10, and CIFAR-100 datasets.
[0141]
[0142] Effect Experiment Analysis 2: Performance Verification under Different Degrees of Heterogeneity (Dirichlet Distribution Experiment)
[0143] To thoroughly evaluate the robustness and adaptability of the FedCMGP method under different levels of data heterogeneity, we model the client label distribution using a Dirichlet distribution and control the degree of heterogeneity using a concentration parameter α. The smaller the value, the stronger the data heterogeneity. Experiments were conducted on the MNIST and FMNIST datasets, with 200 clients selected. (Strong heterogeneity) (Medium heterogeneity) and (Approximately isomorphic), the method of this invention is compared with representative methods such as FedAvg, FedProx, Ditto, FedALA, and pFedBayes. The experimental results are shown in Table 2.
[0144] As can be seen from Table 2, the method of the present invention is effective in all cases. All values exhibit optimal or near-optimal performance. In highly heterogeneous scenarios ( On the MNIST dataset, the accuracy of this invention reached 88.23% ± 0.16%, an improvement of approximately 3.95 percentage points compared to the worst FedAvg (84.28%), and an improvement of approximately 1.12 percentage points compared to pFedBayes (87.11%). On the FMNIST dataset, the accuracy of this invention was 85.01% ± 0.08%, an improvement of approximately 5.20 percentage points compared to FedAvg (79.81%), and an improvement of approximately 0.49 percentage points compared to pFedBayes (84.52%). As the size increases, heterogeneity decreases, and the performance of various methods generally improves, but this invention maintains its leading position. In near-isomorphic ( When tested on MNIST, the accuracy of this invention reached 97.19% ± 0.07%, and on FMNIST it was 88.70% ± 0.09%, still outperforming or matching the best comparison method. This fully demonstrates that this invention, by sharing deep kernel priors and Bayesian nonparametric modeling, can effectively resist the negative impact of data heterogeneity and maintain superior performance in small-sample, highly heterogeneous scenarios. This characteristic has significant practical value for non-independent and identically distributed data commonly found in real-world federated learning applications.
[0145] Table 2. Test accuracy (% ± SEM) for different degrees of heterogeneity under the Dirichlet distribution.
[0146]
[0147] Effect Experiment Analysis 3: Global Model Performance Verification
[0148] This invention not only provides a personalized model for each client but also learns a global deep kernel prior on the server side by aggregating local deep kernel hyperparameters. This global prior contains shared knowledge across clients and can be used for rapid adaptation of new clients or as a benchmark for centralized evaluation. To verify the quality of the global model, we tested the server-aggregated global deep kernel on the MNIST and CIFAR-10 datasets, i.e., using global hyperparameters to construct a Gaussian process classifier and evaluating its classification accuracy on a centralized test set. Comparison methods include Ditto, pFedBayes, and FedWBA, which also output some form of global model (such as global parameters or a global prior). Experimental results are shown in Table 3.
[0149] As shown in Table 3, the global model of this invention achieves accuracy comparable to or even better than the optimal method in most scenarios. On the MNIST dataset, when the number of clients is 200, the global model accuracy of this invention reaches 91.12% ± 0.10%, which is within 0.5 percentage points of Ditto's (91.56% ± 0.12%). With 300 clients, the accuracy of this invention is 89.70% ± 0.06%, which is basically on par with Ditto's (89.74% ± 0.11%). With 500 clients, this invention surpasses Ditto's (85.90% ± 0.05%), achieving an accuracy of 86.29% ± 0.08%. On the CIFAR-10 dataset, with 200 clients, the accuracy of this invention was 64.33% ± 0.08%, slightly lower than Ditto (64.61% ± 0.13%) but better than pFedBayes (62.34% ± 0.11%) and FedWBA (62.07% ± 0.05%); with 300 clients, the accuracy was 61.99% ± 0.10%, higher than pFedBayes (61.37% ± 0.17%) and FedWBA (61.46% ± 0.03%), and slightly lower than Ditto (62.80% ± 0.22%); with 500 clients, the accuracy was 58.80% ± 0.06%, higher than Ditto (58.76% ± 0.12%) and FedWBA (58.34% ± 0.06%). 0.07%, slightly lower than pFedBayes (59.34% ± 0.13%).
[0150] The above results demonstrate that the global prior learned by aggregating deep kernel hyperparameters in this invention can effectively extract shared features across clients. Its performance is comparable to methods with specially designed global models (such as Ditto), and in most cases, it outperforms pFedBayes and FedWBA. More importantly, this global prior is broadcast to clients as a regularization term in the next round of communication, playing a crucial regularization role in the training of personalized models and preventing overfitting. This is the intrinsic reason for the superior overall performance of the method in this invention. The excellent performance of the global model also shows that the deep kernel hyperparameter aggregation strategy proposed in this invention is not only highly efficient in communication but also preserves the common information of client data, providing a robust foundation for personalized learning. Furthermore, the global model itself can be directly used for cold-start prediction of new clients or as a central evaluation model for federated learning systems, expanding the application scenarios of this invention.
[0151] Table 3 Global model test accuracy (% ± SEM)
[0152]
[0153] Effect Experiment Analysis 4: Uncertainty Calibration Performance Analysis
[0154] In this embodiment, we further explore the performance of the FedCMGP method of the present invention in uncertainty calibration, and conduct a more in-depth comparison with other comparative methods to verify its advantages in providing reliable uncertainty quantification results. Uncertainty calibration measures the consistency between model prediction confidence and actual accuracy, and is a key indicator of model reliability in high-risk decision-making scenarios (such as medical diagnosis, autonomous driving, and financial risk control). This embodiment uses three indicators—Expected Calibration Error (ECE), Maximum Calibration Error (MCE), and Brier Score—to comprehensively evaluate the uncertainty calibration performance of each method. Among them, ECE measures the weighted average difference between empirical accuracy and model confidence; the smaller the ECE value, the better the calibration performance. MCE focuses on the calibration deviation in the worst case. Brier Score evaluates the overall accuracy of probability prediction; the smaller the value, the more accurate the probability prediction.
[0155] Experiments were conducted on three datasets: CIFAR-10, CIFAR-100, and Tiny-ImageNet, with the number of clients set to 200, 300, and 500, respectively. All configurations were consistent with those used in the accuracy experiments. We selected representative methods such as FedAvg, FedProx, Ditto, FedALA, FedDBE, FedAS, pFedBayes, pFedGP, and FedWBA as baselines to comprehensively evaluate the performance of our proposed method in uncertainty quantification. Experimental results are shown in Table 4.
[0156] 1. Calibration performance on the MNIST dataset
[0157] To visually demonstrate the advantages of the method of this invention in uncertainty calibration, Figure 2 Reliability plots of the proposed method versus other comparative methods on the MNIST dataset are presented. In the reliability plots, the horizontal axis represents the confidence level of the model's predictions, the vertical axis represents the actual accuracy, the red diagonal line represents perfect calibration (i.e., confidence equals accuracy), and the gray bars represent the confidence distribution under ideal calibration conditions. Figure 2It is clearly evident that the confidence-accuracy bar chart of the method of this invention is closest to the red diagonal and highly overlaps with the gray ideal bar in each confidence interval, indicating excellent consistency between its predicted confidence and actual accuracy. In contrast, the bar chart of pFedGP deviates significantly from the diagonal in the high confidence region, exhibiting overconfidence; while FedWBA has good overall calibration, it still shows some deviation in the 0.7~0.9 range; pFedBayes shows large fluctuations throughout the entire confidence range, exhibiting the weakest calibration performance. Figure 2 The ECE, MCE, and BRI values for each method listed below further confirm the above observations: the ECE of the method of this invention is 0.0078, the MCE is 0.0718, and the BRI is 0.0998, all significantly better than other comparative methods. This intuitive result corroborates the quantitative analysis in Table 4, fully demonstrating the superior performance of the method of this invention in uncertainty quantification.
[0158] 2. Calibration performance on the CIFAR-10 dataset
[0159] On the CIFAR-10 dataset, the method of this invention achieved optimal or near-optimal calibration performance across all client number settings. With 200 clients, the ECE value of this method was 0.0271, significantly outperforming comparable methods such as FedAvg (0.0719), FedProx (0.0451), Ditto (0.1079), and FedALA (0.0322). The Brillouin score of this method was 0.3680, the lowest among all methods, indicating the most accurate overall probability prediction. When the number of clients increased to 300, the ECE value of this method was 0.0341, outperforming all methods except FedWBA (0.0270), with an MCE of 0.1147, significantly lower than FedWBA's 0.1598; the Brillouin score was 0.3109, also the lowest. When the number of clients is 500, the ECE value of this invention is 0.0208, which is the best among all methods. The MCE is 0.0962, which is second only to pFedBayes (0.0615) but better than other methods. The Brill score is 0.4295, which is lower than FedWBA (0.4678) and pFedGP (0.5468).
[0160] 2. Calibration performance on the CIFAR-100 dataset
[0161] The CIFAR-100 dataset has more categories and more complex tasks, placing higher demands on uncertainty calibration. With 200 clients, this invention achieves an ECE value of 0.0323, significantly outperforming comparable methods such as FedAvg (0.0898), FedProx (0.0797), Ditto (0.0934), and FedALA (0.1302), and slightly higher than pFedBayes (0.0340), but with a significantly higher MCE and Brillouin score. With 300 clients, this invention achieves an ECE value of 0.0565, outperforming FedAvg (0.0775), FedProx (0.0515, slightly lower but superior in other metrics), and Ditto (0.0434, slightly lower but superior in other metrics), with an MCE of 0.1167, far lower than FedProx's 0.2201, and a Brillouin score of 0.3453. At 500 clients, the ECE value of this invention is 0.0557, which is better than FedAvg (0.0850), FedProx (0.0705), Ditto (0.0955), etc., and the MCE is 0.1013, which is significantly better than FedProx (0.1517) and FedAvg (0.1509). The Brill score is 0.4100, which is better than all the comparison methods.
[0162] 3. Calibration performance on the Tiny-ImageNet dataset
[0163] The Tiny-ImageNet dataset, with its larger size and more categories, presents a significant challenge for evaluating the performance of uncertainty-based calibration. At 200 clients, the ECE value of this invention is 0.1481, which, while higher than FedWBA (0.1257), is significantly better than FedAvg (0.2185), FedProx (0.1946), Ditto (0.1957), and other methods. Furthermore, its MCE (0.2084) is far lower than all the comparison methods (FedWBA: 0.3663, FedAvg: 0.5452), indicating that the calibration bias of this invention is extremely small even in the worst-case scenario. At 300 clients, the ECE value of this invention is 0.0943, which is better than FedAvg (0.1405), FedProx (0.1084), Ditto (0.1167), etc., and the MCE is 0.2711, which is significantly better than FedWBA (0.5086) and FedAvg (0.5781). The Brill score is 0.7434, which is better than all methods. At 500 clients, the ECE value of this invention is 0.0647, which is better than FedAvg (0.1420), FedProx (0.1042), Ditto (0.0274) (slightly lower, but the MCE and BRI of this invention are better), FedWBA (0.0990), etc., and the MCE is 0.2219, which is much lower than FedWBA (0.3926) and FedAvg (0.6069). The Brill score is 0.6463, which is better than all comparison methods.
[0164] Based on the above experimental results, the FedCMGP method of this invention demonstrates significant advantages in uncertain calibration. Under most dataset and client quantity settings, the ECE, MCE, and Brill scores of this invention's method reach optimal or near-optimal levels. Particularly on complex datasets like Tiny-ImageNet, the MCE metric of this invention significantly outperforms all comparable methods, indicating extremely high calibration reliability even in the worst-case scenario.
[0165] FedCMGP's superior performance in uncertainty calibration is attributed to its Bayesian nonparametric design based on the classification of conjugate multi-class Gaussian processes. Specifically, the key factors include:
[0166] First, accurate posterior inference. Through One-vs-Each approximation and Pólya-Gamma augmentation, this invention achieves a fully analytical conditional posterior distribution in multi-class Gaussian process classification. Clients can obtain accurate posterior samples through efficient Gibbs sampling without relying on variational or Laplace approximation, fundamentally avoiding calibration errors introduced by approximate inference.
[0167] Second, the regularization effect of shared deep kernel priors. The global deep kernel priors obtained by server aggregation contain shared knowledge across clients, which plays a strong regularization role in local posterior inference, effectively preventing clients from making overconfidence predictions on small sample data, thereby improving calibration quality.
[0168] Third, Gaussian processes have an inherent ability to quantify uncertainty. Gaussian processes perform Bayesian inference directly in the function space, which naturally allows them to quantify the uncertainty of predictions through posterior variance, avoiding the uncertainty estimation distortion caused by model specification bias in parametric methods.
[0169] In contrast, many traditional federated learning methods (such as FedAvg and FedProx) are essentially frequentist point estimation methods, completely lacking the ability to quantify uncertainty, and their ECE values are generally high. Personalized federated learning methods (such as Ditto and FedALA), while improving prediction accuracy, also lack effective estimation of uncertainty. Bayesian methods (such as pFedBayes and pFedGP), although introducing uncertainty modeling, rely on the Gaussian distribution assumption in pFedBayes and employ a GP tree ensemble strategy, neither of which can achieve the precise conjugate posterior inference required in this invention, thus limiting calibration performance.
[0170] In summary, this embodiment, through comprehensive analysis of indicators such as ECE, MCE, and Brill score, fully verifies the superior performance of FedCMGP in uncertainty calibration. The method of this invention provides more accurate and reliable uncertainty quantification results than comparative methods, offering stronger support for risk control and decision-making in practical applications, especially in fields with extremely high requirements for model reliability, such as medical diagnostics and autonomous driving, where it has significant application value.
[0171] Table 4. Results of uncertainty calibration experiments on the MNIST dataset
[0172]
[0173] Effect Experiment Analysis 5: Convergence Rate Analysis Experiment
[0174] This embodiment aims to explore the convergence rate of the FedCMGP method in a federated learning scenario and conduct a comprehensive comparison with several representative methods to verify its significant advantages in training efficiency and convergence stability. The experiments used four datasets: MNIST, CIFAR-10, CIFAR-100, and Tiny-ImageNet, with the number of clients set to 200, 300, and 500, respectively. All configurations were consistent with the accuracy experiments. To simulate a resource-constrained real-world scenario, each selected client performed only one local update per communication round, and the total number of communication rounds was set to 500 rounds. Figure 3 As shown, the convergence rate is measured by the curve of test accuracy versus communication rounds. The convergence curve is plotted with the horizontal axis representing the number of communication rounds and the vertical axis representing test accuracy. FedCMGP's ability to achieve fast and stable convergence is primarily attributed to the following two key design features:
[0175] First, the server learns a global prior based on the shared deep kernel hyperparameters uploaded by the client via FedAvg, thus acquiring a global prior containing knowledge shared across clients. This prior is broadcast to all clients in the next round of communication, providing a strong regularization direction for local posterior inference. This enables client updates to collaboratively move towards the global optimum, effectively avoiding deviations and oscillations in update direction caused by data heterogeneity.
[0176] Second, efficient posterior inference through Gibbs sampling. By employing One-vs-Each approximation and Pólya-Gamma augmentation, FedCMGP achieves conjugate posterior inference for multi-class Gaussian process classification, enabling clients to quickly obtain accurate posterior samples through Gibbs sampling. Compared to variational inference methods requiring numerous iterations for optimization, Gibbs sampling converges with only a small number of iterations in each round of local updates, significantly reducing local computation time and thus accelerating the overall federated learning process.
[0177] Third, the feature representation capability of deep kernel learning. Deep kernel functions map the original input to a more discriminative feature space through neural networks, enabling the client to quickly learn task-related data representations, further accelerating model convergence.
[0178] In summary, this embodiment fully verifies FedCMGP's significant advantage in convergence speed through convergence curve analysis across multiple datasets and varying client numbers. Compared to existing methods, FedCMGP achieves higher accuracy in fewer communication rounds and exhibits excellent stability with minimal performance fluctuations during training. This efficient convergence characteristic makes the method of this invention particularly suitable for practical federated learning applications with limited communication resources and requiring rapid response, such as real-time model updates on mobile devices and rapid anomaly detection in the Internet of Things, providing strong support for the practical deployment of federated learning.
[0179] Effect Experiment Analysis 6: Input Noise Robustness Analysis
[0180] In real-world federated learning applications, client data is often subject to various types of noise interference, such as sensor noise during image acquisition, data corruption during transmission, and image blurring caused by environmental changes. To evaluate the robustness of the FedCMGP method in heterogeneous data quality scenarios, this embodiment designs an input noise robustness experiment to simulate client data being contaminated by different degrees and types of noise, and examines the extent of performance degradation of the model.
[0181] The experiments were conducted on the CIFAR-100 dataset with 200 clients, and all configurations were kept consistent with the accuracy experiments. To simulate the variability in client data quality in real-world scenarios, we randomly selected 30% of the clients as "noisy clients" and applied one of five common types of image contamination to their local training data, specifically including:
[0182] Gaussian noise: Add Gaussian white noise with a mean of 0 and a variance of 0.1 to the image;
[0183] Salt and pepper noise: Randomly set 5% of the pixels to black or white;
[0184] Defocus blur: The image is blurred using a disk-shaped point spread function with a radius of 3;
[0185] Gaussian blur: Smoothing is performed using a Gaussian kernel with a standard deviation of 2;
[0186] Motion blur: Simulates camera motion, using a motion blur kernel with a length of 10 and an angle of 45 degrees for convolution operation.
[0187] The remaining 70% of clients kept their original data unchanged. All comparison methods were trained and tested under this noisy scenario, with accuracy used as the evaluation metric. The experiment was repeated 5 times and the average value was taken.
[0188] Experimental results are as follows Figure 4 As shown in (a)-(e), the visualizations of five different noise types are presented respectively. Figure 4 (f) summarizes the accuracy of each method in noisy scenarios and its performance degradation compared to noise-free scenarios. From Figure 4 (f) It can be clearly seen that the FedCMGP method of the present invention achieves the highest accuracy and the smallest performance degradation.
[0189] Specifically, in Gaussian noise scenarios, FedCMGP achieved an accuracy of 74.56%, a decrease of approximately 7.41 percentage points compared to noise-free scenarios. The next best performer, pFedBayes, achieved an accuracy of 72.15%, a decrease of 9.14 percentage points; FedDBE achieved an accuracy of 70.79%, a decrease of 14.89 percentage points; and FedAvg achieved only 68.49% ±, a significant decrease of 20.28 percentage points. The superior robustness of FedCMGP in noisy input scenarios is mainly due to the following factors:
[0190] First, the regularization effect of shared deep kernel priors. The global deep kernel priors obtained by server aggregation contain shared knowledge across clients, playing a strong regularization role in local posterior inference. This allows client updates to resist interference from local noisy data and avoids drastic shifts in model parameters due to noisy samples. Even if the local data of some clients is contaminated, their uploaded hyperparameter updates will be smoothed by a weighted averaging process, thereby reducing the impact of malicious or noisy updates.
[0191] Second, the inherent smoothness of Gaussian processes. Gaussian processes measure the similarity between input points through a kernel function, and their predictions are essentially a weighted average based on similarity. When the input data is disturbed by noise, the kernel function can adaptively reduce the weight of noisy points based on their distance in the feature space, thereby suppressing the impact of noise on the prediction results. In contrast, point estimation methods based on neural networks lack this built-in smoothing mechanism and are more prone to overfitting to noisy samples.
[0192] Third, the robustness of deep kernel learning. Deep kernel functions map the original input to a feature space through neural networks. This feature space is regularized by global priors, resulting in stronger discriminative power and robustness. Even in the presence of noise in the input space, deep networks can still extract relatively stable feature representations, ensuring that the similarity calculated by the kernel function remains reliable.
[0193] This embodiment comprehensively evaluates the robustness of FedCMGP under heterogeneous data quality conditions by simulating five common image noise contamination scenarios on the CIFAR-100 dataset. Experimental results show that FedCMGP achieves the best accuracy under all noise types with the smallest performance degradation, demonstrating excellent noise resistance. This characteristic makes the method of this invention particularly suitable for real-world applications with complex data acquisition environments and difficult quality control, such as image classification taken by mobile devices and IoT sensor data processing, providing strong support for the reliable deployment of federated learning systems.
[0194] Effects Experiment Analysis 7: Byzantine Attack Experiment
[0195] This implementation aims to thoroughly evaluate the robustness of the FedCMGP method of this invention against Byzantine attacks. Byzantine attacks refer to malicious clients in a federated learning system that disrupt the convergence performance of the global model by uploading carefully crafted erroneous updates, posing a significant security threat to distributed learning systems. Through this experiment, we verify the stability of FedCMGP under malicious attack environments and explore its defensive effectiveness when combined with various robust aggregation rules.
[0196] Experiments were conducted on the MNIST and CIFAR-10 datasets, with 200 clients configured. 20% of the clients (40 clients) were configured as malicious clients, and the remaining 160 were configured as honest clients. All configurations were consistent with the accuracy experiments. To comprehensively evaluate robustness under different attack intensities, we simulated five common Byzantine attack strategies:
[0197] Sign-Flipping Attack (SF): After a malicious client updates locally, it reverses the uploaded depth kernel hyperparameters to maximize the disruption of the global aggregation direction.
[0198] Label-Flipping Attack (LF): A malicious client randomly flips the labels of the training data before local training (for MNIST and CIFAR-10, the labels are flipped). Replace with Then, normal training and updates are performed based on the incorrect labels.
[0199] FOE Attack (Fall of Empires Attack): The core idea of an FOE attack is to manipulate the update direction of a malicious client so that it is orthogonal to or even opposite to the update direction of a legitimate client, while keeping the update magnitude similar to that of the legitimate client, in order to circumvent distance-based anomaly detection mechanisms.
[0200] ALIE Attack (A Little Is Enough Attack): The core idea of the ALIE attack is that "disruption can be caused with only a small modification." The attacker first calculates the mean and standard deviation of honest client updates, and then sets the malicious client updates to deviate from the mean by a certain multiple of the standard deviation.
[0201] To defend against Byzantine attacks, we tested four classic robust aggregation rules and combined them with the method of this invention:
[0202] Geometric Median (GM) aggregation: uses the geometric median updated by the client as the aggregation result, and is robust to outliers.
[0203] Krum aggregation: Selects the client update that is closest to most other client updates as the aggregation result, and can tolerate up to [number missing] updates. One malicious client (of which) , (Number of participating clients).
[0204] Coordinate-wise Trimmed Mean (CWTM): For each coordinate dimension, the average value is taken after removing the largest and smallest proportions.
[0205] Near-Median (NNM): Aggregation is performed after updating trusted clients based on geometric neighborhood selection.
[0206] All experiments were repeated 5 times, and the average value and standard error were taken. The test accuracy was used as the evaluation index.
[0207] The experimental results are shown in Table 5, which is divided into two parts: the upper part shows the inherent robustness of the method of the present invention compared with FedAvg and pFedBayes when using simple mean aggregation; the lower part shows the defense effect of the method of the present invention combined with different robust aggregation rules.
[0208] Table 5. Test accuracy under Byzantine attack (% ± SEM)
[0209]
[0210] As can be seen from the upper part of Table 5, even when using simple Mean Aggregation without any defense mechanisms, the FedCMGP method of this invention significantly outperforms FedAvg and pFedBayes under all five attack types. This inherent robustness stems from two main aspects: first, the global regularization effect of the shared deep kernel prior constrains client updates within a reasonable range, reducing the impact of malicious updates on the global model; second, the kernel smoothness of the Gaussian process itself ensures that even when local parameters are perturbed, the predictions in the function space remain stable.
[0211] As can be seen from the lower half of Table 5, the defense effect is further improved when the method of this invention is combined with robust aggregation rules. Among them, the NNM aggregation rule shows the best or near-best performance under all attack types. The superiority of the NNM method lies in its ability to filter trusted clients through geometric neighborhoods, effectively distinguishing between malicious updates and normal heterogeneous updates, and is particularly suitable for scenarios with non-independent and identically distributed data. Compared with Krum, which requires the assumption of independent and identical distribution, NNM is more robust; compared with CWTM, NNM has more refined geometric filtering in the multidimensional parameter space.
[0212] By simulating five Byzantine attack strategies on the MNIST and CIFAR-10 datasets, the robustness of FedCMGP under malicious attack environments was comprehensively evaluated. Experimental results show that:
[0213] Inherently robust: Even under average aggregation without defense, FedCMGP significantly outperforms FedAvg and pFedBayes across all attack types, demonstrating its inherent resistance to attacks.
[0214] The defense combination is highly effective: When combined with robust aggregation rules such as NNM, FedCMGP can restore its accuracy to over 90% in most attack scenarios, approaching the normal level when there is no attack.
[0215] Multi-layered protection mechanism: FedCMGP’s robustness comes from multiple protections, including shared prior regularization, Gaussian process function space stability, kernel function locality, and robust aggregation rules, forming a deep defense system.
[0216] This characteristic makes the method of the present invention particularly suitable for security-sensitive practical application scenarios, such as financial risk control and medical diagnosis, providing a solid technical guarantee for the reliable deployment of federated learning systems in untrusted environments.
[0217] Effect Experiment Analysis 8: Ablation Experiment
[0218] This embodiment conducts ablation experiments to deeply analyze the impact of various key components in the FedCMGP method of this invention on model performance, further revealing its characteristics and advantages. The experimental environment is consistent with the previous one, with 200 clients participating in the study on the MNIST and CIFAR-10 datasets, using test accuracy as the main evaluation metric.
[0219] 1. The impact of GP basis kernel functions on model performance
[0220] In FedCMGP, the base kernel function The choice of kernel directly affects the function space smoothness and generalization ability of the Gaussian process. We compared four commonly used basis kernels: the linear kernel, the Laplace kernel, the cosine similarity kernel, and the radial basis function kernel (RBF). The experimental results are shown in Table 6.
[0221] As shown in Table 6, on the MNIST dataset, the RBF kernel achieved the highest accuracy of 98.48% ± 0.03%, followed by the cosine similarity kernel (97.40% ± 0.10%), while the linear kernel performed the worst (97.31% ± 0.10%). On the CIFAR-10 dataset, the RBF kernel also far surpassed the others with an accuracy of 73.86% ± 0.21%, while the cosine similarity kernel and Laplace kernel achieved 71.45% ± 0.46% and 70.95% ± 0.40%, respectively, and the linear kernel only achieved 66.77% ± 0.37%. The superior performance of the RBF kernel is attributed to its infinite differentiability and local smoothing properties, which enable it to more flexibly capture nonlinear patterns in the data. Therefore, in a preferred embodiment of the present invention, the RBF kernel is used as the basis kernel function.
[0222] Table 6 Ablation experiments showing the impact of GP core on model performance
[0223]
[0224] 2. The impact of depth kernel embedding dimension on model performance
[0225] Embedding dimension in depth kernel function The determination of deep neural networks The expressive power of the output feature space. Theoretically, the higher the dimensionality, the stronger the model's ability to fit complex data, but it also increases computational overhead and the risk of overfitting. We tested the model performance under different embedding dimensions on the MNIST and CIFAR-10 datasets, and the results are shown in Table 7.
[0226] For the MNIST dataset, when the embedding dimension increased from 64 to 128, the accuracy improved from 97.25% ± 0.13% to 98.48% ± 0.03%; this shows that for the simple dataset MNIST, 128-dimensional embedding is sufficient to capture data features, and excessively high dimensions may introduce redundant information.
[0227] For the CIFAR-10 dataset, accuracy continuously improves with increasing embedding dimensions: 68.91% ± 0.46% for 128 dimensions, 72.50% ± 0.42% for 256 dimensions, and reaching the optimal 73.86% ± 0.21% for 512 dimensions. This indicates that complex datasets require higher embedding dimensions to fully represent features. Therefore, in practical applications, the embedding dimension should be reasonably selected according to the task complexity. In the preferred embodiment of this invention, MNIST uses 128 dimensions, and CIFAR-10 and higher-level complex datasets use 512 dimensions.
[0228] Table 7 Ablation experiments on the impact of depth kernel embedding dimension on model performance
[0229]
[0230] 3. The impact of Gibbs sampling sample size on model performance
[0231] The number of parallel Gibbs chains in FedCMGP's local posterior inference. The accuracy of Monte Carlo gradient estimation is determined by the number of samples taken. We compared... The model performance at that time is shown in Table 8.
[0232] On the MNIST dataset, increasing the sample size from 1 to 5 significantly improved the accuracy from 96.46% ± 0.13% to 97.92% ± 0.07%; further increasing to 10 resulted in an accuracy improvement of 98.45% ± 0.03%, which tended to plateau. On the CIFAR-10 dataset, increasing the sample size from 5 to 10 improved the accuracy from 72.81% ± 0.56% to 73.42% ± 0.56%; increasing to 15 resulted in 73.47% ± 0.38%. This indicates that appropriately increasing the sample size can effectively reduce the gradient estimation variance and improve model performance, but the marginal benefit diminishes after a certain number. Considering both computational cost and performance, the preferred embodiment of this invention adopts... .
[0233] Table 8 Ablation Experiments on the Impact of Gibbs Sampling Sample Size on Model Performance
[0234]
[0235] 4. The impact of Gibbs sampling iterations on model performance
[0236] Number of iterations for local Gibbs sampling This affects the pooling quality of the posterior samples. We compared... The model performance at that time is shown in Table 9.
[0237] On the MNIST dataset, increasing the number of iterations from 1 to 5 improved the accuracy from 97.72% ± 0.07% to 98.40% ± 0.05%; further increasing it to 10 reduced the accuracy to 97.54% ± 0.11%. This may be because excessive iterations lead to repeated decomposition of the covariance matrix, introducing numerical error accumulation and thus impairing sampling quality. On the CIFAR-10 dataset, increasing the number of iterations from 5 to 10 improved the accuracy from 73.37% ± 0.37% to 73.80% ± 0.25%; further increasing it to 15 resulted in a slight decrease to 73.41% ± 0.56%. Therefore, there exists an optimal number of iterations: 5 for the simple dataset (MNIST) and 5-10 for the complex dataset (CIFAR-10). In a preferred embodiment of the present invention, MNIST... CIFAR-10 and above .
[0238] Table 9 Ablation experiments on the impact of Gibbs sampling iterations on model performance
[0239]
[0240] The ablation experiments in this embodiment show that each component in FedCMGP has a significant impact on model performance. The RBF kernel, due to its excellent smoothness, is the optimal choice. The embedding dimension of the deep kernel needs to be reasonably set according to the dataset complexity; 128 dimensions for MNIST and 512 dimensions for CIFAR-10 achieve a balance between performance and efficiency. Setting the number of Gibbs samples to 5 effectively reduces gradient estimation variance, and setting the number of iterations to 5-10 ensures sampling quality while avoiding the accumulation of numerical errors. These experimental results provide a basis for further optimization of the FedCMGP model and also verify the rationality and necessity of the component design in this invention, fully demonstrating the adaptability and superior performance of FedCMGP under different configurations.
[0241] Example 1: Accuracy Verification Scenario for 5-Class Image Classification
[0242] The system deployed 200 clients and 1 server, with each client having approximately 300 CIFAR-10 image samples locally. The clients received the initial global depth kernel hyperparameters broadcast by the server, constructed a sparse matrix to encode the inter-class differences of the five classes of samples, and introduced an auxiliary variable following a Pólya-Gamma distribution to make the likelihood Gaussian, achieving conditional conjugation with the Gaussian prior. Latent function values and auxiliary variables were alternately updated via Gibbs sampling, initializing five parallel Gibbs chains. Posterior samples were obtained through 10 iterations, and gradients were calculated using Fisher's identity and Monte Carlo estimation. The local depth kernel hyperparameters were optimized with a learning rate of 0.001 and uploaded. The server updated the global hyperparameters using a weighted average based on the amount of data from each client, repeating the communication 500 times. The final classification accuracy reached 73.45%, an improvement of 11.52 percentage points compared to the traditional FedAvg, highlighting the synergistic creativity of parsing the posterior distribution and weighted aggregation.
[0243] Example 2: Dirichlet Heterogeneity Validation Scenario for 10-Class Image Classification
[0244] Deploy 200 clients and 1 server. Each client has 500 FMNIST image samples locally. The degree of data heterogeneity is determined by the Dirichlet distribution concentration parameter. Control (Strong Heterogeneity). After receiving the global depth kernel hyperparameters, the client achieves conjugate posterior inference through One-vs-Each approximation and Pólya-Gamma augmentation. Gibbs sampling iteratively takes 10 steps with 5 parallel chains to obtain posterior samples for optimizing local hyperparameters. The server weighted averages and aggregates the hyperparameters before broadcasting them. After 500 rounds of communication, the accuracy reaches 88.23%, a 3.95 percentage point improvement over FedAvg and a 1.12 percentage point improvement over pFedBayes, verifying the superiority of this invention under strongly heterogeneous data.
[0245] Example 3: Global Model Validation Scenario for 10-Class Image Classification
[0246] 200 clients and 1 server were deployed, with each client having 300 MNIST image samples locally. Clients updated their local depth kernel hyperparameters and uploaded them according to the procedure in Example 1. The server used a weighted average aggregation to obtain the global depth kernel hyperparameters, and then evaluated the global model performance on a centralized test set. After 500 rounds of communication, the global model achieved a classification accuracy of 91.12%, comparable to Ditto (91.56%), and superior to pFedBayes (90.90%) and FedWBA (90.82%), demonstrating that the aggregated global prior can effectively extract shared features across clients.
[0247] Example 4: Uncertainty Calibration and Verification Scenario for 10-Class Image Classification
[0248] 200 clients and 1 server were deployed, with each client containing 300 CIFAR-10 image samples. The clients calculated the prediction uncertainty using the conjugate Gaussian process posterior distribution, and the server aggregated the data to obtain the global model. After 500 rounds of communication, the expected calibration error (ECE) of this invention was 0.0271, the maximum calibration error (MCE) was 0.1017, and the Brill score was 0.3680, significantly outperforming FedAvg (ECE 0.0719) and pFedBayes (ECE 0.0507), verifying the improving effect of accurate posterior inference on uncertainty quantification.
[0249] Example 5: Convergence Rate Verification Scenario for 20 Image Classification Categories
[0250] 200 clients and 1 server were deployed, with each client containing 500 Tiny-ImageNet image samples. The clients used the Hoffman-Ribak method to accelerate Gibbs sampling. During the iteration process, the accuracy rapidly increased from the initial value to 38% in the first 100 rounds, and reached 42.38% after 200 rounds. The convergence speed was significantly faster than FedWBA (only 40.59% after 200 rounds) and FedAvg (approximately 39% after 200 rounds), demonstrating the convergence advantage of efficient sampling.
[0251] Example 6: Input noise robustness verification scenario for 10-class image classification
[0252] 200 clients and 1 server were deployed. 30% of the clients (60 clients) were randomly selected to add one of the following to their local CIFAR-10 data: Gaussian noise, salt-and-pepper noise, defocus blur, Gaussian blur, or motion blur. After receiving the global hyperparameters, the clients performed posterior inference using Gibbs sampling, leveraging the regularization effect of the global prior to suppress the noise effect. After weighted aggregation by the server and 500 rounds of communication, the accuracy in noisy scenarios reached 67.15%, only 7.41 percentage points lower than in noise-free scenarios, while FedAvg decreased by 20.28 percentage points and pFedGP by 9.14 percentage points, verifying the robustness of the inherent smoothness of the Gaussian process.
[0253] Example 7: Verification Scenario for Byzantine Attack Defense in 10-Class Image Classification
[0254] 200 clients and 1 server were deployed, with 20% of the clients (40 nodes) being malicious nodes, implementing four types of attacks including symbol flipping and label inversion. The clients used the method of this invention for local updates, while the server employed NNM robust aggregation rules. After 500 rounds of communication, the accuracy reached 94.40%, a significant improvement over average aggregation and approaching the 98.48% level without attacks, demonstrating the effectiveness of the multi-layered defense combining shared prior regularization and NNM aggregation.
[0255] Example 8: Ablation Experiment Scenario for 10 Image Classifications
[0256] 200 clients and 1 server were deployed, with each client containing 300 CIFAR-10 image samples. By comparing different basis kernel functions (RBF, linear, Laplacian, and cosine), the RBF kernel achieved the highest accuracy (73.86%); compared to the depth kernel embedding dimension, 512 dimensions were optimal; compared to the number of Gibbs samples, 10 chains were sufficient; and compared to the number of iteration steps, 10 steps were optimal. Experimental results show that the configuration of RBF kernel, 512-dimensional embedding, 10 samples, and 10 iteration steps achieves the best balance between performance and efficiency, validating the rationality of the component design.
[0257] The above description is merely a specific embodiment of the present invention, but the scope of protection of the present invention is not limited thereto. Any modifications, equivalent substitutions, and improvements made by those skilled in the art within the scope of the technology disclosed in the present invention, and within the spirit and principles of the present invention, should be covered within the scope of protection of the present invention.
Claims
1. A personalized Bayesian federated learning system based on conjugate multi-class Gaussian process classification and deep kernel learning, characterized in that, It includes a client-side local posterior inference module and a server-side global prior aggregation module. The client-side local posterior inference module adopts the conjugate multi-class Gaussian process classification method, decomposes the multi-class similarity probability into a binary similarity probability product, realizes the analytical conditional posterior distribution of multi-class Gaussian process classification by introducing auxiliary variables, obtains posterior samples by alternately updating latent function values and auxiliary variables using Gibbs sampling, and optimizes the local depth kernel hyperparameters based on the posterior samples. The server's global prior aggregation module receives local depth kernel hyperparameters uploaded by clients, updates the global depth kernel hyperparameters by weighted averaging based on the amount of local data from each client, and broadcasts the updated global hyperparameters to clients as shared priors for the next round of communication.
2. The system according to claim 1, characterized in that, The client-side local posterior inference module decomposes the multi-class similarity probability into a product of binary similarity probabilities by constructing a sparse matrix based on local labels, using the sparse matrix to encode all inter-class differences, and thus approximating the multi-class similarity probability as a product of multiple binary similarity probabilities.
3. The system according to claim 1, characterized in that, The auxiliary variable introduced by the client-side local posterior inference module follows a Pólya-Gamma distribution. After introducing the auxiliary variable, the likelihood under the given conditions of the auxiliary variable takes a Gaussian form, achieving conditional conjugation with the Gaussian prior.
4. The system according to claim 1, characterized in that, When optimizing the local depth kernel hyperparameters, the client-side local posterior inference module uses Fisher's identity and Monte Carlo estimation to calculate the gradient of the log-marginal likelihood with respect to the local depth kernel hyperparameters, and then uses the gradient method to update the local depth kernel hyperparameters with a set learning rate.
5. The system according to claim 1, characterized in that, The specific method by which the server's global prior aggregation module updates the global depth kernel hyperparameters is as follows: calculate the sum of the local data volume of all clients, multiply the local depth kernel hyperparameter of each client by the local data volume of that client, sum the results, and then divide the sum by the sum of the local data volume of all clients to obtain the updated global depth kernel hyperparameters.
6. The system according to claim 1, characterized in that, In the client-side local posterior inference module, Gibbs sampling adopts the Hoffman-Ribak method, which utilizes the block diagonal structure of the covariance matrix and the low-rank diagonal property of multiplying the transpose of the sparse matrix with the diagonal matrix of the auxiliary variable and then with the sparse matrix to reduce the complexity of a single sampling.
7. A personalized Bayesian federated learning method based on conjugate multi-class Gaussian process classification and deep kernel learning, characterized in that, The system described in claim 1 comprises the following steps: A client receives global depth kernel hyperparameters broadcast by a server, decomposes the multi-class similarity using a conjugate multi-class Gaussian process classification method, introduces auxiliary variables to obtain an analytical conditional posterior distribution, updates the latent function value and auxiliary variables through Gibbs sampling to obtain posterior samples, optimizes local depth kernel hyperparameters based on the posterior samples, and uploads them to the server; The server receives local depth kernel hyperparameters from each client, updates the global depth kernel hyperparameters using a weighted average based on the amount of local data from each client, and broadcasts the updated global depth kernel hyperparameters to each client. The above steps are repeated to achieve cross-client knowledge sharing and personalized learning collaborative optimization.
8. The method according to claim 7, characterized in that, The specific steps for Gibbs sampling and updating latent function values and auxiliary variables are as follows: initialize the number of parallel Gibbs chains and the number of local Gibbs iteration steps; initialize auxiliary variables and latent variables for each Gibbs chain; in each iteration, first sample the auxiliary variables of the current round based on the latent variables of the previous round, and then sample the latent variables of the current round based on the auxiliary variables of the current round. After all iterations are completed, the final auxiliary variables and latent variables of each Gibbs chain are obtained as posterior samples.
9. A personalized Bayesian federated learning method based on conjugate multi-class Gaussian process classification and deep kernel learning, characterized in that, Applied to the system of claim 1, the generalization error of this method satisfies the PAC-Bayes generalization error bound. For any value of delta between 0 and 1, the relevant inequality holds with probability 1 minus delta. The true risk function is the expected classification error rate under the posterior distribution, the empirical risk is the classification error rate of the posterior distribution on the training set, the complexity term depends on the KL divergence between the posterior and prior, and the inverse function of the Bernoulli KL divergence is the maximum value among all probability values that satisfy the Bernoulli KL divergence is not greater than the complexity term.
10. The method according to claim 9, characterized in that, The true risk function is the integral of the probability that the sign of the latent function in the posterior distribution is not equal to the label under the expectation of new sample input and label; the empirical risk is the average probability that the sign of the latent function in the posterior distribution is not equal to the label in the training set samples; the complexity term is the sum of the KL divergence between the posterior and prior and the logarithm of the number of training samples plus 1 divided by delta, and then divided by the number of training samples.