A medical image classification method based on federated learning and transfer learning
By employing federated learning and enhanced transfer learning methods, we have addressed the issues of data privacy and heterogeneity in medical image classification, achieving efficient model training and performance improvement.
Patent Information
- Authority / Receiving Office
- CN · China
- Patent Type
- Patents(China)
- Current Assignee / Owner
- GUILIN UNIV OF ELECTRONIC TECH
- Filing Date
- 2023-05-19
- Publication Date
- 2026-06-16
AI Technical Summary
In medical image classification, the problems of data privacy protection and data heterogeneity leading to model performance degradation are difficult to solve effectively.
By employing federated learning and reinforced transfer learning, the model is trained without sharing data, and a small number of data samples are stored in a central server to correct the global model, thereby reducing the impact of data heterogeneity.
While protecting data privacy, it significantly improves the model's generalization ability and performance, outperforming traditional federated learning methods.
Smart Images

Figure CN116543226B_ABST
Abstract
Description
Technical Field
[0001] This invention relates to image classification methods commonly used in the field of computer vision, specifically to a medical image classification method that combines federated learning and transfer learning. Background Technology
[0002] Image classification is an image processing method that distinguishes different categories of images based on the various features they exhibit. It utilizes machine learning algorithms to acquire semantic information about the features in different image categories, performs quantitative analysis, and identifies the image as belonging to one of several categories. In the medical field, image classification is often used to identify whether a region in medical imaging has a lesion or to identify the type of lesion. In practical applications, medical image classification research often requires a large amount of data. However, due to patient data privacy concerns, medical datasets are no longer publicly available. Furthermore, in the medical field, differences in imaging equipment, imaging methods, and even the specific patient, affected area, and disease stage can lead to significant differences in the generated images. This data heterogeneity problem can cause performance degradation in the trained classification model. This invention aims to address these two issues. Summary of the Invention
[0003] This invention proposes a medical image classification method based on federated learning and transfer learning, which effectively reduces the impact of data heterogeneity on model performance while protecting patient data privacy.
[0004] To address the problems existing in the prior art, the technical solution adopted in this invention is as follows:
[0005] This invention proposes a medical image classification method based on federated learning and transfer learning. On the one hand, federated learning effectively addresses data privacy issues because it allows for distributed model training without sharing data. On the other hand, enhanced transfer learning is used to eliminate differences between models trained on data from different hospitals. Furthermore, a small number of data samples are stored on a central server to correct the global model and reduce the impact of data heterogeneity.
[0006] A medical image classification method based on federated learning and transfer learning, the method comprising the following steps:
[0007] S100: The central server sends the current global model to each client.
[0008] S200. Each client trains the global model as a local model locally. The local training adopts an enhanced transfer learning strategy, including: a first stage of training the global model using the local model to retain local knowledge; a second stage of mutual learning between the global model and the local model to exchange knowledge; and a third stage of training the local model using the global model to transfer global knowledge.
[0009] S300. After the local training is completed, each client sends the trained local model to the central server.
[0010] S400, The central server uses a weighted average method to aggregate each trained local model, where the size of the data represents the weight.
[0011] Preferably, the local training includes the following steps:
[0012] S210. In the first stage, since directly updating the local model with the global model will eliminate the local knowledge learned by the local model, this method uses the local model to train the global model. The training loss function is shown in (1).
[0013]
[0014] Where x represents a data sample, L CE1 (x) is the cross-entropy loss in the first stage, w glob (x i ) and w local (x i ) represent samples x respectively i The subsequent feature distribution, where N represents the number of sample data. KL represents the KL divergence (i.e., relative entropy), calculated as follows:
[0015]
[0016] p and q represent two different feature distributions. Due to the asymmetry of the KL divergence, p is the true distribution of the feature, and q is the fitted distribution, which is used to calculate the similarity between distribution q and distribution p.
[0017] S220, Second stage: Set a threshold λ1. When the accuracy of the global model and the local model meets the condition... Then, the second stage of training begins, which involves mutual learning between the global model and the local model. The loss function of the global model is shown in Equation (1), and the loss function of the local model is shown in Equation (3). This training method can improve the generalization ability of all client-side local models. CE2 (x) represents the cross-entropy loss in the second stage;
[0018]
[0019] S230, Third stage: Set a threshold λ2 (0 < λ1 < λ2 < 1), when the accuracy of the global model and the local model meets the condition. Then, the training enters the third stage, which is to use the global model to train the local model. At this time, the accuracy of the two models has reached a high degree of similarity, and the loss function of the local model is shown in formula (3).
[0020] Preferably, the model aggregation method further includes the following steps:
[0021] To mitigate the impact of data heterogeneity, the central server collects a small amount of uniform sample data to train the global model for correction.
[0022] Advantages and effects of the present invention:
[0023] In federated learning, the server aggregates the parameters of each local model to generate a global model, which is then distributed to each local client. However, the distributed model is not suitable for directly replacing the previous local model, as direct replacement would eliminate the knowledge learned by the local model and reduce its optimization performance in the next iteration. Therefore, the feature distribution learned by the global model can be considered the source domain, while the feature distribution learned by the local model can be considered the target domain. In reinforcement transfer learning, the update steps of the local model in each round of federated learning are mainly as follows: First, this method treats the local model as the source domain and the global model as the target domain, aiming to allow the global model to learn the local knowledge of the local model and avoid performance degradation of the global model; second, this method allows the global model and local model to learn from each other, exchanging global and local knowledge; finally, this method treats the global model as the source domain and the local model as the target domain, maximizing the transfer of global knowledge to the local model.
[0024] Experiments were conducted on two publicly available datasets, and the proposed method achieved better results across various metrics compared to traditional federated learning methods. This demonstrates that the proposed method exhibits superior performance and generalization ability when dealing with data heterogeneity issues, and shows promising application prospects. Attached Figure Description
[0025] Figure 1 This is a framework diagram for a medical image classification method based on federated learning and transfer learning.
[0026] Figure 2 This is a diagram of the DenseNet network framework.
[0027] Figure 3This is the overall algorithm flowchart for the model.
[0028] Figure 4 This is a flowchart illustrating the overall process of a medical image classification method based on federated learning and transfer learning. Detailed Implementation
[0029] Federated Learning (FL) is a distributed machine learning technique that allows collaborating institutions (such as hospitals) to train models together by storing data locally. During training, only model parameters are exchanged, not local data, thus ensuring collaborative model training without the need for data sharing and achieving data privacy protection.
[0030] Transfer learning is a machine learning method that, compared to traditional machine learning methods, uses a model trained for task A as a starting point and reuses it to train task B. It improves the new task by transferring knowledge from the previously learned task. In simpler terms, transfer learning uses a pre-trained model as the starting point for a new task, continuing to train it. This not only improves the model's generalization ability but also allows it to improve faster and converge better.
[0031] In transfer learning, the learned data features and feature distribution are called the source domain, while the data feature distribution to be learned is called the target domain. The core idea is to find and utilize the similarities between the source and target domains, quantitatively calculate the degree of similarity, and take necessary learning methods to increase the similarity between the two domains. This is similar to Generative Adversarial Networks (GANs), where the generator generates samples, and the discriminator determines whether a sample is real data or generated by the generator. The generator and discriminator compete against each other, completing adversarial training. In transfer learning, model training for the target domain can be viewed as extracting features from the target domain data, continuously learning the features of the domain data until the discriminator can no longer distinguish between the two domains.
[0032] Reinforced transfer learning, as an improvement on transfer learning methods, has two main objectives: first, to transfer feature information from the source domain to the target domain, avoiding the elimination of knowledge learned by the local model; and second, to enable mutual learning between the source and target domains, increasing the similarity between the two domains and allowing the feature information learned by the global and local models to be integrated, thereby improving the performance of the global model. In each round of local model updates in federated learning, the steps of reinforced transfer learning are as follows: First, this method treats the local model as the source domain and the global model as the target domain, aiming to allow the global model to learn local knowledge from the local model and avoid performance degradation of the global model; second, this method allows the global and local models to learn from each other, exchanging global and local knowledge; finally, this method treats the global model as the source domain and the local model as the target domain, maximizing the transfer of global knowledge to the local model.
[0033] like Figure 1 As shown, a medical image classification method based on federated learning and transfer learning is presented, wherein the trained model framework consists of a densely connected convolutional network (DenseNet) for classification. Figure 2 As shown in the diagram, DenseNet consists of a linear mapping layer and a fully connected layer. DenseNet is used to extract features from the image, the linear mapping layer is used to map the feature vectors to the specified dimensions, and the fully connected layer is used to generate the predicted value for each category.
[0034] like Figure 4 As shown, the medical image classification method based on federated learning and transfer learning of the present invention mainly includes the following steps:
[0035] S100: The central server sends the current global model to each client.
[0036] S200. Each client trains the global model as a local model locally. The local training adopts an enhanced transfer learning strategy, including: a first stage of training the global model using the local model to retain local knowledge; a second stage of mutual learning between the global model and the local model to exchange knowledge; and a third stage of training the local model using the global model to transfer global knowledge.
[0037] S300. After the local training is completed, each client sends the trained local model to the central server.
[0038] S400, The central server uses a weighted average method to aggregate each trained local model, where the size of the data represents the weight.
[0039] As a preferred approach, the local training phase may include the following steps:
[0040] S210. In the first stage, since directly updating the local model with the global model will eliminate the local knowledge learned by the local model, this method uses the local model to train the global model. The training loss function is shown in (1).
[0041]
[0042] Where x represents a data sample, L CE1 (x) is the cross-entropy loss in the first stage, w glob (x i ) and w local (x i ) represent samples x respectively i The subsequent feature distribution, where N represents the number of sample data. KL represents the KL divergence (i.e., relative entropy), calculated as follows:
[0043]
[0044] p and q represent two different feature distributions. Due to the asymmetry of KL divergence, p is the true distribution of the feature, and q is the fitted distribution, which is used to calculate the similarity between distribution q and distribution p.
[0045] S220, Second stage: Set a threshold λ1. When the accuracy of the global model and the local model meets the condition... Then, the second stage of training begins, which involves mutual learning between the global model and the local model. The loss function of the global model is shown in Equation (1), and the loss function of the local model is shown in Equation (3). This training method can improve the generalization ability of all client-side local models. CE2 (x) represents the cross-entropy loss in the second stage.
[0046]
[0047] S230, Third stage: Set a threshold λ2 (0 < λ1 < λ2 < 1), when the accuracy of the global model and the local model meets the condition. Then, the training enters the third stage, which is to use the global model to train the local model. At this time, the accuracy of the two models has reached a high degree of similarity, and the loss function of the local model is shown in formula (3).
[0048] As a preferred approach, the model aggregation method further includes the following steps:
[0049] To mitigate the impact of data heterogeneity, the central server collects a small amount of uniform sample data to train the global model for correction.
[0050] The overall training algorithm flow of the model is as follows: Figure 3 As shown.
[0051] This invention presents experiments based on the above methods. In the experiments, the classification model uses DenseNet as the backbone network, and then generates predicted values for each class through linear projection layers and fully connected layers. The deep learning framework is implemented based on PyTorch, and other relevant experimental parameters are set as follows: batch size is set to 8, the optimizer uses Adam (momentum of 0.9 and 0.99), the learning rate is set to 1e-4, and the total number of communication rounds is set to 200. Meanwhile, after conducting ablation experiments with different threshold parameters, we set the threshold parameters λ1 and λ2 to 0.7 and 0.9, respectively.
[0052] The method of this invention was tested on two publicly available datasets. To simulate the experimental setup of federated learning, the two datasets were divided according to the Dirichlet distribution. In Task 1, the dataset was divided into 10 subsets, while in Task 2, it was divided into 5 subsets. The purpose of using the Dirichlet distribution was to simulate the client class imbalance situation. Five metrics—Accuracy, Specificity, Sensitivity, AUC, and F1 score—were used to comprehensively analyze the experimental results.
[0053] The experiments compared our proposed method with other traditional federated learning methods. The results are shown in Tables 1 and 2, where Central represents centralized training, and FedAvg and FedProx are two traditional federated learning methods. The experiments demonstrate that our proposed method outperforms other traditional federated learning methods in all evaluation metrics. This indicates that our proposed method achieves better results when handling heterogeneous data, and overall, it surpasses other federated learning methods in both model accuracy and generalization ability.
[0054] Table 1 compares our proposed method with current state-of-the-art federated learning methods in Task 1.
[0055] Method Accuracy Specificity Sensitivity AUC F1 Central 88.87% 87.27% 74.11% 91.05% 73.48% FedAvg 87.92% 85.66% 75.55% 90.55% 73.86% FedProx 88.59% 86.29% 74.08% 91.52% 73.61% ETL-FL(our) 90.13% 88.07% 76.54% 91.89% 74.93%
[0056] Table 2 compares our proposed method with current state-of-the-art federated learning methods in Task 2.
[0057] Method Accuracy Specificity Sensitivity AUC F1 Central 92.99% 94.32% 88.98% 96.85% 88.42% FedAvg 92.36% 94.19% 89.25% 96.47% 87.66% FedProx 92.84% 94.11% 89.81% 97.07% 88.39% ETL-FL(our) 94.73% 95.44% 89.27% 97.19% 89.42%
[0058] Furthermore, the effectiveness of the method was demonstrated through relevant ablation experiments.
[0059] 1) The effectiveness of the improvement strategy in this method
[0060] This invention makes two main improvements to federated learning: 1) it employs reinforced transfer learning (ETL) for model training during the local training phase; and 2) it uses a small number of data samples for global model training during the model aggregation phase. Ablation experiments were conducted to remove reinforced transfer learning (w / o ETL) and to remove the server-side training strategy (w / o ST), with results shown in Table 3. As can be seen, both improved strategies of this method achieved good performance in both Task 1 and Task 2. In Task 1, our two improved strategies resulted in improvements of 2.62% and 2.79%, respectively, while in Task 2, the improvements were 3.28% and 2.60%, respectively. This demonstrates that this method plays a crucial role in addressing the data heterogeneity problem in real-world medical federated learning scenarios.
[0061] Table 3 shows the impact of two model improvement strategies on the experimental results.
[0062]
[0063] Note: ETL: Enhanced Transfer Learning Improvement Strategy; ST: Server-Side Training Strategy
[0064] 2) The impact of training thresholds λ1 and λ2 on the model
[0065] During the local training phase, the augmented transfer learning strategy of this method has three stages, and the basis for entering a specific stage is the comparison of the model's accuracy. That is, when the accuracy of the global model and the local model meets the condition... Then, the second phase of training begins; when the accuracy of both the global and local models meets the conditions... Then, the training enters the third stage. Here, λ1 and λ2 satisfy the condition 0 < λ1 < λ2 < 1. We conducted experiments with different λ1 and λ2 to study their impact on model performance, and the experimental results are shown in Table 4.
[0066] Table 4 shows the effect of different λ1 and λ2 on the experimental results.
[0067]
[0068] Table 4 shows that when λ1 = 0.7 and λ2 = 0.9, the experimental results are generally better than in other cases. This indicates that appropriately increasing the conditions for entering the second and third stages in enhanced federated learning can improve the model's accuracy. Furthermore, it can be observed that appropriately decreasing the value of λ1 improves the specificity evaluation index, suggesting that faster entry into the second stage helps the model diagnose negative cases.
Claims
1. A medical image classification method based on federated learning and transfer learning, characterized in that, The method includes the following steps: S100: The central server sends the current global model to each client. S200. Each client trains the global model as a local model locally. The local training adopts an enhanced transfer learning strategy, including: a first stage of training the global model using the local model to retain local knowledge; a second stage of mutual learning between the global model and the local model to exchange knowledge; and a third stage of training the local model using the global model to transfer global knowledge. S300. After the local training is completed, each client sends the trained local model to the central server. S400, The central server uses a weighted average method to aggregate each trained local model, where the size of the data represents the weight.
2. The medical image classification method based on federated learning and transfer learning according to claim 1, characterized in that, The local training includes the following steps: S210. In the first stage, since directly updating the local model with the global model will eliminate the local knowledge learned by the local model, this method uses the local model to train the global model. The training loss function is shown in (1). Where x represents a data sample, L CE1 (x) is the cross-entropy loss in the first stage, w glob (x i ) and w local (x i ) represent samples x respectively i The subsequent feature distribution, where N represents the number of sample data and KL represents the KL divergence (i.e., relative entropy), is calculated using the following formula: p and q represent two different feature distributions. Since the KL divergence is asymmetric, p is the true distribution of the feature and q is the fitted distribution, which is used to calculate the similarity between distribution q and distribution p. S220, Second stage: Set a threshold λ1. When the accuracy of the global model and the local model meets the condition... Then, the second stage of training begins, which involves mutual learning between the global model and the local model. The loss function of the global model is shown in Equation (1), and the loss function of the local model is shown in Equation (3). This training method can improve the generalization ability of all client-side local models. CE2 (x) represents the cross-entropy loss in the second stage; S230, Third stage: Set a threshold λ2 (0 < λ1 < λ2 < 1), when the accuracy of the global model and the local model meets the condition. Then, the training enters the third stage, which is to use the global model to train the local model. At this time, the accuracy of the two models has reached a high degree of similarity, and the loss function of the local model is shown in formula (3).
3. The medical image classification method based on federated learning and transfer learning according to claim 1, characterized in that, The model aggregation method further includes the following steps: To mitigate the impact of data heterogeneity, the central server collects a small amount of uniform sample data to train the global model for correction.