Method for improving generalization of classification model based on data enhancement and related device
By constructing a classification model that includes original and data augmentation branches, using the data augmentation module to extract features and classify the data to be classified, and constructing data augmentation loss and classification loss, the problem of insufficient generalization of the classification model is solved, and accurate classification is achieved on different datasets.
Patent Information
- Authority / Receiving Office
- CN · China
- Patent Type
- Patents(China)
- Current Assignee / Owner
- PING AN TECH (SHENZHEN) CO LTD
- Filing Date
- 2023-06-28
- Publication Date
- 2026-06-16
AI Technical Summary
Existing classification models have poor generalization ability outside the training dataset, making it difficult to obtain accurate classification results on different datasets. Traditional data augmentation methods cannot quickly and effectively improve the generalization ability of the model.
An initial classification model is constructed, including a raw branch and a data augmentation branch. The data to be classified is downsampled multiple times by the raw encoder and the augmentation encoder. The data augmentation module is used to randomly perturb the augmented downsampled map. Data augmentation loss and classification loss are constructed. The model is updated by gradient descent to improve generalization.
While keeping the semantic information of the data to be classified unchanged, the perturbation of style features is increased, which improves the classification accuracy and generalization of the classification model on different datasets.
Smart Images

Figure CN116824254B_ABST
Abstract
Description
Technical Field
[0001] This application relates to the fields of artificial intelligence and digital healthcare technology, and in particular to a method and related equipment for improving the generalization of a data-augmented classification model. Background Technology
[0002] Classification models have been widely used in various industries and fields such as finance and digital healthcare. However, the biggest problem when using classification models to solve practical problems is their poor generalization ability. That is, a classification model trained on the training dataset is difficult to obtain accurate classification results on other datasets. For example, a medical image classification model trained on the first medical image dataset is usually difficult to obtain accurate classification results when applied to the classification task of the second medical image dataset. Therefore, it is necessary to improve the generalization ability of the classification model so that it can maintain high classification accuracy when solving practical problems.
[0003] Currently, the common approach is to first augment existing raw data (such as the first medical image dataset mentioned above) with data, and then use the augmented data and the original data to train the classification model in order to improve the generalization of the classification model. Common data augmentation methods include: Method 1: performing undirected or random data augmentation on the data; Method 2: directly generating augmented data based on generative adversarial models. However, Method 1 cannot find the augmented data that is most beneficial to classification, and Method 2 requires a lot of training resources. Neither of the above two methods can quickly and effectively improve the generalization of the classification model. Summary of the Invention
[0004] In view of the above, it is necessary to propose a method and related equipment for improving the generalization of classification models based on data augmentation, so as to solve the technical problem of how to quickly and effectively improve the generalization of classification models. The related equipment includes a device for improving the generalization of classification models based on data augmentation, electronic equipment, and storage media.
[0005] This application provides a method for improving the generalization of a data-augmented classification model, the method comprising:
[0006] S10, Build an initial classification model. The initial classification model includes an original branch and a data augmentation branch. The original branch includes an original encoder and an original classifier. The data augmentation branch includes an augmentation encoder and an augmentation classifier. The augmentation encoder is an original encoder that includes at least one data augmentation module.
[0007] S11, Collect multiple data sets to be classified with category labels as training datasets;
[0008] S12, Select any data to be classified from the training dataset and input it into the original branch. The original encoder performs multiple downsamplings on the data to be classified to obtain at least one original downsampled image, and inputs the last original downsampled image into the original classifier to obtain the original classification result.
[0009] S13, the data to be classified is input into the data augmentation branch, the augmentation encoder performs multiple downsamplings on the data to be classified to obtain at least one augmented downsampling map, and randomly perturbs the pre-selected augmented downsampling map based on the data augmentation module to obtain at least one augmented feature map, and the last augmented downsampling map is input into the augmentation classifier to obtain the augmented classification result;
[0010] S14, construct a data augmentation loss based on the enhanced feature map, the original downsampled map, and the enhanced downsampled map; construct a classification loss based on the original classification result, the enhanced classification result, and the category label of the data to be classified; and use the sum of the data augmentation loss and the classification loss as the target loss.
[0011] S15, update the initial classification model according to the gradient descent method to complete one iteration of training, return to step S12, until the target loss value is less than the preset value, and obtain the target classification model.
[0012] In some embodiments, the original encoder includes multiple convolutional layers, and the original classifier includes multiple fully connected layers;
[0013] A data augmentation module is inserted at at least one preset position in the plurality of convolutional layers of the original encoder to obtain the augmented encoder. The augmented classifier includes a plurality of fully connected layers. The data augmentation module is used to add random perturbations to the downsampled map output by the convolutional layer to achieve data augmentation.
[0014] The original classifier and the augmented classifier may have the same or different structures, and when the original classifier and the augmented classifier have the same structure, the structural network parameters of the original classifier and the augmented classifier may or may not be shared.
[0015] In some embodiments, the augmentation encoder downsamples the data to be classified multiple times to obtain at least one augmented downsampled map, and randomly perturbs a pre-selected augmented downsampled map based on the data augmentation module to obtain at least one augmented feature map, including:
[0016] For any convolutional layer in the enhanced encoder, an enhanced downsampled map is obtained by performing downsampling processing on the input data based on the convolutional layer, wherein the size of the enhanced downsampled map is less than or equal to the size of the input data.
[0017] When the end of the convolutional layer is directly connected to the data augmentation module, it indicates that the augmented downsampled map is a pre-selected augmented downsampled map. Then, the data augmentation module randomly perturbs the augmented downsampled map to obtain an augmented feature map.
[0018] Determine whether the convolutional layer is the last convolutional layer in the enhanced encoder;
[0019] If the convolutional layer is not the last convolutional layer in the enhanced encoder, then the enhanced downsampled map or the enhanced feature map is used as the input data for the next convolutional layer.
[0020] If the convolutional layer is the last convolutional layer in the enhanced encoder, then the enhanced downsampled map is used as the last enhanced downsampled map.
[0021] In some embodiments, the step of randomly perturbing the enhanced downsampled map based on the data augmentation module to obtain an enhanced feature map includes:
[0022] Obtain all feature values in the enhanced downsampled image;
[0023] Each feature value in the enhanced downsampled image is randomly perturbed according to the random perturbation formula to obtain the perturbation value corresponding to each feature value. The random perturbation formula satisfies the following relationship:
[0024]
[0025] Where x is any feature value in the enhanced downsampling image, μ and σ are the mean and variance of all feature values, respectively, γ and β are the learnable parameters of the data augmentation module, and LDP(x) is the perturbation value corresponding to feature value x;
[0026] Replace all feature values in the enhanced downsampled image with the corresponding perturbation values to obtain the enhanced feature map corresponding to the enhanced downsampled image.
[0027] In some embodiments, the enhanced downsampling map and the original downsampling map Figure 1 A one-to-one correspondence exists, where the corresponding enhanced downsampled image and the original downsampled image have the same size. The construction of the data augmentation loss based on the enhanced feature map, the original downsampled image, and the enhanced downsampled image includes:
[0028] For any enhanced feature map, obtain the enhanced downsampled map corresponding to the enhanced feature map, and use the original downsampled map corresponding to the enhanced downsampled map in the original branch as the target downsampled map;
[0029] Calculate the Gram matrix of the target downsampled image and the enhanced feature map respectively, and calculate the style bias of the enhanced feature map based on the Gram matrix. The style bias satisfies the following relationship:
[0030]
[0031] in, For the i-th enhanced feature map, f i G(f) is the target downsampled image corresponding to the i-th enhanced feature map. i )and They represent f respectively i and Gram matrix, Indicates calculation The F-norm, The style bias of the i-th enhanced feature map;
[0032] The semantic bias is calculated based on the last original downsampled image and the last enhanced downsampled image, and the semantic bias satisfies the following relation:
[0033]
[0034] Among them, f * and These are the last original downsampled image and the last enhanced downsampled image, respectively. Indicates calculation L2 norm sem The semantic deviation;
[0035] A data augmentation loss is constructed based on the semantic bias and the style bias of each augmented feature map, and the data augmentation loss satisfies the following relationship:
[0036]
[0037] Among them, L sem Let K be the semantic bias, and K be the number of all enhanced feature maps. Let λ be the style bias of the i-th enhanced feature map. sem and λ spe L is a preset coefficient greater than 0. zq Enhance the loss of the data.
[0038] In some embodiments, the classification loss includes a first classification loss L cls1 Second classification loss L cls2The first classification loss is a cross-entropy loss function constructed based on the original classification result and the category labels of the data to be classified, and the second classification loss is a cross-entropy loss function constructed based on the enhanced classification result and the category labels of the data to be classified.
[0039] In some embodiments, when the value of the target loss is less than the preset value, the integration result of the original branch and the data augmentation branch is used as the target classification model. The input of the target classification model is any data to be classified, and the output is the classification result of the data to be classified.
[0040] This application embodiment also provides a device for improving the generalization of a data-augmented classification model, the device comprising:
[0041] A building unit is used to build an initial classification model. The initial classification model includes an original branch and a data augmentation branch. The original branch includes an original encoder and an original classifier. The data augmentation branch includes an augmentation encoder and an augmentation classifier. The augmentation encoder is an original encoder that includes at least one data augmentation module.
[0042] The acquisition unit is used to collect multiple data sets to be classified with category labels as training datasets;
[0043] The original branch unit is used to select any data to be classified from the training dataset and input it into the original branch. The original encoder performs multiple downsamplings on the data to be classified to obtain at least one original downsampled image, and inputs the last original downsampled image into the original classifier to obtain the original classification result.
[0044] An augmentation branch unit is used to input the data to be classified into the data augmentation branch. The augmentation encoder performs multiple downsampling on the data to be classified to obtain at least one augmented downsampled map, and randomly perturbs the pre-selected augmented downsampled map based on the data augmentation module to obtain at least one augmented feature map. The last augmented downsampled map is input into the augmentation classifier to obtain the augmented classification result.
[0045] The target loss unit is used to construct a data augmentation loss based on the enhanced feature map, the original downsampled map, and the enhanced downsampled map, and to construct a classification loss based on the original classification result, the enhanced classification result, and the category label of the data to be classified, and to use the sum of the data augmentation loss and the classification loss as the target loss;
[0046] An iterative training unit is used to update the initial classification model according to the gradient descent method to complete one iteration of training, and return to the original branch unit until the target loss is less than a preset value, thus obtaining the target classification model.
[0047] This application embodiment also provides an electronic device, the electronic device comprising:
[0048] Memory, storing at least one instruction;
[0049] The processor executes instructions stored in the memory to implement the method for improving the generalization of the data augmentation-based classification model.
[0050] This application also provides a computer-readable storage medium storing at least one instruction, which is executed by a processor in an electronic device to implement the method for improving the generalization of the data augmentation-based classification model.
[0051] In summary, this application performs feature extraction and classification on the same data to be classified through the original branch and a data augmentation branch with an added data augmentation module. A data augmentation loss is constructed based on the Gram matrix of the feature maps and semantic information during feature extraction to ensure that the data augmentation module can significantly increase the perturbation of style features while preserving the semantic information in the data to be classified. Simultaneously, the classification loss constrains both the original and augmentation branches to obtain accurate classification results, thereby improving the generalization of the classification model. This application can be used to improve the generalization of medical data classification models in the field of digital healthcare. Attached Figure Description
[0052] Figure 1 This is a flowchart of a preferred embodiment of the method for improving the generalization of a data augmentation-based classification model involved in this application.
[0053] Figure 2 This is a schematic diagram of the structure of the initial classification model involved in this application.
[0054] Figure 3 This is a functional block diagram of a preferred embodiment of the data augmentation-based classification model generalization improvement device involved in this application.
[0055] Figure 4 This is a schematic diagram of the structure of an electronic device that is a preferred embodiment of the data augmentation-based classification model generalization improvement method involved in this application. Detailed Implementation
[0056] To better understand the purpose, features, and advantages of this application, a detailed description of the application is provided below with reference to the accompanying drawings and specific embodiments. It should be noted that, unless otherwise specified, the embodiments and features described in the embodiments of this application can be combined with each other. Numerous specific details are set forth in the following description to provide a thorough understanding of this application; the described embodiments are only a part of the embodiments of this application, and not all of them.
[0057] Furthermore, the terms "first" and "second" are used for descriptive purposes only and should not be construed as indicating or implying relative importance or implicitly specifying the number of technical features indicated. Thus, a feature defined as "first" or "second" may explicitly or implicitly include one or more of the stated features. In the description of this application, "a plurality of" means two or more, unless otherwise explicitly specified.
[0058] Unless otherwise defined, all technical and scientific terms used herein have the same meaning as commonly understood by one of ordinary skill in the art to which this application belongs. The terminology used herein is for the purpose of describing particular embodiments only and is not intended to be limiting of the application. The term "and / or" as used herein includes any and all combinations of one or more of the associated listed items.
[0059] This application provides a method for improving the generalization of a data-augmented classification model, which can be applied to one or more electronic devices. An electronic device is a device that can automatically perform numerical calculations and / or information processing according to pre-set or stored instructions. Its hardware includes, but is not limited to, microprocessors, application-specific integrated circuits (ASICs), field-programmable gate arrays (FPGAs), digital signal processors (DSPs), embedded devices, etc.
[0060] Electronic devices can be any electronic product that allows human-computer interaction with a customer, such as personal computers, tablets, smartphones, personal digital assistants (PDAs), game consoles, interactive network television (IPTV), smart wearable devices, etc.
[0061] Electronic devices may also include network devices and / or client devices. The network devices include, but are not limited to, a single network server, a server group consisting of multiple network servers, or a cloud based on cloud computing consisting of a large number of hosts or network servers.
[0062] The networks in which electronic devices are located include, but are not limited to, the Internet, wide area networks, metropolitan area networks, local area networks, and virtual private networks (VPNs).
[0063] like Figure 1The diagram shown is a flowchart of a preferred embodiment of the data-augmented classification model generalization improvement method of this application. The order of steps in the flowchart can be changed, and some steps can be omitted, depending on different needs. The data-augmented classification model generalization improvement method provided in this application can be applied to any scenario requiring data classification, and thus can be applied to products in these scenarios, such as medical image classification or medical text classification in the digital healthcare field, business data classification in the banking and insurance field, etc. Medical images refer to images of internal tissues obtained non-invasively for medical treatment or research, such as images of the stomach, abdomen, heart, knee, and brain, including CT (Computed Tomography), MRI (Magnetic Resonance Imaging), US (ultrasonic), X-ray images, electroencephalograms, and images generated by medical instruments such as optical imaging. Medical text can be electronic healthcare records, electronic personal health records, including medical records, electrocardiograms, medical images, and other electronic records with archival value.
[0064] S10, Build an initial classification model. The initial classification model includes an original branch and a data augmentation branch. The original branch includes an original encoder and an original classifier. The data augmentation branch includes an augmentation encoder and an augmentation classifier. The augmentation encoder is an original encoder that includes at least one data augmentation module.
[0065] Please see Figure 2 , Figure 2 This is a schematic diagram of the structure of the initial classification model involved in this application. In an optional embodiment, the initial classification model 2 includes two branches: a data augmentation branch 20 and an original branch 21. The input of the original branch 21 is the data to be classified, and the output is the original classification result of the data to be classified. The original encoder 211 includes multiple convolutional layers, and the original classifier 212 includes multiple fully connected layers. The input of the data augmentation branch 20 is the data to be classified, and the output is the enhanced classification result of the data to be classified. A data augmentation module is inserted at at least one preset position in the multiple convolutional layers of the original encoder 211 to obtain the enhanced encoder 201. The enhanced classifier 202 includes multiple fully connected layers, and the data augmentation module is used to add random perturbations to the downsampled image output by the convolutional layer to achieve data augmentation.
[0066] In this optional embodiment, the original encoder can be any existing convolutional neural network encoder structure such as ResNet or DenseNet, and the number of data augmentation modules in the enhanced encoder is preset.
[0067] In an optional embodiment, the original classifier and the augmented classifier may have the same or different structures. When the original classifier and the augmented classifier have the same structure, their network parameters may or may not be shared. The structure of the original classifier refers to the number of fully connected layers and the number of neurons in each fully connected layer. The network parameters of the original classifier include the weights and biases of all fully connected layers. Network parameter sharing means that the values of all network parameters in the original classifier and the augmented classifier are consistent.
[0068] Thus, the initial classification model is built, which includes the original branch and the data augmentation branch containing at least one data augmentation module, providing a model foundation for improving the generalization of the classification model.
[0069] S11, Collect multiple data sets to be classified with category labels as training datasets;
[0070] In an optional embodiment, the data to be classified can be any data that needs to be classified, and the category label is the actual classification result corresponding to the data to be classified. The data to be classified and the classification result are determined according to the specific application scenario and classification task. The data to be classified includes, but is not limited to, image data and time-series data. For example, in the field of digital medicine, the data to be classified is a medical image, and the corresponding classification result can be whether or not the person is ill; in the field of insurance, the data to be classified can be obtained by forming a row vector from the annual insurance and claims data of any individual, and then arranging the row vectors from each year along the column direction, and the corresponding classification result can be whether or not the insurance should continue.
[0071] In an optional embodiment, a large amount of data to be classified with category labels is stored as a training dataset, which is used to train the initial classification model.
[0072] Thus, a training dataset is obtained for training the initial classification model. The training dataset includes a large amount of data to be classified with category labels, providing a data foundation for improving the generalization of the classification model.
[0073] S12, select any data to be classified from the training dataset and input it into the original branch. The original encoder performs multiple downsampling on the data to be classified to obtain at least one original downsampled image, and inputs the last original downsampled image into the original classifier to obtain the original classification result.
[0074] In an optional embodiment, the original encoder downsamples the data to be classified multiple times to obtain at least one original downsampled map, including:
[0075] For any convolutional layer in the original encoder, the input data is downsampled based on the convolutional layer to obtain an original downsampled image, the size of the original downsampled image being less than or equal to the size of the input data;
[0076] If the convolutional layer is not the last convolutional layer in the original encoder, then the original downsampled image is used as the input data for the next convolutional layer.
[0077] If the convolutional layer is the last convolutional layer in the original encoder, then the original downsampled image is used as the last original downsampled image.
[0078] In this system, the input data of the first convolutional layer in the original encoder is the data to be classified. In the original downsampled map obtained by performing downsampling processing on the input data based on the convolutional layer, the size of the original downsampled map is related to preset parameters in the convolutional layer, including the kernel size and stride. As multiple convolutional layers continuously perform downsampling operations on the data to be classified, the size of the resulting original feature map becomes smaller and smaller, enabling the extraction of semantic information from the data to be classified. This semantic information consists of content features related to the classification task.
[0079] In an optional embodiment, the last original downsampled image is input into the original classifier to obtain the original classification result.
[0080] Thus, in the original branch, multiple downsampling operations are performed on the data to be classified using multiple convolutional layers in the original encoder, and the original downsampled image obtained from the last downsampling operation is input into the original classifier to obtain the original classification result.
[0081] S13, the data to be classified is input into the data augmentation branch, the augmentation encoder performs multiple downsamplings on the data to be classified to obtain at least one augmented downsampling map, and randomly perturbs the pre-selected augmented downsampling map based on the data augmentation module to obtain at least one augmented feature map, and the last augmented downsampling map is input into the augmentation classifier to obtain the augmented classification result.
[0082] In an optional embodiment, the augmentation encoder downsamples the data to be classified multiple times to obtain at least one augmented downsampled map, and randomly perturbs the pre-selected augmented downsampled map based on the data augmentation module to obtain at least one augmented feature map, including:
[0083] For any convolutional layer in the enhanced encoder, an enhanced downsampled map is obtained by performing downsampling processing on the input data based on the convolutional layer, wherein the size of the enhanced downsampled map is less than or equal to the size of the input data.
[0084] When the end of the convolutional layer is directly connected to the data augmentation module, it indicates that the augmented downsampled map is a pre-selected augmented downsampled map. Then, the data augmentation module randomly perturbs the augmented downsampled map to obtain an augmented feature map.
[0085] Determine whether the convolutional layer is the last convolutional layer in the enhanced encoder;
[0086] If the convolutional layer is not the last convolutional layer in the enhanced encoder, then the enhanced downsampled map or the enhanced feature map is used as the input data for the next convolutional layer.
[0087] If the convolutional layer is the last convolutional layer in the enhanced encoder, then the enhanced downsampled map is used as the last enhanced downsampled map.
[0088] In this enhancement encoder, the input data of the first convolutional layer is the data to be classified. Since the number and position of the data augmentation modules in the enhancement encoder are preset, if the end of any convolutional layer is connected to a data augmentation module, the enhanced downsampled image output by that convolutional layer is the pre-selected enhanced downsampled image. Because the enhancement encoder is obtained by inserting data augmentation modules at preset positions in the original encoder, the enhanced downsampled image and the original downsampled image are identical. Figure 1 There is a one-to-one correspondence, where the corresponding enhanced downsampled image and the original downsampled image have the same size. Optionally, in the enhanced encoder, the enhanced downsampled images output by all convolutional layers except the last convolutional layer can be used as pre-selected enhanced downsampled images.
[0089] In an optional embodiment, the step of randomly perturbing the enhanced downsampled map based on the data augmentation module to obtain the enhanced feature map includes:
[0090] Obtain all feature values in the enhanced downsampled image;
[0091] Each feature value in the enhanced downsampled image is randomly perturbed according to the random perturbation formula to obtain the perturbation value corresponding to each feature value. The random perturbation formula satisfies the following relationship:
[0092]
[0093] Where x is any feature value in the enhanced downsampling image, μ and σ are the mean and variance of all feature values, respectively, γ and β are the learnable parameters of the data augmentation module, and LDP(x) is the perturbation value corresponding to feature value x;
[0094] Replace all feature values in the enhanced downsampled image with the corresponding perturbation values to obtain the enhanced feature map corresponding to the enhanced downsampled image.
[0095] Here, γ and β are the learnable parameters of the data augmentation module, and their specific values are determined by the training process of the initial classification model. It should be noted that the learnable parameters in different data augmentation modules are independent of each other and do not affect one another.
[0096] In an optional embodiment, the last augmented downsampled map in the augmented encoder is input into the augmented classifier to obtain the augmented classification result.
[0097] Thus, the same data to be classified is input into the augmentation encoder to obtain multiple augmented downsampled maps. At the same time, the pre-selected augmented downsampled map is input into the data augmentation module to obtain at least one augmented feature map. The last augmented downsampled map is input into the augmentation classifier to obtain the augmented classification result, which is obtained based on the augmented downsampled map after data augmentation.
[0098] S14, construct a data augmentation loss based on the enhanced feature map, the original downsampled map, and the enhanced downsampled map; construct a classification loss based on the original classification result, the enhanced classification result, and the category label of the data to be classified; and use the sum of the data augmentation loss and the classification loss as the target loss.
[0099] In an optional embodiment, to ensure that the data augmentation module can maximize the perturbation of style features while preserving the semantic information in the data to be classified, a data augmentation loss needs to be constructed to constrain the data augmentation module to obtain the ideal data augmentation effect. That is, the ideal data augmentation effect is to add as much perturbation as possible to the data without changing its semantic information. In this application, the last original downsampled image and the last augmented downsampled image are used as the semantic information extracted from the data to be classified by the original branch and the data augmentation branch, respectively.
[0100] In an optional embodiment, the enhanced downsampling map and the original downsampling map Figure 1 A one-to-one correspondence exists, where the corresponding enhanced downsampled image and the original downsampled image have the same size. The construction of the data augmentation loss based on the enhanced feature map, the original downsampled image, and the enhanced downsampled image includes:
[0101] For any enhanced feature map, obtain the enhanced downsampled map corresponding to the enhanced feature map, and use the original downsampled map corresponding to the enhanced downsampled map in the original branch as the target downsampled map;
[0102] Calculate the Gram matrix of the target downsampled image and the enhanced feature map respectively, and calculate the style bias of the enhanced feature map based on the Gram matrix. The style bias satisfies the following relationship:
[0103]
[0104] in, For the i-th enhanced feature map, f i G(f) is the target downsampled image corresponding to the i-th enhanced feature map. i )and They represent f respectively i and Gram matrix, Indicates calculation The F-norm, The style bias of the i-th enhanced feature map;
[0105] The semantic bias is calculated based on the last original downsampled image and the last enhanced downsampled image, and the semantic bias satisfies the following relation:
[0106]
[0107] Among them, f * and These are the last original downsampled image and the last enhanced downsampled image, respectively. Indicates calculation L2 norm sem The semantic deviation;
[0108] A data augmentation loss is constructed based on the semantic bias and the style bias of each augmented feature map, and the data augmentation loss satisfies the following relationship:
[0109]
[0110] Among them, L sem Let K be the semantic bias, and K be the number of all enhanced feature maps. Let λ be the style bias of the i-th enhanced feature map. sem and λ spe L is a preset coefficient greater than 0. zq Enhance the loss of the data.
[0111] Where, λ sem and λ speThe values are 0.8 and 0.6, respectively. It should be noted that the value of the data augmentation loss is related to the semantic bias and the style bias of each augmented feature map. The smaller the semantic bias, the smaller the value of the data augmentation loss; the larger the style bias of all augmented feature maps, the smaller the value of the data augmentation loss. During the subsequent training of the initial classification network, the value of the data augmentation loss will be continuously reduced, that is, while keeping the semantic information in the data to be classified unchanged, the perturbation of the style features will be greatly increased.
[0112] In this optional embodiment, the Gram matrix is obtained by calculating the inner product between different channels in the same target downsampled map or enhanced feature map. That is, the Gram matrix is used to measure the characteristics of each channel in the feature map and the correlation between them, and can be used to characterize the style information of any feature map. For example, for a feature map of size W×H×C, where C is the number of channels in the feature map, each channel can be considered as a feature vector with 1 row and W×H columns. The Gram matrix then has a size of C rows and C columns, and the value g(u, v) in the u-th row and v-th column represents the inner product between channel u and channel v in the feature map. The value in the u-th row and v-th column satisfies the relationship: g(u, v) = h u ×(h v ) T , where h u Let h be the feature vector of channel u. v ) T It is the transpose of the eigenvectors of channel v.
[0113] In an optional embodiment, the classification loss includes a first classification loss L. cls1 Second classification loss L cls2 The first classification loss is a cross-entropy loss function constructed based on the original classification result and the category labels of the data to be classified, and the second classification loss is a cross-entropy loss function constructed based on the enhanced classification result and the category labels of the data to be classified.
[0114] In an optional embodiment, the target loss is the sum of the data augmentation loss and the classification loss, and the target loss satisfies the following relationship:
[0115] L final =L cls1 +L cls2 +L zq
[0116] Among them, L cls1 For the first classification loss, L cls2 For the second classification loss, L zq For the data augmentation loss, L finalThe target loss is denoted as .
[0117] Thus, in constructing the loss function, the first classification loss and the second classification loss are used to constrain the classification accuracy of the original classifier and the augmented classifier, and the data augmentation loss is used to constrain the data augmentation module to greatly increase the perturbation of style features while keeping the semantic information in the data to be classified unchanged.
[0118] S15, update the initial classification model according to the gradient descent method to complete one iteration of training, return to step S12, until the target loss value is less than the preset value, and obtain the target classification model.
[0119] In an optional embodiment, the initial classification model is updated according to gradient descent to complete one iteration of training. In one iteration of training, the model parameters in all initial classification models are updated in the direction of decreasing target loss. The model parameters include the convolution kernel parameters in all convolutional layers, the weights and biases of all fully connected layers, and the learnable parameters in all data augmentation modules.
[0120] In an optional embodiment, after completing one iteration of training, the process returns to step S12 to select new data to be classified for the next iteration of training. Training stops when the target loss is less than a preset value, and a target classification model is obtained. The preset value is 0.001.
[0121] In an optional embodiment, when the target loss is less than the preset value, the ensemble result of the original branch and the data augmentation branch is used as the target classification model. The input of the target classification model is any data to be classified, and the output is the classification result of the data to be classified. The ensemble result of the original branch and the data augmentation branch is obtained by fusing the original branch and the data augmentation branch using an ensemble learning algorithm, which includes, but is not limited to, Bagging, Boosting, and Stacking. For example, if the ensemble learning algorithm is Bagging, the classification results of the original branch and the data augmentation branch are weighted and averaged to obtain the classification result of the data to be classified output by the target model.
[0122] In an optional embodiment, the target classification model can obtain accurate classification results both before and after data augmentation, and the data after data augmentation has a large degree of perturbation of style features compared with the data before data augmentation. Therefore, the target classification model can obtain accurate classification results on different data, and the target classification model has high generalization.
[0123] In this way, by training the initial classification model through the objective function, the obtained target classification model can obtain accurate classification results both before and after data augmentation, thereby improving the generalization ability of the classification model.
[0124] As can be seen from the above technical solution, this application performs feature extraction and classification on the same data to be classified through the original branch and the data augmentation branch with added data augmentation module; it constructs a data augmentation loss based on the Gram matrix of the feature map and semantic information during the feature extraction process to ensure that the data augmentation module can greatly increase the perturbation of style features while keeping the semantic information in the data to be classified unchanged; at the same time, the classification loss constrains both the original branch and the augmentation branch to obtain accurate classification results, thereby improving the generalization of the classification model. The above method can be applied to medical image classification or medical text classification in the field of digital healthcare.
[0125] Please see Figure 3 , Figure 3 This is a functional block diagram of a preferred embodiment of the data-augmented classification model generalization improvement device of this application. The data-augmented classification model generalization improvement device 11 includes a construction unit 110, a data acquisition unit 111, an original branch unit 112, an augmented branch unit 113, a target loss unit 114, and an iterative training unit 115. The module / unit referred to in this application refers to a series of computer-readable instruction segments that can be executed by the processor 13 and perform a fixed function, stored in the memory 12. In this embodiment, the functions of each module / unit will be detailed in subsequent embodiments.
[0126] In an optional embodiment, the building unit 110 is used to build an initial classification model, which includes an original branch and a data augmentation branch. The original branch includes an original encoder and an original classifier, and the data augmentation branch includes an augmentation encoder and an augmentation classifier. The augmentation encoder is an original encoder that includes at least one data augmentation module.
[0127] In some embodiments, the original encoder includes multiple convolutional layers, and the original classifier includes multiple fully connected layers;
[0128] A data augmentation module is inserted at at least one preset position in the plurality of convolutional layers of the original encoder to obtain the augmented encoder. The augmented classifier includes a plurality of fully connected layers. The data augmentation module is used to add random perturbations to the downsampled map output by the convolutional layer to achieve data augmentation.
[0129] The original classifier and the augmented classifier may have the same or different structures, and when the original classifier and the augmented classifier have the same structure, the structural network parameters of the original classifier and the augmented classifier may or may not be shared.
[0130] In an optional embodiment, the acquisition unit 111 is used to acquire multiple data to be classified with category labels as a training dataset.
[0131] In an optional embodiment, the original branch unit 112 is used to select any data to be classified from the training dataset and input it into the original branch. The original encoder performs multiple downsamplings on the data to be classified to obtain at least one original downsampled image, and inputs the last original downsampled image into the original classifier to obtain the original classification result.
[0132] In an optional embodiment, the enhancement branch unit 113 is used to input the data to be classified into the data enhancement branch, the enhancement encoder performs multiple downsamplings on the data to be classified to obtain at least one enhanced downsampled map, and randomly perturbs the pre-selected enhanced downsampled map based on the data enhancement module to obtain at least one enhanced feature map, and inputs the last enhanced downsampled map into the enhancement classifier to obtain the enhanced classification result.
[0133] In some embodiments, the augmentation encoder downsamples the data to be classified multiple times to obtain at least one augmented downsampled map, and randomly perturbs a pre-selected augmented downsampled map based on the data augmentation module to obtain at least one augmented feature map, including:
[0134] For any convolutional layer in the enhanced encoder, an enhanced downsampled map is obtained by performing downsampling processing on the input data based on the convolutional layer, wherein the size of the enhanced downsampled map is less than or equal to the size of the input data.
[0135] When the end of the convolutional layer is directly connected to the data augmentation module, it indicates that the augmented downsampled map is a pre-selected augmented downsampled map. Then, the data augmentation module randomly perturbs the augmented downsampled map to obtain an augmented feature map.
[0136] Determine whether the convolutional layer is the last convolutional layer in the enhanced encoder;
[0137] If the convolutional layer is not the last convolutional layer in the enhanced encoder, then the enhanced downsampled map or the enhanced feature map is used as the input data for the next convolutional layer.
[0138] If the convolutional layer is the last convolutional layer in the enhanced encoder, then the enhanced downsampled map is used as the last enhanced downsampled map.
[0139] In some embodiments, the step of randomly perturbing the enhanced downsampled map based on the data augmentation module to obtain an enhanced feature map includes:
[0140] Obtain all feature values in the enhanced downsampled image;
[0141] Each feature value in the enhanced downsampled image is randomly perturbed according to the random perturbation formula to obtain the perturbation value corresponding to each feature value. The random perturbation formula satisfies the following relationship:
[0142]
[0143] Where x is any feature value in the enhanced downsampling image, μ and σ are the mean and variance of all feature values, respectively, γ and β are the learnable parameters of the data augmentation module, and LDP(x) is the perturbation value corresponding to feature value x;
[0144] Replace all feature values in the enhanced downsampled image with the corresponding perturbation values to obtain the enhanced feature map corresponding to the enhanced downsampled image.
[0145] In an optional embodiment, the target loss unit 114 is used to construct a data augmentation loss based on the enhanced feature map, the original downsampled map, and the enhanced downsampled map, construct a classification loss based on the original classification result, the enhanced classification result, and the category label of the data to be classified, and use the sum of the data augmentation loss and the classification loss as the target loss.
[0146] In some embodiments, the enhanced downsampling map and the original downsampling map Figure 1 A one-to-one correspondence exists, where the corresponding enhanced downsampled image and the original downsampled image have the same size. The construction of the data augmentation loss based on the enhanced feature map, the original downsampled image, and the enhanced downsampled image includes:
[0147] For any enhanced feature map, obtain the enhanced downsampled map corresponding to the enhanced feature map, and use the original downsampled map corresponding to the enhanced downsampled map in the original branch as the target downsampled map;
[0148] Calculate the Gram matrix of the target downsampled image and the enhanced feature map respectively, and calculate the style bias of the enhanced feature map based on the Gram matrix. The style bias satisfies the following relationship:
[0149]
[0150] in, For the i-th enhanced feature map, f i G(f) is the target downsampled image corresponding to the i-th enhanced feature map. i )and They represent f respectively i and Gram matrix, Indicates calculation The F-norm, The style bias of the i-th enhanced feature map;
[0151] The semantic bias is calculated based on the last original downsampled image and the last enhanced downsampled image, and the semantic bias satisfies the following relation:
[0152]
[0153] Among them, f * and These are the last original downsampled image and the last enhanced downsampled image, respectively. Indicates calculation L2 norm sem The semantic deviation;
[0154] A data augmentation loss is constructed based on the semantic bias and the style bias of each augmented feature map, and the data augmentation loss satisfies the following relationship:
[0155]
[0156] Among them, L sem Let K be the semantic bias, and K be the number of all enhanced feature maps. Let λ be the style bias of the i-th enhanced feature map. sem and λ spe L is a preset coefficient greater than 0. zq Enhance the loss of the data.
[0157] In some embodiments, the classification loss includes a first classification loss L cls1 Second classification loss L cls2 The first classification loss is a cross-entropy loss function constructed based on the original classification result and the category labels of the data to be classified, and the second classification loss is a cross-entropy loss function constructed based on the enhanced classification result and the category labels of the data to be classified.
[0158] In an optional embodiment, the iterative training unit 115 is used to update the initial classification model according to the gradient descent method to complete one iteration of training, return to the original branch unit, and obtain the target classification model when the target loss value is less than a preset value.
[0159] In some embodiments, when the value of the target loss is less than the preset value, the integration result of the original branch and the data augmentation branch is used as the target classification model. The input of the target classification model is any data to be classified, and the output is the classification result of the data to be classified.
[0160] As can be seen from the above technical solutions, this application performs feature extraction and classification on the same data to be classified through the original branch and the data augmentation branch with added data augmentation module; the data augmentation loss is constructed based on the Gram matrix of the feature map and semantic information during the feature extraction process to ensure that the data augmentation module can greatly increase the perturbation of style features while keeping the semantic information in the data to be classified unchanged; at the same time, the classification loss constrains both the original branch and the augmentation branch to obtain accurate classification results, thereby improving the generalization of the classification model.
[0161] Please see Figure 4 This is a schematic diagram of the structure of an electronic device provided in an embodiment of this application. The electronic device 1 includes a memory 12 and a processor 13. The memory 12 is used to store computer-readable instructions, and the processor 13 executes the computer-readable instructions stored in the memory to implement the method for improving the generalization of the data augmentation-based classification model described in any of the above embodiments.
[0162] In an alternative embodiment, the electronic device 1 further includes a bus and a computer program stored in the memory 12 and executable on the processor 13, such as a data augmentation-based classification model generalization enhancement program.
[0163] Figure 4 Only electronic device 1 with memory 12 and processor 13 is shown. It will be understood by those skilled in the art that... Figure 4 The structure shown does not constitute a limitation on the electronic device 1, and may include fewer or more components than shown, or combine certain components, or have different component arrangements.
[0164] Combination Figure 1 The memory 12 in the electronic device 1 stores a plurality of computer-readable instructions to implement a method for improving the generalization of a data augmentation-based classification model, and the processor 13 can execute the plurality of instructions to achieve:
[0165] S10, Build an initial classification model. The initial classification model includes an original branch and a data augmentation branch. The original branch includes an original encoder and an original classifier. The data augmentation branch includes an augmentation encoder and an augmentation classifier. The augmentation encoder is an original encoder that includes at least one data augmentation module.
[0166] S11, Collect multiple data sets to be classified with category labels as training datasets;
[0167] S12, Select any data to be classified from the training dataset and input it into the original branch. The original encoder performs multiple downsamplings on the data to be classified to obtain at least one original downsampled image, and inputs the last original downsampled image into the original classifier to obtain the original classification result.
[0168] S13, the data to be classified is input into the data augmentation branch, the augmentation encoder performs multiple downsamplings on the data to be classified to obtain at least one augmented downsampling map, and randomly perturbs the pre-selected augmented downsampling map based on the data augmentation module to obtain at least one augmented feature map, and the last augmented downsampling map is input into the augmentation classifier to obtain the augmented classification result;
[0169] S14, construct a data augmentation loss based on the enhanced feature map, the original downsampled map, and the enhanced downsampled map; construct a classification loss based on the original classification result, the enhanced classification result, and the category label of the data to be classified; and use the sum of the data augmentation loss and the classification loss as the target loss.
[0170] S15, update the initial classification model according to the gradient descent method to complete one iteration of training, return to step S12, until the target loss value is less than the preset value, and obtain the target classification model.
[0171] Specifically, the processor 13's implementation method for the above instructions can be found in [reference needed]. Figure 1 The descriptions of the relevant steps in the corresponding embodiments are not repeated here.
[0172] Those skilled in the art will understand that the schematic diagram is merely an example of electronic device 1 and does not constitute a limitation on electronic device 1. Electronic device 1 can be a bus-type structure or a star-type structure. Electronic device 1 may also include more or fewer other hardware or software than shown in the diagram, or different component arrangements. For example, electronic device 1 may also include input / output devices, network access devices, etc.
[0173] It should be noted that electronic device 1 is only an example. Other existing or future electronic products that are suitable for this application should also be included within the scope of protection of this application and are incorporated herein by reference.
[0174] The memory 12 includes at least one type of readable storage medium, which can be non-volatile or volatile. The readable storage medium includes flash memory, portable hard drives, multimedia cards, card-type memory (e.g., SD or DX memory), magnetic storage, magnetic disks, optical disks, etc. In some embodiments, the memory 12 can be an internal storage unit of the electronic device 1, such as a portable hard drive of the electronic device 1. In other embodiments, the memory 12 can also be an external storage device of the electronic device 1, such as a plug-in portable hard drive, Smart Media Card (SMC), Secure Digital (SD) card, Flash Card, etc., equipped on the electronic device 1. The memory 12 can be used not only to store application software and various types of data installed on the electronic device 1, such as code for improving the generalization of data augmentation-based classification models, but also to temporarily store data that has been output or will be output.
[0175] In some embodiments, the processor 13 may be composed of integrated circuits, such as a single packaged integrated circuit or multiple integrated circuits packaged with the same or different functions, including combinations of one or more central processing units (CPUs), microprocessors, digital processing chips, graphics processors, and various control chips. The processor 13 is the control unit of the electronic device 1, connecting various components of the electronic device 1 via various interfaces and lines. It executes programs or modules stored in the memory 12 (e.g., executing programs to improve the generalization of data-augmented classification models) and calls data stored in the memory 12 to perform various functions of the electronic device 1 and process data.
[0176] The processor 13 executes the operating system of the electronic device 1 and various installed applications. The processor 13 executes these applications to implement the steps in the embodiments of the data augmentation-based classification model generalization improvement methods described above, for example... Figure 1 The steps are shown.
[0177] For example, the computer program may be divided into one or more modules / units, which are stored in the memory 12 and executed by the processor 13 to complete this application. The one or more modules / units may be a series of computer-readable instruction segments capable of performing specific functions, which describe the execution process of the computer program in the electronic device 1. For example, the computer program may be divided into a construction unit 110, a data acquisition unit 111, an initial branch unit 112, an enhancement branch unit 113, a target loss unit 114, and an iterative training unit 115.
[0178] The integrated unit implemented as a software functional module described above can be stored in a computer-readable storage medium. This software functional module, stored in a storage medium, includes several instructions to cause a computer device (which may be a personal computer, computer equipment, or network device, etc.) or processor to execute portions of the data augmentation-based classification model generalization improvement method described in the various embodiments of this application.
[0179] If the modules / units integrated in electronic device 1 are implemented as software functional units and sold or used as independent products, they can be stored in a computer-readable storage medium. Based on this understanding, all or part of the processes in the methods of the above embodiments can also be implemented by a computer program instructing related hardware devices. The computer program can be stored in a computer-readable storage medium, and when executed by a processor, it can implement the steps of the various method embodiments described above.
[0180] The computer program includes computer program code, which may be in the form of source code, object code, executable file, or some intermediate form. The computer-readable medium may include: any entity or device capable of carrying the computer program code, recording media, USB flash drive, portable hard drive, magnetic disk, optical disk, computer memory, read-only memory (ROM), random access memory, and other memory.
[0181] Furthermore, the computer-readable storage medium may primarily include a stored program area and a stored data area, wherein the stored program area may store the operating system, an application program required for at least one function, etc.; and the stored data area may store data created based on the use of blockchain nodes, etc.
[0182] The blockchain referred to in this application is a novel application model of computer technologies such as distributed data storage, peer-to-peer transmission, consensus mechanisms, and encryption algorithms. Essentially, a blockchain is a decentralized database, a chain of data blocks linked together using cryptographic methods. Each data block contains information about a batch of network transactions, used to verify the validity of the information (anti-counterfeiting) and generate the next block. A blockchain can include an underlying blockchain platform, a platform product service layer, and an application service layer.
[0183] The bus can be a Peripheral Component Interconnect (PCI) bus or an Extended Industry Standard Architecture (EISA) bus, etc. This bus can be divided into address bus, data bus, control bus, etc. For ease of representation, in... Figure 4 The symbol is represented by only one arrow, but this does not indicate that there is only one bus or one type of bus. The bus is configured to enable communication between the memory 12 and at least one processor 13, etc.
[0184] This application also provides a computer-readable storage medium (not shown) storing computer-readable instructions, which are executed by a processor in an electronic device to implement the data augmentation-based classification model generalization improvement method described in any of the above embodiments.
[0185] In the several embodiments provided in this application, it should be understood that the disclosed systems, apparatuses, and methods can be implemented in other ways. For example, the apparatus embodiments described above are merely illustrative; for instance, the division of modules is only a logical functional division, and other division methods may be used in actual implementation.
[0186] The modules described as separate components may or may not be physically separate. The components shown as modules may or may not be physical units; that is, they may be located in one place or distributed across multiple network units. Some or all of the modules can be selected to achieve the purpose of this embodiment according to actual needs.
[0187] Furthermore, the functional modules in the various embodiments of this application can be integrated into one processing unit, or each unit can exist physically separately, or two or more units can be integrated into one unit. The integrated unit can be implemented in hardware or in the form of hardware plus software functional modules.
[0188] Furthermore, it is clear that the word "comprising" does not exclude other units or steps, and the singular does not exclude the plural. Multiple units or devices described in the specification may also be implemented by a single unit or device through software or hardware. Terms such as "first," "second," etc., are used to indicate names and do not indicate any specific order.
[0189] Finally, it should be noted that the above embodiments are only used to illustrate the technical solutions of this application and are not intended to limit it. Although this application has been described in detail with reference to preferred embodiments, those skilled in the art should understand that modifications or equivalent substitutions can be made to the technical solutions of this application without departing from the spirit and scope of the technical solutions of this application.
Claims
1. A method for improving the generalization of a data augmentation-based classification model, characterized in that, The method includes: S10, Build an initial classification model. The initial classification model includes an original branch and a data augmentation branch. The original branch includes an original encoder and an original classifier. The data augmentation branch includes an augmentation encoder and an augmentation classifier. The augmentation encoder is an original encoder that includes at least one data augmentation module. S11, Collect multiple data sets to be classified with category labels as training datasets. The data set to be classified is any data set that needs to be classified, and the category label is the actual classification result corresponding to the data set to be classified. The data set to be classified includes: medical images or medical text. S12, Select any data to be classified from the training dataset and input it into the original branch. The original encoder performs multiple downsamplings on the data to be classified to obtain at least one original downsampled image, and inputs the last original downsampled image into the original classifier to obtain the original classification result. S13, the data to be classified is input into the data augmentation branch, the augmentation encoder performs multiple downsamplings on the data to be classified to obtain at least one augmented downsampling map, and randomly perturbs the pre-selected augmented downsampling map based on the data augmentation module to obtain at least one augmented feature map, and the last augmented downsampling map is input into the augmentation classifier to obtain the augmented classification result; S14, construct a data augmentation loss based on the enhanced feature map, the original downsampled map, and the enhanced downsampled map; construct a classification loss based on the original classification result, the enhanced classification result, and the category label of the data to be classified; and use the sum of the data augmentation loss and the classification loss as the target loss. S15, update the initial classification model according to the gradient descent method to complete one iteration of training, return to step S12, until the target loss value is less than the preset value, and obtain the target classification model.
2. The method for improving the generalization of a data-augmented classification model as described in claim 1, characterized in that, The original encoder includes multiple convolutional layers, and the original classifier includes multiple fully connected layers; A data augmentation module is inserted at at least one preset position in the plurality of convolutional layers of the original encoder to obtain the augmented encoder. The augmented classifier includes a plurality of fully connected layers. The data augmentation module is used to add random perturbations to the downsampled map output by the convolutional layer to achieve data augmentation. The original classifier and the augmented classifier may have the same or different structures, and when the original classifier and the augmented classifier have the same structure, the structural network parameters of the original classifier and the augmented classifier may or may not be shared.
3. The method for improving the generalization of a data-augmented classification model as described in claim 2, characterized in that, The augmentation encoder performs multiple downsampling operations on the data to be classified to obtain at least one augmented downsampled map, and randomly perturbs the pre-selected augmented downsampled map based on the data augmentation module to obtain at least one augmented feature map, including: For any convolutional layer in the enhanced encoder, an enhanced downsampled map is obtained by performing downsampling processing on the input data based on the convolutional layer, wherein the size of the enhanced downsampled map is less than or equal to the size of the input data. When the end of the convolutional layer is directly connected to the data augmentation module, it indicates that the augmented downsampled map is a pre-selected augmented downsampled map. Then, the data augmentation module randomly perturbs the augmented downsampled map to obtain an augmented feature map. Determine whether the convolutional layer is the last convolutional layer in the enhanced encoder; If the convolutional layer is not the last convolutional layer in the enhanced encoder, then the enhanced downsampled map or the enhanced feature map is used as the input data for the next convolutional layer. If the convolutional layer is the last convolutional layer in the enhanced encoder, then the enhanced downsampled map is used as the last enhanced downsampled map.
4. The method for improving the generalization of a data-augmented classification model as described in claim 3, characterized in that, The step of randomly perturbing the enhanced downsampled map based on the data augmentation module to obtain the enhanced feature map includes: Obtain all feature values in the enhanced downsampled image; Each feature value in the enhanced downsampled image is randomly perturbed according to the random perturbation formula to obtain the perturbation value corresponding to each feature value. The random perturbation formula satisfies the following relationship: Where x is any feature value in the enhanced downsampling image. and These are the mean and variance of all eigenvalues, respectively. and These are the learnable parameters of the data augmentation module. The perturbation value corresponding to the eigenvalue x; Replace all feature values in the enhanced downsampled image with the corresponding perturbation values to obtain the enhanced feature map corresponding to the enhanced downsampled image.
5. The method for improving the generalization of a data-augmented classification model as described in claim 1, characterized in that, The enhanced downsampled image and the original downsampled image correspond one-to-one, and the corresponding enhanced downsampled image and the original downsampled image have the same size. The data augmentation loss constructed based on the enhanced feature map, the original downsampled image, and the enhanced downsampled image includes: For any enhanced feature map, obtain the enhanced downsampled map corresponding to the enhanced feature map, and use the original downsampled map corresponding to the enhanced downsampled map in the original branch as the target downsampled map; Calculate the Gram matrix of the target downsampled image and the enhanced feature map respectively, and calculate the style bias of the enhanced feature map based on the Gram matrix. The style bias satisfies the following relationship: in, For the i-th enhanced feature map, This is the target downsampled image corresponding to the i-th enhanced feature map. and They represent and Gram matrix, Indicates calculation The F-norm, The style bias of the i-th enhanced feature map; The semantic bias is calculated based on the last original downsampled image and the last enhanced downsampled image, and the semantic bias satisfies the following relation: in, and These are the last original downsampled image and the last enhanced downsampled image, respectively. Indicates calculation The 2-norm, The semantic deviation; A data augmentation loss is constructed based on the semantic bias and the style bias of each augmented feature map, and the data augmentation loss satisfies the following relationship: in, Let K be the semantic bias, and K be the number of all enhanced feature maps. For the style bias of the i-th enhanced feature map, and A preset coefficient greater than 0. Enhance the loss of the data.
6. The method for improving the generalization of a data-augmented classification model as described in claim 1, characterized in that, The classification loss includes the first classification loss. Second category loss The first classification loss is a cross-entropy loss function constructed based on the original classification result and the category labels of the data to be classified, and the second classification loss is a cross-entropy loss function constructed based on the enhanced classification result and the category labels of the data to be classified.
7. The method for improving the generalization of a data-augmented classification model as described in claim 1, characterized in that, When the target loss is less than the preset value, the integration result of the original branch and the data augmentation branch is used as the target classification model. The input of the target classification model is any data to be classified, and the output is the classification result of the data to be classified.
8. A device for improving the generalization of a data-augmented classification model, characterized in that, The device includes: A building unit is used to build an initial classification model. The initial classification model includes an original branch and a data augmentation branch. The original branch includes an original encoder and an original classifier. The data augmentation branch includes an augmentation encoder and an augmentation classifier. The augmentation encoder is an original encoder that includes at least one data augmentation module. The acquisition unit is used to acquire multiple data sets to be classified with category labels as training datasets. The data set to be classified is any data set that needs to be classified, and the category label is the actual classification result corresponding to the data set to be classified. The data set to be classified includes medical images or medical text. The original branch unit is used to select any data to be classified from the training dataset and input it into the original branch. The original encoder performs multiple downsamplings on the data to be classified to obtain at least one original downsampled image, and inputs the last original downsampled image into the original classifier to obtain the original classification result. An augmentation branch unit is used to input the data to be classified into the data augmentation branch. The augmentation encoder performs multiple downsampling on the data to be classified to obtain at least one augmented downsampled map, and randomly perturbs the pre-selected augmented downsampled map based on the data augmentation module to obtain at least one augmented feature map. The last augmented downsampled map is input into the augmentation classifier to obtain the augmented classification result. The target loss unit is used to construct a data augmentation loss based on the enhanced feature map, the original downsampled map, and the enhanced downsampled map, and to construct a classification loss based on the original classification result, the enhanced classification result, and the category label of the data to be classified, and to use the sum of the data augmentation loss and the classification loss as the target loss; An iterative training unit is used to update the initial classification model according to the gradient descent method to complete one iteration of training, and return to the original branch unit until the target loss is less than a preset value, thus obtaining the target classification model.
9. An electronic device, characterized in that, The electronic device includes: Memory, which stores computer-readable instructions; and The processor executes computer-readable instructions stored in the memory to implement the method for improving the generalization of a data-augmented classification model as described in any one of claims 1 to 7.
10. A computer-readable storage medium, characterized in that, The computer-readable storage medium stores computer-readable instructions, which, when executed by a processor, implement the method for improving the generalization of a data-augmented classification model as described in any one of claims 1 to 7.