An adaptive prototype generation and dual-space-guided heterogeneous federated learning method
By employing an adaptive prototype generation and dual-space guided heterogeneous federated learning method, this study addresses the problem of passive weighted mean aggregation in the global prototype generation mechanism during collaborative diagnosis among multiple medical institutions. It achieves adaptive boundary and category separation penalty for global disease prototypes, thereby improving the diagnostic accuracy and rare disease identification capabilities of primary healthcare institutions.
Patent Information
- Authority / Receiving Office
- CN · China
- Patent Type
- Applications(China)
- Current Assignee / Owner
- CHONGQING UNIV OF POSTS & TELECOMM
- Filing Date
- 2026-04-15
- Publication Date
- 2026-06-12
Smart Images

Figure CN122201851A_ABST
Abstract
Description
Technical Field
[0001] This invention belongs to the interdisciplinary field of distributed artificial intelligence and medicine, specifically involving a heterogeneous federated learning method with adaptive prototype generation and dual-space guidance. Background Technology
[0002] In practical federated learning applications, participants typically possess different model architectures and highly non-independent, identically distributed private data, which has spurred the development of heterogeneous federated learning. Early heterogeneous federated learning methods attempted to achieve collaboration through partial sharing of model structures (such as sharing classification heads), but this limited flexibility and was prone to negative transfer effects under severe data heterogeneity. Subsequent knowledge distillation (KD)-based methods (such as FedMD, FML, FedKD, etc.) supported completely independent models, but heavily relied on public proxy datasets or required frequent exchange of high-dimensional logits, resulting in huge communication overhead and easy propagation of noise in the early stages of training. To address this, prototype-based heterogeneous federated learning methods (such as FedProto) emerged as a communication-efficient alternative. This paradigm achieves knowledge aggregation by transmitting compact class feature representations (prototypes) between the client and server. However, existing prototype learning methods have a fatal flaw: the generation mechanism of the global prototype is inherently passive (such as a simple heuristic weighted average). When faced with highly heterogeneous environments, this passive aggregation forcibly compresses prototypes from different spaces and with different attributes into a single centroid, which easily leads to similarity boundary collapse, obliterates the multimodal distribution characteristics of intra-class data, and results in overlapping and blurred boundaries of global prototype features after aggregation, severely limiting the performance ceiling of the system. In addition, existing methods rely only on the alignment of the first moments of the feature space, ignoring the extremely important inter-class topological structure information in the global feature space (i.e., the dark knowledge implied in the prediction probability of non-target categories).
[0003] In collaborative diagnosis across multiple medical institutions, patient imaging data (CT, MRI, pathology slides, etc.) is strictly managed by the Personal Information Protection Law and medical industry data security regulations, making it difficult for institutions to overcome existing data barriers. From high-performance imaging workstations in tertiary hospitals to lightweight terminals in primary healthcare institutions, the judgment models deployed by each participant differ fundamentally in architectural depth and parameter scale. Data privacy restrictions and differences in model architecture are the main challenges encountered in collaborative modeling of medical artificial intelligence.
[0004] Current prototype-based heterogeneous federated learning schemes (such as FedProto) have two deep-seated shortcomings. On the one hand, the creation of global prototypes relies on passive weighted mean aggregation. When the disease spectrum of various hospitals exhibits strong non-independent and identically distributed characteristics, it is easy to cause the class boundaries in the feature space to shrink. Disease categories with similar visual semantics, such as early lung adenocarcinoma and pneumonia exudative lesions, will overlap in the prototype space, making the decision hyperplane unstable and naturally increasing the misdiagnosis rate. On the other hand, this alignment method, which treats the class centroid as the only transitive object, can only obtain the first moment information of the feature distribution, while ignoring the "diagnostic dark knowledge" contained in the prediction probability of non-target categories between disease categories (such as the image similarity between "emphysema" and "bullous pulmonaryus" being much greater than that between "emphysema" and "rib fracture"). Thus, this information is systematically discarded, which directly affects the ability of grassroots lightweight models to generalize and identify rare diseases.
[0005] In summary, there is an urgent need for a heterogeneous federated learning method suitable for collaborative image interpretation among multiple hospitals, in order to improve the accuracy of collaborative judgment among multiple medical institutions and assist doctors in further diagnosis. Summary of the Invention
[0006] To address the shortcomings of existing technologies, this invention proposes a heterogeneous federated learning method with adaptive prototype generation and dual-space guidance, which includes:
[0007] S1: Initialize the federated learning system; the federated learning system includes multiple heterogeneous medical institution node clients and a central server; the medical institution node clients deploy local diagnostic models, and the central server deploys a global prototype generation network and a global classification head;
[0008] S2: The medical institution node client trains the local diagnostic model and uploads the local disease feature prototype to the central server;
[0009] S3: The central server trains a global prototype generation network based on local disease feature prototypes, and the global prototype generation network generates global disease prototypes.
[0010] S4: Train a global classification head based on local disease feature prototypes and global disease prototypes; generate global soft labels from the global classification head; the central server distributes global soft labels and global disease prototypes to medical institution node clients;
[0011] S5: The medical institution node client performs multi-objective dual-guided training on the local diagnostic model based on global soft tags and global disease prototypes, and uploads the local disease feature prototypes output by the trained local diagnostic model to the central server.
[0012] S6: Each medical institution node client determines whether the local diagnostic model has converged. If so, the learning ends; otherwise, it returns to step S3.
[0013] Preferably, in step S3, training the global prototype generation network specifically includes:
[0014] Calculate the adaptive boundary for each classification category;
[0015] The global prototype generation network generates a global disease prototype, and calculates the boundary enhancement contrast loss based on the global disease prototype, the adaptive boundary of the corresponding category of the global disease prototype, and the local disease feature prototype of the corresponding category of the global disease prototype.
[0016] Calculate category separation penalty based on the global disease prototype;
[0017] The weighted sum of the boundary enhancement contrast loss and the category separation penalty is used as the total loss for global disease prototype generation on the central server. The parameters of the global prototype generation network are adjusted based on the total loss for global disease prototype generation on the central server to complete the training of the global prototype generation network.
[0018] Furthermore, the formula for calculating the adaptive boundary for each classification category is:
[0019]
[0020] in, Represents the dynamic adaptive boundary of category c. This represents the unweighted average aggregation center of category c. Represents any category other than category c. Unweighted average aggregation center, This indicates the preset cutoff threshold. This represents the L2 norm.
[0021] Furthermore, the formula for calculating the boundary enhancement contrast loss is as follows:
[0022]
[0023] in, This indicates boundary-enhanced contrast loss. This represents the prototype of a local disease feature with the actual category y uploaded by the medical institution node client i. This represents a global disease prototype of category y generated by the global prototype generation network. Represents the dynamic adaptive boundary of category y. This represents the L2 norm.
[0024] Furthermore, the formula for calculating the category separation penalty is as follows:
[0025]
[0026] in, express, Indicates the number of categories. Indicates a safe interval. This represents the global disease prototype for category c. Indicates category The global disease prototype This represents the L2 norm.
[0027] Preferably, in step S4, the process of training the global classification head includes:
[0028] By combining local disease feature prototypes and global disease prototypes, a hybrid prototype dataset is obtained.
[0029] The hybrid prototype dataset is input into the global classification head, and the global classification head is optimized using standard cross-entropy loss;
[0030] The global disease prototype is input into the optimized global classification head, and a temperature scaling factor is introduced to obtain a global soft label.
[0031] Furthermore, the global soft tag representation is obtained as follows:
[0032]
[0033] in, This represents a global soft tag of category c. This represents the global disease prototype for category c. Indicates global category header Learnable parameters This represents the temperature scaling factor. Indicates the global category header. express function.
[0034] Preferably, in step S5, the loss function for multi-objective dual-guided training of the local diagnostic model is a weighted sum of local classification loss, feature space prototype alignment loss, and prediction space knowledge distillation loss.
[0035] Furthermore, the feature space prototype alignment loss is expressed as:
[0036]
[0037] in, This represents the feature space prototype alignment loss. This represents the local private dataset of client i at the medical institution node. Upsampled samples The mathematical expectation, The dimension of the unified feature space is represented. This represents the high-dimensional features extracted locally. Indicates the input sample. This represents the learnable parameters of the local feature extractor of the medical institution node client i. This represents the global disease prototype corresponding to the true category y of the input sample. This represents the L2 norm.
[0038] The beneficial effects of this invention are as follows:
[0039] In addressing the prototype boundary shrinkage problem, this invention designs a parameterized trainable prototype generation network on the coordination center server side. This generation network is jointly activated by dynamic adaptive boundary constraints and class separation penalties, thereby actively creating a global disease prototype space containing high disease separability from a mechanistic perspective. Regarding the lack of inter-class topological information, the central server combines local disease feature prototypes and generated prototypes to form a hybrid feature dataset. During this process, no original images are leaked to train the global classification head. Temperature scaling is then used to obtain class-level soft labels with global disease topological relationships, thus providing different types of medical institution node clients with higher-quality deep diagnostic priors. During training on the medical institution node clients, the mean square error of the feature space achieves geometric anchoring of local disease features to the global disease prototype. The prediction space utilizes KL divergence to distill global soft label knowledge into the local diagnostic model. These two aspects form two dimensions of regularization constraints, enabling heterogeneous diagnostic models from various institutions to achieve consistent learning of the global disease semantic manifold without data leaving the local system. This truly improves the accuracy of collaborative diagnosis in grassroots medical institutions caused by the long-tail disease distribution and resource shortages. This invention effectively improves the identification accuracy of rare diseases in grassroots institutions, enhances the accuracy of collaborative judgment among multiple medical institutions, and helps assist doctors in further diagnosis. Attached Figure Description
[0040] Figure 1 This is a diagram illustrating the overall architecture of the heterogeneous federated learning method with adaptive prototype generation and dual-space guidance in this invention.
[0041] Figure 2 This is a schematic diagram of the workflow of the global prototype generation network and global classification head on the central server side in this invention;
[0042] Figure 3 This is a schematic diagram of the local multi-objective joint training mechanism of the medical institution node client based on dual guidance of feature space and prediction space in this invention. Detailed Implementation
[0043] The technical solutions of the embodiments of the present invention will be clearly and completely described below with reference to the accompanying drawings. Obviously, the described embodiments are only some embodiments of the present invention, and not all embodiments. Based on the embodiments of the present invention, all other embodiments obtained by those skilled in the art without creative effort are within the scope of protection of the present invention.
[0044] This invention proposes a heterogeneous federated learning method with adaptive prototype generation and dual-space guidance, such as... Figure 1 As shown, the method includes the following:
[0045] S1: Initialize the federated learning system; the federated learning system includes multiple heterogeneous medical institution node clients and a central server; the medical institution node clients deploy local diagnostic models, and the central server deploys a global prototype generation network and a global classification head.
[0046] Federal networks include A person with a private dataset It consists of heterogeneous medical institution node clients and a central server. The central server is equipped with a trainable global prototype generation network. and global classification head Each medical institution's node client has a completely different local model architecture, and its logic is decoupled into a feature extractor. and classifier All medical institution node client feature extractors output a unified... 3D feature space.
[0047] In some preferred embodiments of the present invention, a federated network of 20 medical institutions as medical institution node clients is used as a specific scenario. The medical institution node client participants include a tertiary-level general hospital (which deploys ResNet series deep networks), a secondary-level specialized hospital (which deploys MobileNet v2 or GoogleNet), and a primary health service center (which deploys a lightweight 4-layer CNN). Each institution holds a local private lung impact dataset (which includes CT scans and X-rays). The disease types cover many lesion types such as lung adenocarcinoma, pneumonia, emphysema, and bullae. The distribution of each category shows obvious regional long-tail characteristics. The output dimension of the feature extractor of all local diagnostic models is uniformly adjusted to 512 dimensions, thereby meeting the requirement that cross-architecture prototype communication must be consistent in dimension. The entire collaborative training process is set to a transmission process of 500 rounds, and all medical institution node clients will participate in the aggregation operation during each round of transmission.
[0048] First, the federated learning system is initialized. Specifically, each medical institution initializes its local diagnostic model according to its own computing power, breaking down its logic into two modules: an image feature acquirer and a classification classifier. Simultaneously, the central server also performs the initialization process, creating a global disease prototype generator with an FC-ReLU-FC structure and a global classification head. Furthermore, locally stored image data is allocated to the training set at a pre-defined ratio of 75%, with the remaining 25% going to the test set. This raw data remains locally and is not sent elsewhere.
[0049] S2: The medical institution node client trains the local diagnostic model and uploads the local disease feature prototype to the central server.
[0050] Each medical institution node client uses an SGD optimizer with a batch size of 32 and a learning rate of 0.01 to conduct a single round of local training. The average value of the sample features of each disease category in the local image dataset is calculated to obtain the local judgment prototype vector at the category level, i.e., the local disease feature prototype. This local disease feature prototype vector is then uploaded to the coordination center server. The communication load for each upload is only the number of categories multiplied by a 512-dimensional floating-point vector. Compared with transmitting model gradients or image data, the communication overhead is negligible.
[0051] S3: The central server trains a global prototype generation network based on local disease feature prototypes, and the global prototype generation network generates global disease prototypes.
[0052] like Figure 2 As shown, training the global prototype generation network specifically includes:
[0053] Calculate the adaptive boundary for each classification category, specifically:
[0054] In each round of communication, the central server collects the local disease feature prototype sets uploaded by the participating medical institution node clients. Calculate the unweighted average aggregation center for each type of disease. This allows for the dynamic determination of adaptive boundaries for each category. :
[0055]
[0056] in, Represents the dynamic adaptive boundary of category c. This represents the unweighted average aggregation center of category c. Represents any category other than category c. Unweighted average aggregation center, This indicates a preset cutoff threshold to prevent excessive boundary expansion. This represents the L2 norm.
[0057] This mechanism enables the model to adaptively adjust the repulsion strength of each category based on the current evolution state of the feature space.
[0058] A global prototype generation network generates a global disease prototype. Based on the global disease prototype, the adaptive boundary of the corresponding category, and the local disease feature prototype of the corresponding category, a boundary enhancement contrast loss is calculated. Specifically:
[0059] Train the global prototype generation network to generate global disease prototypes. Closely aligned with the prototype of local disease characteristics At the same time, stay away from the prototype of the global disease. For local disease feature prototypes uploaded by medical institution node clients. (Corresponding to category y), the boundary enhancement contrast loss is defined as:
[0060]
[0061] in, This indicates boundary-enhanced contrast loss. This represents the prototype of a local disease feature with the actual category y uploaded by the medical institution node client i. This represents a global disease prototype of category y generated by the global prototype generation network. This represents the dynamic adaptive boundary of category y. i represents a specific medical institution node client currently participating in the aggregation, with a value range of... ; y represents the prototype of this local disease feature currently being calculated. The corresponding real tag category has a range of values. .
[0062] Introduce a penalty term into the positive sample term of the denominator. This is equivalent to raising the positive sample matching requirement, forcing the generator to push out-of-class prototypes further, retaining at least [a certain number] in the feature space for each class. Safe distance.
[0063] The category separation penalty is calculated based on the global disease prototype, specifically:
[0064] To completely avoid the risk of prototype collapse from a global perspective, all violations of safety intervals should be addressed. The category is subject to an absolute topological constraint, i.e., a full matrix separation penalty:
[0065]
[0066] in, express, Indicates the number of categories. The safety interval is a hyperparameter limit with a preset value of 100 used to prevent topological collapse of the global disease prototype. This represents the global disease prototype for category c. Indicate category The global disease prototype and These are prototypes of different categories of global diseases.
[0067] The punishment iterates through all... Using a prototype pair (instead of penalizing only the nearest pair), the gradient truncation problem in backpropagation is effectively avoided, ensuring that all classes in the high-dimensional space receive a uniform repulsive force.
[0068] The weighted sum of the boundary enhancement contrast loss and the class separation penalty is used as the total loss for global disease prototype generation on the central server. The parameters of the global prototype generation network are adjusted based on this total loss to complete the training of the global prototype generation network. Specifically:
[0069] Total loss in global disease prototype generation on the central server side Represented as:
[0070]
[0071] in, This is a balance coefficient used to adjust the strength of explicit topological constraints. It is achieved by minimizing... The central server continuously generates global disease prototypes that combine semantic consistency and high distinguishability. .
[0072] Global prototype generation network affected by boundary-enhanced contrast loss and full matrix separation penalty (Preferred, weighted) Through the combined effects of various factors, a high-quality global judgment prototype was gradually improved, autonomously forming a system that maintains sufficient geometric distance between disease categories. This mechanism provides a clear restriction on visually similar conditions (such as prototypical confusions like lung adenocarcinoma and pneumonia exudates), thereby improving the ability to distinguish between similar conditions.
[0073] This invention effectively prevents the degradation of the prototype space when there is heterogeneity in multi-center image data. The central server of this invention adopts a trainable global prototype generation network. This generator uses dynamic adaptive boundary constraints and replaces passive mean aggregation with full matrix separation penalty. Even if the disease spectrum distribution is highly non-independent and identically distributed, it can maintain the inter-class geometric distance of the global diagnostic prototype. In this way, it can suppress the feature aliasing problem caused by visually semantically similar lesions (such as lung lesions of different stages) from a mechanistic perspective, so that the discriminative structure of the prototype space gradually becomes stable.
[0074] S4: Train a global classification head based on local disease feature prototypes and global disease prototypes; generate global soft labels from the global classification head; and distribute global soft labels and global disease prototypes to medical institution node clients.
[0075] like Figure 2 As shown, the process of training the global classification head includes:
[0076] By combining local disease feature prototypes and global disease prototypes, a hybrid prototype dataset is obtained. Specifically:
[0077] The central server fuses local disease feature prototypes with the generated global disease prototypes to form a training set, which consists of local disease feature prototypes uploaded by medical institution node clients. (Reflecting real statistical heterogeneity) and globally generated prototypes (Provides highly distinguishable anchor points) Hybrid construction, requiring no original image data.
[0078] The hybrid prototype dataset is input into the global classification head, and the global classification head is optimized using standard cross-entropy loss:
[0079]
[0080] in, It's a tag. It is the global classification header. It is the standard cross-entropy loss function. It is a prototype in a mixed prototype set. By minimizing... The global classification head fully fits the topological structure of the global feature space and establishes a globally consistent hyperplane partition among the prototypes of each category.
[0081] The global disease prototype is input into the optimized global classification head, and a temperature scaling factor is introduced to obtain global soft labels, specifically:
[0082] After the global classification head converges, the global disease prototypes for each disease category are fed into the global classification head, and a temperature scaling factor is introduced. Generate disease categories by smoothing out overconfident Logits output. Global soft tags for topological association information :
[0083]
[0084] in, This represents a global soft tag of category c. Indicates global category header Learnable parameters This represents the temperature scaling factor, preferably... ; Indicates the global category header. express function.
[0085] The probability component of non-target categories in soft tags implicitly contains prior knowledge, such as "emphysema and bullae have a greater semantic similarity than emphysema and rib fractures," which pre-existed knowledge before the global judgment. Therefore, temperature-smoothed soft tags... It has high information entropy, not only providing the highest confidence level for the target category, but also accurately quantifying the category. Topological similarity associations (dark knowledge) with all other categories in the global feature space.
[0086] This invention achieves effective transfer of diagnostic dark knowledge across institutions while strictly adhering to medical data privacy compliance requirements. It trains a global classification head using a hybrid prototype dataset, ensuring that the original image data remains within the local institution, with only lightweight prototype vectors participating in communication. The class-level soft labels generated through temperature scaling distribute the existing global topological relationships between disease categories—high-order semantic information systematically lost by current methods due to merely aligning first-order moments—to all participants in a privacy-secure manner, thereby injecting global diagnostic priors into the lightweight models of primary healthcare institutions.
[0087] S5: The medical institution node client performs multi-objective dual-guided training on the local diagnostic model based on global soft tags and global disease prototypes, and uploads the real prototype output by the trained local diagnostic model to the central server.
[0088] like Figure 3 As shown, the medical institution node clients of each institution receive the global diagnostic prototype from the central server. and soft labels Then, the local model will be improved by using a weighted loss function that includes three objectives.
[0089] The three objective loss functions are local classification loss, feature space prototype alignment loss, and prediction space knowledge distillation loss.
[0090] Local classification loss :
[0091] Standard cross-entropy loss is used to ensure the model's basic fit to the local data distribution:
[0092]
[0093] Where i represents the index of the medical institution node client, j represents the index of the sample, and N represents the number of samples in the local dataset. Let j represent the j-th input sample. The sample feature representation generated by the feature extractor This represents the learnable parameters of the local feature extractor of the medical institution node client i. This is for the classifier targeting In its true category The logit output on the top Indicates that the classifier is for In category The logit output on the device.
[0094] Feature space prototype alignment loss :
[0095] The global disease prototype issued by the central server Treated as global semantic anchors for the corresponding category, explicit geometric constraints are imposed on local instance-level features using mean squared error (MSE):
[0096]
[0097] in, This represents the feature space prototype alignment loss. This represents the local private dataset of client i at the medical institution node. Upsampled samples The mathematical expectation, The dimension of the unified feature space is represented. High-dimensional features extracted locally , Indicates the input sample. This represents the learnable parameters of the local feature extractor of the medical institution node client i. This represents the global disease prototype corresponding to the true category y of the input sample. The formula introduces a 1 / d normalization factor to eliminate dimensionality interference. This term acts as a powerful structure regularization term, effectively normalizing the gradient update direction of the local feature extractor and preventing feature drift on Non-IID data.
[0098] Predicting spatial knowledge distillation loss :
[0099] Global soft tags generated by the central server Implicit knowledge of inter-class relationships (dark knowledge) is passed on. The local model uses the exact same temperature as the central server when calculating the soft prediction distribution. (Unified temperature field design), aligned with global soft labels via KL divergence:
[0100]
[0101] in, Local logical value Temperature The smoothed soft prediction distribution, while the local logistic value o is the unnormalized original logistic value vector output by the last layer of the local classifier network, i.e. .
[0102] The overall goal of joint optimization of medical institution node clients is:
[0103]
[0104] in, Controlling the geometric constraint strength in the feature space, Controlling the intensity of predictive spatial knowledge distillation. Unlike traditional offline distillation's heuristic temperature compensation, this invention maintains a uniform temperature field and relies on independent hyperparameters. Adjusting the distillation gradient backflow intensity greatly simplifies landscape optimization. The experimental optimal value is... . Used to control the strength of the geometric anchoring in the feature space. The MSE alignment loss in the feature space is used to adjust the magnitude of the knowledge distillation gradient backpropagation in the prediction space. It drags the local image features toward the global semantic anchor point, while the KL divergence distillation loss in the prediction space incorporates the topological association of disease categories into the decision boundary of the local classifier. The two work together to enable the lightweight models of primary healthcare institutions to learn a global disease discrimination manifold similar to that of deep learning models in tertiary hospitals, even with only a small number of locally labeled samples.
[0105] This invention applies constraints from two aspects: feature space geometric anchoring and prediction space knowledge distillation. This effectively improves the recognition accuracy of rare diseases in grassroots institutions. Hard constraints are specifically used for semantic space alignment across architectures, while soft guidance is responsible for conveying the probabilistic topology of inter-class relationships. The combination of these constraints enables heterogeneous diagnostic models, from lightweight CNNs to deep ResNets, to converge to a globally consistent decision manifold when dealing with the challenges of long-tailed disease distribution and resource scarcity. Their generalization ability is significantly better than current baseline methods.
[0106] After completing local training, each medical institution's node client calculates its local feature prototype. (That is, the feature mean of all samples of category c in the local private dataset), which is then uploaded to the central server to enter the next round of communication. The uploaded content is only a lightweight prototype vector (one 512-dimensional vector for each class), which does not contain any original data or model parameters, thus strictly protecting data and model privacy.
[0107] S6: Each medical institution node client determines whether the local diagnostic model has converged. If so, the learning ends; otherwise, it returns to step S3.
[0108] When the accuracy of each participating institution's local diagnostic model on its own test set gradually stabilizes, the federated training ends, and the converged heterogeneous diagnostic model of each institution is given; otherwise, return to step S3 and continue iterative training.
[0109] After training, each institution uses the trained heterogeneous diagnostic model to process medical data, which can output accurate prediction results, thereby assisting doctors in making further diagnoses.
[0110] In summary, this invention employs a trainable global prototype generation network on the central server side. This generator utilizes dynamic adaptive boundary constraints and replaces passive mean aggregation with full matrix separation penalty. The invention trains a global classification head using a hybrid prototype dataset, achieving effective transfer of dark knowledge for cross-institutional diagnoses while strictly adhering to medical data privacy compliance requirements. By applying constraints through both feature space geometric anchoring and prediction space knowledge distillation, the invention effectively improves the identification accuracy of rare diseases in primary care institutions. This invention is applicable to scenarios where multiple parties cannot share data (non-independent and identically distributed data) and where each party's model architecture is completely different, enabling efficient and privacy-preserving collaborative modeling. This includes, but is not limited to, various federated learning applications such as edge computing, smart healthcare, smart transportation, and the industrial internet.
[0111] The above-described embodiments further illustrate the purpose, technical solution, and advantages of the present invention. It should be understood that the above-described embodiments are merely preferred embodiments of the present invention and are not intended to limit the present invention. Any modifications, equivalent substitutions, improvements, etc., made to the present invention within the spirit and principles of the present invention should be included within the protection scope of the present invention.
Claims
1. A heterogeneous federated learning method with adaptive prototype generation and dual-space guidance, characterized in that, Includes the following steps: S1: Initialize the federated learning system; the federated learning system includes multiple heterogeneous medical institution nodes, medical institution node clients, and a central server; the medical institution node clients deploy local diagnostic models, and the central server deploys a global prototype generation network and a global classification head; S2: The medical institution node client trains the local diagnostic model and uploads the local prototype to the central server; S3: The central server trains a global prototype generation network based on the local true prototype, and the global prototype generation network generates a global disease prototype. S4: Train a global classification head based on the local prototype and the global disease prototype; the global classification head generates global soft labels; the central server distributes global soft labels and global disease prototypes to the medical institution node clients; S5: The medical institution node client performs multi-objective dual-guided training on the local diagnostic model based on global soft tags and global disease prototypes, and uploads the local prototype output by the trained local diagnostic model to the central server. S6: Each medical institution node client determines whether the local diagnostic model has converged. If so, the learning ends; otherwise, it returns to step S3.
2. The heterogeneous federated learning method with adaptive prototype generation and dual-space guidance according to claim 1, characterized in that, In step S3, training the global prototype generation network in the global diagnostic model specifically includes: Calculate the adaptive boundary for each classification category; A global prototype generation network generates a global disease prototype, and calculates a boundary enhancement contrast loss based on the global disease prototype, the adaptive boundary of the corresponding category of the global disease prototype, and the local prototype of the corresponding category of the global disease prototype. Calculate category separation penalty based on the global disease prototype; The weighted sum of the boundary enhancement contrast loss and the category separation penalty is used as the total loss for global disease prototype generation on the central server. The parameters of the global prototype generation network are adjusted based on the total loss for global disease prototype generation on the central server to complete the training of the global prototype generation network.
3. The heterogeneous federated learning method with adaptive prototype generation and dual-space guidance according to claim 2, characterized in that, The formula for calculating the adaptive boundary for each category is: ; in, Represents the dynamic adaptive boundary of category c. This represents the unweighted average aggregation center of category c. Represents any category other than category c. Unweighted average aggregation center, This indicates the preset cutoff threshold. This represents the L2 norm.
4. The heterogeneous federated learning method with adaptive prototype generation and dual-space guidance according to claim 2, characterized in that, The formula for calculating the boundary enhancement contrast loss is: ; in, This indicates boundary-enhanced contrast loss. This represents the prototype of a local disease feature with the actual category y uploaded by the medical institution node client i. This represents a global disease prototype of category y generated by the global prototype generation network. Represents the dynamic adaptive boundary of category y. This represents the L2 norm.
5. The heterogeneous federated learning method with adaptive prototype generation and dual-space guidance according to claim 2, characterized in that, The formula for calculating the category separation penalty is: ; in, express, Indicates the number of categories. Indicates a safe interval. This represents the global disease prototype for category c. Indicates category The global disease prototype This represents the L2 norm.
6. The heterogeneous federated learning method with adaptive prototype generation and dual-space guidance according to claim 1, characterized in that, In step S4, the process of training the global classification head includes: By combining local disease feature prototypes and global disease prototypes, a hybrid prototype dataset is obtained. The hybrid prototype dataset is input into the global classification head, and the global classification head is optimized using standard cross-entropy loss; The global disease prototype is input into the optimized global classification head, and a temperature scaling factor is introduced to obtain a global soft label.
7. The heterogeneous federated learning method with adaptive prototype generation and dual-space guidance according to claim 6, characterized in that, The global soft tag is represented as follows: ; in, This represents a global soft tag of category c. This represents the global disease prototype for category c. Indicates global category header Learnable parameters This represents the temperature scaling factor. Indicates the global category header. express function.
8. The heterogeneous federated learning method with adaptive prototype generation and dual-space guidance according to claim 1, characterized in that, In step S5, the loss function for multi-objective dual-guided training of the local diagnostic model is a weighted sum of local classification loss, feature space prototype alignment loss, and prediction space knowledge distillation loss.
9. The heterogeneous federated learning method with adaptive prototype generation and dual-space guidance according to claim 8, characterized in that, The feature space prototype alignment loss is expressed as: ; in, This represents the feature space prototype alignment loss. This represents the local private dataset of client i at the medical institution node. Upsampled samples The mathematical expectation, The dimension of the unified feature space is represented. This represents the high-dimensional features extracted locally. Indicates the input sample. This represents the learnable parameters of the local feature extractor of the medical institution node client i. This represents the global disease prototype corresponding to the true category y of the input sample. This represents the L2 norm.