An image classification method based on personalized federated distillation

By employing cross-attention-driven feature distillation and cross-head prediction techniques, the problems of erroneous knowledge transfer and insufficient adaptation to heterogeneous data in federated distillation methods are solved, thereby improving the image classification performance of personalized models.

CN121708359BActive Publication Date: 2026-06-26KUNMING UNIV OF SCI & TECH

Patent Information

Authority / Receiving Office
CN · China
Patent Type
Patents(China)
Current Assignee / Owner
KUNMING UNIV OF SCI & TECH
Filing Date
2025-12-04
Publication Date
2026-06-26

AI Technical Summary

Technical Problem

Existing federated distillation methods suffer from problems such as erroneous knowledge transfer, invalid feature transfer, and insufficient adaptation to heterogeneous data in image classification scenarios, leading to a decline in the classification performance of personalized models.

Method used

By employing cross-attention-driven feature distillation (CAD) and cross-head prediction (CHD) techniques, and decoupling the global and personalized models, feature similarity weights are calculated to transfer global knowledge that conforms to personalized features, thus avoiding optimization conflicts.

Benefits of technology

It improves the image classification performance of personalized models, enhances their generalization ability on heterogeneous data, and improves classification accuracy.

✦ Generated by Eureka AI based on patent content.

Smart Images

  • Figure CN121708359B_ABST
    Figure CN121708359B_ABST
Patent Text Reader

Abstract

The application relates to an image classification method based on personalized federated distillation, and belongs to the technical field of image classification. The method comprises the following steps: decoupling a global model and a personalized model of each client into a global feature extractor, a global head, a personalized feature extractor and a personalized head; extracting global features and personalized features of each level in an image; introducing a cross-attention mechanism to calculate cross-attention driven feature distillation loss between the global features and the personalized features; obtaining cross-entropy loss by supervising personalized prediction and real labels of the image, inputting the personalized features into the global head to obtain cross-head prediction, and performing distillation on the cross-head prediction and global prediction to obtain cross-head prediction loss; constructing total loss to train the personalized feature extractor and the personalized head, uploading the personalized feature extractor and the personalized head to a server for aggregation, and repeating the above steps to obtain an enhanced personalized model for image classification. The application aims to fully mine global knowledge in the image classification process and improve the classification performance of the personalized model.
Need to check novelty before this filing date? Find Prior Art

Description

Technical Field

[0001] This invention relates to an image classification method based on personalized federated distillation, belonging to the field of image classification technology. Background Technology

[0002] Image classification is one of the core tasks in the field of computer vision. Its core objective is to determine the category of input images using algorithmic models. It is widely used in key areas such as autonomous driving (e.g., traffic sign recognition, pedestrian and vehicle differentiation), medical image diagnosis (e.g., lung CT lesion classification, skin cancer image screening), security monitoring (e.g., suspicious item identification, personnel identity association), and smart retail (e.g., product category inventory). With the development of deep learning technology, deep models represented by ResNet and ViT have raised the accuracy of image classification to new heights by accurately extracting multi-level features of images.

[0003] However, the high performance of deep image classification models depends on the support of large-scale labeled datasets. In practical applications, image data is often scattered across different institutions or devices (i.e., "clients") and involves strict privacy protection requirements: for example, patient CT images in hospitals, road test images of autonomous driving companies, and user album images on mobile terminals cannot be directly shared for centralized training. This "data silo" problem severely limits the application of deep models in image classification tasks.

[0004] To resolve the conflict between data privacy and model performance, Federated Learning (FL) technology has been introduced into the field of image classification. It achieves "data not being moved, models being co-trained" by training models locally on the client side and only uploading model parameters to the server for aggregation. Federated Distillation, as a combination of Federated Learning and Knowledge Distillation, further optimizes for personalized needs. In image classification scenarios, the global model aggregated on the server is typically used as the "teacher model," and the client's local model acts as the "student model." By transferring the knowledge (predictions or features) from the teacher model, the classification ability of each client's personalized model is improved. For example, in a multi-hospital lung CT classification task, each hospital (client) can learn common lesion features from the global teacher model through federated distillation, while preserving the adaptability of the local model to CT images from specific equipment.

[0005] However, existing federated distillation methods still face three key problems in image classification applications, severely limiting the classification performance of personalized models: First, there is the problem of erroneous knowledge propagation. Existing methods mostly use the prediction results (Logits) of the global teacher model as distillation knowledge. However, in image classification, the global model is prone to prediction bias for similar categories (such as cats and dogs, pneumonia and common inflammation, circular traffic signs and elliptical signs). If erroneous predictions are directly passed to the personalized model, the classification accuracy of the client model on local similar category images will decrease significantly. Second, there is the problem of invalid feature propagation. Some federated distillation methods attempt to improve performance through feature distillation, but they do not consider the "feature hierarchical correlation" in image classification. Image features need to be abstracted step by step from low-level (edges, textures) to high-level (semantics, categories). Existing methods do not calculate the similarity of features between the global model and the personalized model at each level, resulting in invalid feature propagation. Thirdly, there is the problem of insufficient adaptation to heterogeneous data. In image classification scenarios, client-side data is highly heterogeneous. For example, the resolution of CT equipment in different hospitals is different (leading to differences in image size), and the class distribution of traffic sign datasets in different regions is unbalanced (e.g., "speed limit 60" signs account for 40% in one city's client, while in another city they only account for 5%). Existing federated distillation methods struggle to balance "global common knowledge" (such as general image edge extraction capabilities) with "personalized knowledge" (such as noise correction for CT images from specific equipment and sample enhancement for imbalanced categories), resulting in poor generalization ability of personalized models on local data. For example, on a heterogeneous subset of the CIFAR10 dataset (containing 10 classes of images), the classification accuracy of existing methods is 12% lower than that of the global model.

[0006] To address the core pain points of federated distillation methods in the aforementioned image classification scenarios, there is an urgent need for a personalized federated distillation scheme that can accurately filter effective knowledge and adapt to data heterogeneity, thereby improving the image classification performance of client-side personalized models. Based on this, this invention proposes an image classification method based on personalized federated distillation (FedCACHD). Through cross-attention-driven feature distillation (CAD) and cross-head prediction (CHD), it solves the problems of erroneous knowledge transmission, invalid feature transmission, and heterogeneous data adaptation, providing an efficient and privacy-preserving personalized model training scheme for image classification tasks. Summary of the Invention

[0007] The purpose of this invention is to provide an image classification method based on personalized federated distillation (FedCACHD), which aims to fully exploit global knowledge during the image classification process and improve the classification performance of personalized models.

[0008] To achieve the above objectives, the technical solution of this invention is: an image classification method based on personalized federated distillation. This method performs cross-attention-driven feature distillation (CAD) and cross-head prediction (CHD) on each client device. CAD learns global knowledge that conforms to personalized features by constructing similarity weights between global and personalized features; CHD effectively avoids optimization conflicts in the personalized model through cross-head distillation. Both methods can mutually promote each other and jointly improve personalized performance, including the following steps:

[0009] Step 1: Decouple the global model and personalized model of each client into a global feature extractor, a global head, a personalized feature extractor, and a personalized head;

[0010] Step 2: Input the local image training data of each client into the decoupled global feature extractor and personalized feature extractor for feature extraction, and obtain the global features and personalized features at each level respectively;

[0011] Step 3: Introduce a cross-attention mechanism to calculate the similarity weight between the global features and personalized features at each level, and calculate the cross-attention-driven feature distillation loss based on the similarity weight;

[0012] Step 4: Input the global features into the global head to obtain the global prediction, input the personalized features into the personalized head to obtain the personalized prediction, supervise the personalized prediction and the true labels of the local image training data to obtain the cross-entropy loss, input the personalized features into the global head to obtain the cross-head prediction, and distill the personalized features with the global prediction to obtain the cross-head prediction loss.

[0013] Step 5: Construct a total loss using the cross-attention driven feature distillation loss, the cross-entropy loss, and the cross-head prediction loss. Train a personalized feature extractor and a personalized head based on the total loss. Then, upload the trained personalized feature extractor and personalized head to the server for aggregation. Download the aggregated feature extractor and aggregated head to all clients to initialize the global feature extractor, personalized feature extractor, and global head. Repeat Steps 1-4 until the personalized model converges, thereby obtaining an enhanced personalized model.

[0014] Step 6: Input the image test dataset to be classified for each client into the enhanced personalization model to obtain the image classification results output by the enhanced personalization model.

[0015] Optionally, each client deploys a global model, a personalized model, and a local private dataset, wherein the global model includes a global feature extractor and a global head, respectively represented as follows: and The personalized model includes a personalized feature extractor and a personalized head, respectively denoted as... and The local private dataset is represented as follows:

[0016] (1)

[0017] In the formula, This represents the local private dataset of the k-th client. This represents the i-th local private data sample. This represents the label corresponding to the i-th local private data sample. This represents the total amount of data owned by the k-th client.

[0018] Optionally, the expression for obtaining the global features and personalized features at each level is:

[0019] (2)

[0020] (3)

[0021] In the formula, Represents a set of global features. Let M represent the global features output by the Mth convolutional layer, where... , , and These represent the height, width, and channel dimensions, respectively. A set representing personalized characteristics. This represents the personalized feature output of the Nth convolutional layer, where... , , and These represent the height, width, and channel dimensions, respectively.

[0022] Optionally, the introduction of the cross-attention mechanism to calculate the similarity weights between the global features and personalized features at each level specifically involves:

[0023] Select the m-th feature from the global feature set. Select the nth feature from the personalized candidate feature set. For global features and personalized features A comparison is made between two methods: global average pooling and channel-by-channel pooling.

[0024] The similarity weights calculated based on global average pooling are used to assess the strength of knowledge transfer based on the distance defined by channel-wise pooling: For the global average pooling method, firstly, the similarity weights calculated based on global average pooling are used to assess the strength of knowledge transfer based on the distance defined by channel-wise pooling. and Perform global average pooling calculation, and then use the obtained features respectively. and Two linear transition parameters are used for linear transformation, and the transformed features are then used as... and It also introduces a cross-attention mechanism to measure query and key. and Similarities between them; for the channel-by-channel pooling method, respectively for and Perform channel-wise average pooling calculation, and use the calculated features for knowledge transfer. The expression is as follows:

[0025] (4)

[0026] (5)

[0027] In the formula, Indicates global average pooling. and It is the linear transition parameter between the m-th query and the n-th key. Indicates channel dimension, This represents the dimension of the m-th global feature. The dimension representing the nth personalized feature. and These are the activation functions for Query and Key, respectively;

[0028] The similarity weight is calculated using all queries and keys, expressed as follows:

[0029] (6)

[0030] In the formula, It is the similarity weight that captures the relationship between the m-th global feature and all personalized features. This represents a function that normalizes numerical values ​​to a probability distribution that sums to 1. This represents the m-th global feature after linear transformation. This represents the Nth personalized feature after linear transformation. Indicates bilinear weights, and Represents position encoding, by utilizing Global features pass their knowledge to personalized features.

[0031] Optionally, the expression for calculating the cross-attention-driven feature distillation loss based on similarity weights is:

[0032] (7)

[0033] In the formula, For cross-attention driven feature distillation loss, Let represent the similarity weight between the m-th global feature and the n-th personalized feature. This represents a combined function, consisting of a channel-wise average pooling layer and L2 normalization. Indicates to Perform upsampling or downsampling to match the feature map size of personalized features with the feature map size of global features.

[0034] Optionally, the obtained cross-head prediction loss is specifically as follows:

[0035] The last layer of features of the personalized model Pass to global header To obtain cross-head prediction Cross-head prediction Compared with the original predictions of the global model The KD loss between the two points is used as the target of the cross-prediction loss, and the expression is:

[0036] (8)

[0037] In the formula, For the cross-prediction loss, r represents the spatial location or region on the prediction graph, used to traverse the entire prediction graph to calculate the knowledge distillation loss. Denotes KL divergence, and These are the region selection principle and the normalization factor, wherein the region selection principle is to perform the selection equally across the entire prediction map. and Distillation between It is a constant function with a value of 1 in the cross-head prediction.

[0038] Optionally, the expression for the total loss is:

[0039] (9)

[0040] In the formula, For the total loss, Represents cross-entropy loss, The weights are used to control the distillation intensity of the cross-head prediction loss. is the weight used to control the distillation intensity of the feature distillation loss driven by cross-attention.

[0041] The beneficial effects of this invention are as follows: By using cross-attention-driven feature distillation and cross-head prediction, this invention effectively avoids the problems of invalid or inaccurate global knowledge transfer and optimization conflicts in personalized models caused by erroneous predictions in the global model during image classification, thereby improving the classification performance of personalized models. At the same time, this invention also provides new ideas for further combining and innovating personalized federated distillation and cross-attention mechanisms in image classification scenarios. Attached Figure Description

[0042] Figure 1 This is a schematic diagram of the overall framework of the invention.

[0043] Figure 2 This is a flowchart of the steps of the present invention;

[0044] Figure 3 This is a detailed framework diagram of cross-attention-driven feature distillation in an embodiment of the present invention;

[0045] Figure 4 This is an accuracy graph of the proposed method and baseline under different communication rounds in a real MNIST data scenario in the embodiments of the present invention;

[0046] Figure 5 This is an accuracy graph of the proposed method and baseline under different communication rounds in a real FMNIST data scenario in the embodiments of the present invention;

[0047] Figure 6 This is an accuracy graph of the proposed method and baseline under different communication rounds in a real CIFAR10 data scenario in the embodiments of the present invention;

[0048] Figure 7 This is an accuracy graph of the proposed method and baseline under different communication rounds in a real CIFAR100 data scenario in the embodiments of the present invention;

[0049] Figure 8 This is an accuracy graph of the proposed method and baseline under different communication rounds in a real Tiny data scenario according to the embodiments of the present invention. Detailed Implementation

[0050] To more clearly illustrate the uses, solutions, and advantages of this technology, a detailed explanation will be provided below with reference to accompanying figures and examples. It should be noted that the examples provided herein are for illustrative purposes only and are not intended to limit the scope of this technology.

[0051] The illustrations and specific parameter values ​​provided in the following examples are mainly for illustrating the basic concept of the present invention and for simulation verification of the present invention. In specific application environments, appropriate adjustments can be made according to the actual scenario and requirements.

[0052] Example 1: As Figure 2 As shown, this embodiment provides an image classification method based on personalized federated distillation (FedCACHD), which includes the following steps:

[0053] Step 1: Decouple the global model and personalized model of each client into a global feature extractor, a global head, a personalized feature extractor, and a personalized head;

[0054] Optionally, this embodiment uses a pre-trained convolutional neural network model (such as ResNet18) as both the global and personalized model for each client. Specifically, in the personalized federated distillation framework, the global model and the personalized model act as the teacher and student models, respectively. The global and personalized models are then decoupled into a global feature extractor, a global head, a personalized feature extractor, and a personalized head. The global and personalized feature extractors are used to extract global and personalized features, respectively. In this embodiment, the pre-trained convolutional neural network model ResNet18 includes 20 convolutional layers and 1 fully connected layer. The 20 convolutional layers are used as the feature extractor, and the 1 fully connected layer is used as the head. Furthermore, the 20 convolutional layers of the feature extractor are decoupled and used to extract features from each layer.

[0055] Optionally, the personalized federated distillation framework includes a server and multiple clients, each client deploying a global model, a personalized model, and a local private dataset. The global model includes a global feature extractor and a global header, respectively represented as follows: and The personalized model includes a personalized feature extractor and a personalized head, respectively denoted as... and The local private dataset is represented as follows:

[0056] (1)

[0057] In the formula, This represents the local private dataset of the k-th client. This represents the i-th local private data sample. This represents the label corresponding to the i-th local private data sample. This represents the total amount of data owned by the k-th client.

[0058] Step 2: Input the local image training data of each client into the decoupled global feature extractor and personalized feature extractor for feature extraction, and obtain the global features and personalized features at each level respectively;

[0059] Optionally, the expression for obtaining the global features and personalized features at each level is:

[0060] (2)

[0061] (3)

[0062] In the formula, Represents a set of global features. Let M represent the global features output by the Mth convolutional layer, where... , , and These represent the height, width, and channel dimensions, respectively. A set representing personalized characteristics. This represents the personalized feature output of the Nth convolutional layer, where... , , and These represent the height, width, and channel dimensions, respectively.

[0063] It is understandable that each level of features has its own feature map size and channel dimension, in the form of: , where H, W and d represent the height, width and channel dimension, respectively.

[0064] Step 3: Introduce a cross-attention mechanism to calculate the similarity weight between the global features and personalized features at each level, and calculate the cross-attention-driven feature distillation loss based on the similarity weight;

[0065] Optionally, the introduction of the cross-attention mechanism to calculate the similarity weights between the global features and personalized features at each level specifically involves:

[0066] Select the m-th feature from the global feature set. Select the nth feature from the personalized candidate feature set. For global features and personalized features A comparison is made between two methods: global average pooling and channel-by-channel pooling.

[0067] The similarity weights calculated based on global average pooling are used to assess the strength of knowledge transfer based on the distance defined by channel-wise pooling: For the global average pooling method, firstly, the similarity weights calculated based on global average pooling are used to assess the strength of knowledge transfer based on the distance defined by channel-wise pooling. and Perform global average pooling calculation, and then use the obtained features respectively. and Two linear transition parameters are used for linear transformation, and the transformed features are then used as... and It also introduces a cross-attention mechanism to measure query and key. and Similarities between them; for the channel-by-channel pooling method, respectively for and Perform channel-wise average pooling calculation, and use the calculated features for knowledge transfer. The expression is as follows:

[0068] (4)

[0069] (5)

[0070] In the formula, Indicates global average pooling. and It is the linear transition parameter between the m-th query and the n-th key. Indicates channel dimension, This represents the dimension of the m-th global feature. The dimension representing the nth personalized feature. and These are the activation functions for Query and Key, respectively;

[0071] The similarity weight is calculated using all queries and keys, expressed as follows:

[0072] (6)

[0073] In the formula, It is the similarity weight that captures the relationship between the m-th global feature and all personalized features. This represents a function that normalizes numerical values ​​to a probability distribution that sums to 1. This represents the m-th global feature after linear transformation. This represents the Nth personalized feature after linear transformation. Indicates bilinear weights, and Represents position encoding, by utilizing Global features pass their knowledge to personalized features.

[0074] Optionally, the expression for calculating the cross-attention-driven feature distillation loss based on similarity weights is:

[0075] (7)

[0076] In the formula, For cross-attention driven feature distillation loss, Let represent the similarity weight between the m-th global feature and the n-th personalized feature. This represents a combined function, consisting of a channel-wise average pooling layer and L2 normalization. Indicates to Perform upsampling or downsampling to match the feature map size of personalized features with the feature map size of global features.

[0077] Step 4: Input the global features into the global head to obtain the global prediction, input the personalized features into the personalized head to obtain the personalized prediction, supervise the personalized prediction and the true labels of the local image training data to obtain the cross-entropy loss, input the personalized features into the global head to obtain the cross-head prediction, and distill the personalized features with the global prediction to obtain the cross-head prediction loss.

[0078] Optionally, the obtained cross-head prediction loss is specifically as follows:

[0079] The last layer of features of the personalized model Pass to global header To obtain cross-head prediction Cross-head prediction Compared with the original predictions of the global model The KD loss between the individual heads is used as the target of the cross-head prediction loss to mitigate the conflict between fitting the true labels and the global model predictions simultaneously for personalized heads. The expression is:

[0080] (8)

[0081] In the formula, For the cross-prediction loss, r represents the spatial location or region on the prediction graph, used to traverse the entire prediction graph to calculate the knowledge distillation loss. Denotes KL divergence, and These are the region selection principle and the normalization factor, wherein the region selection principle is to perform the selection equally across the entire prediction map. and Distillation between It is a constant function with a value of 1 in the cross-head prediction.

[0082] Step 5: Construct a total loss using the cross-attention driven feature distillation loss, the cross-entropy loss, and the cross-head prediction loss. Train a personalized feature extractor and a personalized head based on the total loss. Then, upload the trained personalized feature extractor and personalized head to the server for aggregation. Download the aggregated feature extractor and aggregated head to all clients to initialize the global feature extractor, personalized feature extractor, and global head. Repeat Steps 1-4 until the personalized model converges, thereby obtaining an enhanced personalized model.

[0083] Optionally, the expression for the total loss is:

[0084] (9)

[0085] In the formula, For the total loss, Represents cross-entropy loss, The weights are used to control the distillation intensity of the cross-head prediction loss. is the weight used to control the distillation intensity of the feature distillation loss driven by cross-attention.

[0086] Step 6: Input the image test dataset to be classified for each client into the enhanced personalization model to obtain the image classification results output by the enhanced personalization model.

[0087] The technical solution of this embodiment will be further explained below with reference to the accompanying drawings, such as... Figure 1 The diagram shown illustrates the overall framework of FedCACHD. FedCACHD aims to improve the personalization performance of Personalized Federated Learning (PFL) across heterogeneous clients. To achieve this, this embodiment uses a supervised multi-class classification task as the primary task setting, dividing the training model into shared feature extractors. and predicting the head The former maps input samples to a feature space, while the latter maps features to a label space. , , These represent the N-dimensional input sample space, D-dimensional feature space, and C-dimensional label space, respectively. In this embodiment, the last fully connected layer in the training model is referred to as the head. Specifically, each client k in FedCACHD mainly consists of six parts, including a global feature extractor (defined as...). ), global header (defined as Personalized feature extractor (defined as) Personalized head (defined as) FedCACHD employs three techniques: Cross-Attention Driven Feature Distillation (CAD) and Cross-Head Prediction (CHD). In FedCACHD, CAD is used to construct similarities between global and individual features, guiding individual features to learn from similar global features. CHD is used to avoid inconsistencies arising from potential mispredictions in the global head.

[0088] In each round of communication, each client uploads its personalized feature extractor and personalized header to the server, while retaining its global feature extractor and global header locally. The server then receives the personalized feature extractor and personalized header uploaded by each client and aggregates them (using FedAvg's weighted aggregation method). Finally, the server sends the aggregated personalized feature extractor and personalized header back to each client device. After receiving the aggregated model from the server, the client initializes its device using the aggregated model. , and Meanwhile, each client's global feature extractor and global header It is used only for local inference and remains frozen throughout the local learning process. Additionally, it includes a personalized head. It will not be updated in the aggregated personalized header.

[0089] like Figure 3 The diagram illustrates a detailed overview of the CAD method. This embodiment selects the m-th feature from the global candidate feature set. Select the nth feature from the personalized candidate feature set. For global features and personalized features This paper compares two methods: global average pooling and channel-by-channel pooling. Then, the similarity weights calculated by global average pooling are used to transfer the strength of knowledge based on the distance defined by channel-by-channel pooling. Specifically, for global average pooling, this embodiment first... and Perform global average pooling calculation, and then use the obtained features respectively. and Two linear transition parameters are used for linear transformation, and the transformed features are then used as... and In order to measure and To address the similarities between them, this embodiment introduces the concepts of query and key from the cross-attention mechanism. For channel-by-channel pooling, this embodiment respectively... and Perform channel-wise average pooling calculations and use the calculated features for knowledge transfer.

[0090] Based on the specific implementation details, the effectiveness of the technical solution of the present invention will be demonstrated through experiments.

[0091] Specifically, this embodiment evaluates FedCACHD on five benchmark datasets for image classification tasks, including MNIST, FMNIST, CIFAR 10, CIFAR 100, and TinyImagenet. Data for each client is divided into a training set (75%) and a test set (25%), with the default number of clients set to 20. This embodiment also constructs two heterogeneous data scenarios: ill-conditioned and real-world. In the ill-conditioned scenario, two classes are extracted from each of the MNIST, FMNIST, and CIFAR 10 datasets, and ten classes are extracted from the CIFAR 100 dataset, and assigned to each client. In the real-world scenario, this embodiment samples the proportion of each label data from the client device using a Dirichlet distribution with a concentration parameter `dir`, and sets the default value of `dir` to 0.1 (the smaller the value of `dir`, the stronger the data heterogeneity). This embodiment uses the classic deep neural network ResNet18 on the five datasets.

[0092] Furthermore, for the FedCACHD proposed in this embodiment, the experiment sets the default value of the client joining ratio to 1. The local batch size is set to 128, the local training epoch is set to 1, and the local learning rate of the ResNet18 deep neural network is set to 0.1. The hyperparameters of the logits distillation loss for CHD are also set. Set the hyperparameters of the characteristic distillation loss in the CAD module. This embodiment demonstrates the average test accuracy of all benchmark frameworks. It runs 200 communication rounds on five datasets: MNIST, FMNIST, CIFAR 10, CIFAR 100, and TinyImagenet. Three experiments are performed on each framework, and the mean and standard deviation of the three experiments are taken as the final result. The best result is highlighted in bold. Furthermore, to ensure the fairness of the experiment, this embodiment implements the proposed FedCACHD and all benchmark frameworks on a machine with NVIDIA GeForce RTX 4060 GPUs and in a PyTorch 2.5.1 environment. This embodiment conducts experiments on five image benchmark datasets—MNIST, FMNIST, CIFAR 10, CIFAR 100, and TinyImagenet—in two heterogeneous scenarios: ill-conditioned and real-world. The results compared with all baselines are shown in Table 1.

[0093] Table 1 Experimental Results

[0094]

[0095] As shown in the table, FedCACHD demonstrates the best performance across two heterogeneous scenarios and five datasets when using ResNet18. In the ill-conditioned scenarios of the CIFAR 100 dataset, FedCACHD's personalization performance is 29.76% higher than the traditional federated learning method FedProx, 5.37% higher than the knowledge distillation-based method FedKD, and 0.71% higher than the personalization method FedROD. In the real-world scenarios of the CIFAR 100 and TinyImagenet datasets, FedCACHD's personalization performance is 24.91% and 22.86% higher than the traditional federated learning methods FedProx and FedAvg, respectively, 6.83% and 11.97% higher than the knowledge distillation-based method FedKD, and 0.06% and 0.15% higher than the personalization methods FedROD and FedCP, respectively. The following example will provide a performance comparison analysis of FedCACHD with other baselines.

[0096] In Table 1, the traditional FL learning methods FedAvg and FedProx exhibit poor performance in both heterogeneous scenarios. This is because the simple aggregation of the global model cannot fit the local heterogeneous data well, especially in scenarios with even greater heterogeneity. In contrast, FedCACHD trains both the global model and the personalized model simultaneously. The personalized model can better fit the local data, thus demonstrating better performance.

[0097] FedGen mitigates the challenge of data heterogeneity through Logits distillation, but optimizing only a single global model is not well-suited to local data. FedNTD alleviates the catastrophic forgetting problem by passing non-target class knowledge, but this may transmit incorrect knowledge, and its performance may degrade compared to FedAvg and FedProx in scenarios with high heterogeneity. FedKD achieves mutual knowledge transfer between teachers and students by maintaining a tutor model and a student model on each client, transferring not only Logits knowledge but also the hidden feature knowledge in the intermediate layers. Therefore, it has better performance than FedGen and FedNTD. The above methods do not consider the similarity between global and personalized features, as well as the possible erroneous global knowledge during the transfer process, which limits personalized performance. In contrast, FedCACHD evaluates the similarity between global and personalized features and improves personalized performance by passing global knowledge that conforms to the individual.

[0098] Ditto achieves better fitting of local data by training personalized models on each device while learning the global model, and by using regularization to keep the personalized and global models consistent. FedPer, FedRep, FedROD, and FedCP decouple the model into a feature extractor and specific heads. FedPer and FedRep share the parameters of the feature extractor, FedROD learns the parameters of both heads, and FedCP separates global and personalized features to achieve the PFL (Progressive Functional Rendering) objective. These methods show better performance than traditional FL methods but fail to fully exploit global knowledge. In contrast, FedCACHD fully exploits global features and Logits, thus exhibiting superior personalized performance.

[0099] Furthermore, such as Figures 4-8 As shown in the figure, the accuracy curves of FedCACHD compared to five baseline methods—FedAvg, FedProx, FedGen, FedNTD, and FedKD—are displayed at different communication rounds. The figure demonstrates that FedCACHD achieves optimal performance compared to the five baseline methods (FedAvg, FedProx, FedGen, FedNTD, and FedKD) under various conditions, and converges to the optimum quickly.

[0100] As a specific example, taking the CIFAR10 image classification task as a concrete application scenario, the CIFAR10 dataset contains 10 categories (airplane, car, bird, cat, deer, dog, frog, horse, boat, and truck), with 5000 training images and 1000 test images for each category. The image size is 32×32×3 (RGB color images). The experiment sets up 20 clients. The local training data for each client is sampled from the CIFAR10 training set (50,000 images) using a Dirichlet distribution (dir=0.1) to ensure data heterogeneity. The test data uniformly uses the CIFAR10 test set (10,000 images). The following details the specific inputs, processing procedures, and output results of each of the six steps in this embodiment's technical solution.

[0101] S1: Model Decoupling and Initialization

[0102] Input: Pre-trained ResNet18 model (adapted to CIFAR10), 20 client devices.

[0103] Solution: Split ResNet18 into 4 core components (2 global and 2 personalized).

[0104] Global side: Global feature extractor (18 convolutional layers, extracting common features) + global head (1 fully connected layer, 10-class output), parameters frozen;

[0105] Personalization side: Personalized feature extractor (structure is the same as the global one, parameters are copied from the global one) + personalized header (structure is the same as the global header), parameters are trainable.

[0106] Output: Each client receives 4 components, and the feature extractor can extract features independently in 18 layers.

[0107] S2: Local Image Feature Extraction

[0108] Input: Local CIFAR10 training data from the client (Example: Client 1 has 5000 images, with "cat" class accounting for 15% and "car" class accounting for 8%), and two feature extractors from S1.

[0109] Processing: Input data in batches (128 images / batch), and extract 18 layers of features (the last layer of features has a size of 4×4×512).

[0110] Output: Two sets of features:

[0111] Global features (like (For the last layer of global features).

[0112] Personalized features (like (This is the last layer of personalized features).

[0113] S3: Loss Calculation

[0114] Input: S2 and Cross-attention parameters.

[0115] Processing: Calculate the similarity weights for each pair of features using cross-attention (e.g., ...). and The weight is 0.82, indicating high correlation, and then the feature difference loss is calculated.

[0116] Output: Single batch CAD loss (Example: =0.35, the smaller the value, the more similar the features).

[0117] S4: and Loss Calculation

[0118] Input: S2 / The global head and personalized head in S1, and the true labels of the training image (e.g., the label for "cat" is [0,0,0,1,0,0,0,0,0,0]).

[0119] deal with:

[0120] Global prediction: Input the global header and output the category probability (e.g., the global prediction probability for "cat" is 0.78).

[0121] Personalized prediction: Input a personalized header and output a probability (e.g., the personalized prediction probability for "cat" = 0.89).

[0122] Calculate two losses: cross-entropy loss =0.116 (Personalized Prediction vs. True Label), CHD Loss =0.072 (Cross-head prediction vs. global prediction).

[0123] Output: =0.116、 =0.072.

[0124] S5: Model Training and Aggregation

[0125] Input: 3 losses in S3-S4, hyperparameters (α=0.01, β=0.09, learning rate 0.1), and a total of 200 communication rounds.

[0126] deal with:

[0127] Total loss: = 0.116 + 0.01×0.072 + 0.09×0.35 = 0.148;

[0128] Local training: Update personalized component parameters using backpropagation;

[0129] Aggregation: The client uploads the updated personalized components, the server aggregates them according to the data volume (e.g., client 1 weight = 5%), and then distributes them to all clients;

[0130] Repeat S2-S4 until convergence (accuracy fluctuation ≤ 0.1%).

[0131] Output: Converged personalized model (Example: Client 1 model has an accuracy of 89.16% on the local validation set).

[0132] S6: Image Classification and Verification

[0133] Input: CIFAR10 test set (10,000 images to be classified), personalized model in S5.

[0134] Processing: Input the image to be classified into the model, take the category with the highest probability as the result, and calculate the accuracy.

[0135] Output: Classification accuracy = 90.16%, an improvement of 1.94 percentage points compared to the traditional FedAvg (88.22%), achieving accurate classification of CIFAR10 images.

[0136] In summary, this invention introduces Cross-Attention Driven Feature Distillation (CAD) and Cross-Head Prediction (CHD) within the personalized federated distillation framework. CAD decouples the global and personalized models into finer-grained global and personalized features, constructs similarity weights between global and personalized features through a cross-attention mechanism, and extracts global features that better reflect the personalized features. CHD passes the decoupled personalized features to the global head and generates cross-head predictions for distillation. This invention enhances personalized performance by integrating CAD and CHD. Extensive experiments on different image datasets demonstrate that FedCACHD outperforms state-of-the-art FL methods, ensuring the transfer of invalid or inaccurate global knowledge during personalized federated distillation in image classification, and improving the classification performance of the personalized model.

[0137] Obviously, the foregoing embodiments are only used to illustrate the technical solutions of the present invention and are not intended to limit the application of the present invention. Any adjustments, substitutions, or optimizations made under the guidance of the basic principles and spirit of the present invention should be considered as part of the protection scope of the present invention.

Claims

1. An image classification method based on personalized federated distillation, characterized in that, The method includes the following steps: Step 1: Decouple the global model and personalized model of each client into a global feature extractor, a global head, a personalized feature extractor, and a personalized head; Step 2: Input the local image training data of each client into the decoupled global feature extractor and personalized feature extractor for feature extraction, and obtain the global features and personalized features at each level respectively; Step 3: Introduce a cross-attention mechanism to calculate the similarity weight between the global features and personalized features at each level, and calculate the cross-attention-driven feature distillation loss based on the similarity weight; Step 4: Input the global features into the global head to obtain the global prediction, input the personalized features into the personalized head to obtain the personalized prediction, supervise the personalized prediction and the true labels of the local image training data to obtain the cross-entropy loss, input the personalized features into the global head to obtain the cross-head prediction, and distill the personalized features with the global prediction to obtain the cross-head prediction loss. Step 5: Construct a total loss using the cross-attention driven feature distillation loss, the cross-entropy loss, and the cross-head prediction loss. Train a personalized feature extractor and a personalized head based on the total loss. Then, upload the trained personalized feature extractor and personalized head to the server for aggregation. Download the aggregated feature extractor and aggregated head to all clients to initialize the global feature extractor, personalized feature extractor, and global head. Repeat Steps 1-4 until the personalized model converges, thereby obtaining an enhanced personalized model. Step 6: Input the image test dataset to be classified for each client into the enhanced personalization model to obtain the image classification results output by the enhanced personalization model.

2. The image classification method based on personalized federated distillation according to claim 1, characterized in that, Each client deploys a global model, a personalized model, and a local private dataset. The global model includes a global feature extractor and a global header, which are represented as follows: and The personalized model includes a personalized feature extractor and a personalized head, respectively denoted as... and The local private dataset is represented as follows: (1); In the formula, This represents the local private dataset of the k-th client. This represents the i-th local private data sample. This represents the label corresponding to the i-th local private data sample. This represents the total amount of data owned by the k-th client.

3. The image classification method based on personalized federated distillation according to claim 2, characterized in that, The expressions for obtaining the global features and personalized features at each level are as follows: (2); (3); In the formula, Represents a set of global features. Let M represent the global features output by the Mth convolutional layer, where... , , and These represent the height, width, and channel dimensions, respectively. A set representing personalized characteristics. This represents the personalized feature output of the Nth convolutional layer, where... , , and These represent the height, width, and channel dimensions, respectively.

4. The image classification method based on personalized federated distillation according to claim 3, characterized in that, The introduction of the cross-attention mechanism to calculate the similarity weight between the global features and personalized features at each level is specifically as follows: Select the m-th feature from the global feature set. Select the nth feature from the personalized candidate feature set. For global features and personalized features A comparison is made between two methods: global average pooling and channel-by-channel pooling. The similarity weights calculated based on global average pooling are used to assess the strength of knowledge transfer based on the distance defined by channel-wise pooling: For the global average pooling method, firstly, the similarity weights calculated based on global average pooling are used to assess the strength of knowledge transfer based on the distance defined by channel-wise pooling. and Perform global average pooling calculation, and then use the obtained features respectively. and Two linear transition parameters are used for linear transformation, and the transformed features are then used as... and It also introduces a cross-attention mechanism to measure query and key. and Similarities between them; for the channel-by-channel pooling method, respectively for and Perform channel-wise average pooling calculation, and use the calculated features for knowledge transfer. The expression is as follows: (4); (5); In the formula, Indicates global average pooling. and It is the linear transition parameter between the m-th query and the n-th key. Indicates channel dimension, This represents the dimension of the m-th global feature. The dimension representing the nth personalized feature. and These are the activation functions for Query and Key, respectively; The similarity weight is calculated using all queries and keys, expressed as follows: (6); In the formula, It is the similarity weight that captures the relationship between the m-th global feature and all personalized features. This represents a function that normalizes numerical values ​​to a probability distribution that sums to 1. This represents the m-th global feature after linear transformation. This represents the Nth personalized feature after linear transformation. Indicates bilinear weights, and Represents position encoding, by utilizing Global features pass their knowledge to personalized features.

5. The image classification method based on personalized federated distillation according to claim 4, characterized in that, The expression for calculating the cross-attention-driven feature distillation loss based on similarity weights is as follows: (7); In the formula, For cross-attention driven feature distillation loss, Let represent the similarity weight between the m-th global feature and the n-th personalized feature. This represents a combined function, consisting of a channel-wise average pooling layer and L2 normalization. Indicates to Perform upsampling or downsampling to match the feature map size of personalized features with the feature map size of global features.

6. The image classification method based on personalized federated distillation according to claim 3, characterized in that, The obtained cross-head prediction loss is specifically as follows: The last layer of features of the personalized model Pass to global header To obtain cross-head prediction Cross-head prediction Compared with the original predictions of the global model The KD loss between the two points is used as the target of the cross-prediction loss, and the expression is: (8); In the formula, For the cross-prediction loss, r represents the spatial location or region on the prediction graph, used to traverse the entire prediction graph to calculate the knowledge distillation loss. Denotes KL divergence, and These are the region selection principle and the normalization factor, wherein the region selection principle is to perform the selection equally across the entire prediction map. and Distillation between It is a constant function with a value of 1 in the cross-head prediction.

7. The image classification method based on personalized federated distillation according to claim 6, characterized in that, The expression for the total loss is: (9); In the formula, For the total loss, Represents cross-entropy loss, The weights are used to control the distillation intensity of the cross-head prediction loss. is the weight used to control the distillation intensity of the feature distillation loss driven by cross-attention.