An image classification model optimization method based on confusion samples
Patent Information
- Authority / Receiving Office
- CN · China
- Patent Type
- Patents(China)
- Current Assignee / Owner
- TONGJI ARTIFICIAL INTELLIGENCE RES INST SUZHOU CO LTD
- Filing Date
- 2022-11-28
- Publication Date
- 2026-06-19
Smart Images

Figure CN115908927B_ABST
Abstract
Description
Technical Field
[0001] This invention relates to the field of intelligent image classification technology, and more specifically to an image classification model optimization method based on confused samples. Background Technology
[0002] Image classification is a processing method that determines the category of the main objects in an image based on the features in the image information. It uses computers to perform quantitative analysis of images, classifying the image or each pixel or region in the image into one of several categories, in order to replace human visual interpretation.
[0003] Traditional image classification methods mainly include the following categories: (1) Color feature-based indexing techniques: Objects of the same type generally have similar color features, so objects can be distinguished based on color features, such as using color histograms for global color feature indexing and local color feature indexing; (2) Texture-based image classification techniques: Texture features are also one of the important features of an image, such as the gray-level co-occurrence matrix representation based on texture features and the wavelet transform-based representation; (3) Shape-based image classification techniques: Shape is also one of the important visualization contents of an image. Most current shape-based classification methods revolve around building image indexes from the contour features and regional features of the shape, such as line segment descriptions, spline fitting curves, Fourier descriptors, and Gaussian parameter curves. The advantage of these low-level features is that they are relatively simple to calculate, but the disadvantage is that they do not have good semantic expression capabilities, resulting in poor recognition performance, poor reliability, and poor stability when classifying complex, large, and fine-grained images.
[0004] Following feature representation methods, how to measure the feature distance between different images is also a key issue in image classification. Existing distance measurement models are mainly divided into two categories: non-learning methods and learning methods. Most methods choose simple non-learning metrics, such as first-order distance, second-order distance, and Bhattacharyya distance. However, due to the redundancy and robustness of the extracted image features, the recognition results are not ideal. On the other hand, learning-based metrics typically learn how to optimize the differences and similarities between samples, thus often resulting in relatively better recognition performance.
[0005] With the development of deep learning methods, the aforementioned feature representation and distance measurement problems have been unified into deep models. Through massive self-learning network weights, excellent feature extraction, feature representation, and feature measurement capabilities have been achieved. However, since the network weights are entirely dependent on the training samples, the model's classification performance is affected by the quality of the training data. It may learn incorrect or easily confused feature extraction knowledge, leading to the model outputting the wrong category during classification.
[0006] See patent publication number CN111428876A, which discloses an image classification method based on a hybrid dilated convolutional neural network using self-synchronous learning. This method can accelerate the convergence of traditional convolutional neural networks and improve generalization ability, while also avoiding image information loss caused by pooling, allowing for the computation of more information and thus improving classification performance. However, this method places high demands on the trained model, requiring the use of hybrid dilated convolutions, which limits its application scope. Furthermore, due to the influence of factors such as lighting changes, orientation, viewing angle, occlusion, and image resolution on easily confused samples, the model's image classification performance in complex and easily confused scenes remains poor. Summary of the Invention
[0007] The purpose of this invention is to provide an image classification model optimization method based on confused samples.
[0008] To achieve the above objectives, the technical solution adopted by the present invention is as follows:
[0009] An image classification model optimization method based on confused samples includes:
[0010] S1: Build a CNN feature extractor and use it to extract regular features and obfuscated features from the input image.
[0011] S2: Build a CNN feature classifier and use it to classify regular features, confused features, and fully connected layers corresponding to deconfused features.
[0012] S3: Calculate the cross-entropy loss for the outputs of the classification of regular features, confused features, and deconfused features.
[0013] S4: Sum the three cross-entropy losses, backpropagate the gradient, and train the model.
[0014] S5: Model fusion to obtain CNN model weights with decongestion capabilities.
[0015] Preferably, in S1, CNN feature extractors named f1 and f2 are established respectively. f1 is used to extract conventional features, and f2 is used to extract confusing features.
[0016] More preferably, the conventional features are de-obfuscated to obtain the de-obfuscated features:
[0017]
[0018]
[0019] X de =XX conf
[0020] Where: A is the input image, f1 is the conventional CNN feature extraction network, f2 is the confused CNN feature extraction network, X is the conventional convolutional feature, X conf To obfuscate the samples using convolutional features, where n is the feature length, X and X... conf In the vector, each position represents a classification feature of the image, and the value at each position represents the strength of the corresponding feature, X. de These are the features after removing the obfuscating features.
[0021] Preferably, in step S2, CNN feature classifiers named c1, c2, and c3 are established respectively. c1 is used to classify regular features, c2 is used to classify confused features, and c3 is used to classify the de-confused features. More preferably, the output results of the classification by c1, c2, and c3 are as follows:
[0022] y = c1(X)
[0023] y conf =c2(X conf )
[0024] y de =c3(XX) conf )
[0025] Where: c1 is a regular fully connected classification network, used to output the regular classification result y; c2 is a fully connected classification network with confusion features, used to output the classification result y of the confusion categories. conf c3 is a fully connected classification network for de-obfuscating features, used to output the classification result y of the correct class after de-obfuscation. de .
[0026] Preferably, in S3, conventional classification labels are used to supervise the output of conventional features and de-obfuscated features, and obfuscated category labels are used to supervise the output of obfuscated features.
[0027] More preferably, the loss calculation formula is as follows:
[0028] loss base =CE(y,y * )
[0029] loss conf =CE(y conf ,y')
[0030] loss de =CE(y de ,y * )
[0031] loss = loss base +lossconf +loss de
[0032] Where: CE is the cross-entropy loss, loss base For regular classification, loss conf To obfuscate the classification loss, loss de y' represents the classification loss after deconfusion; y' is the supervision label of the easily confused category, which is the category with the second highest confidence in the c1 prediction results.
[0033] Preferably, in the above technical solution, in S4, the gradient propagation direction during model training is:
[0034] loss base →y→X→f1
[0035] loss conf →y conf →X conf →f2
[0036] loss de →y de →X conf →f2.
[0037] Preferably, in S5, after the loss converges, a dual-branch image classification network is obtained; the network weight averaging method is used to fuse the dual branches of the CNN, and combined with the de-obfuscated fully connected classification network to obtain the final network model.
[0038] More preferably, the network is simplified using a network weight averaging method:
[0039] f = WA(f1, f2)
[0040] c = c3
[0041] Where: WA is the network weight averaging operation, and f and c are the feature extraction network and classification network obtained by merging after averaging.
[0042] Due to the application of the above technical solution, the present invention has the following advantages compared with the prior art:
[0043] 1. Extracting easily confused features from images and using easily confused samples to supervise the training of the feature extraction network can enable the network to recognize easily confused features;
[0044] 2. Based on the interpretable feature attribution method, the deconfused features are obtained by combining ordinary features and confusing features, which reduces the problem of easily confused features in ordinary features, and supervised learning ensures that the deconfused features still have enough features for correct classification.
[0045] 3. By adopting the idea of averaging network weights, multiple branches of the network are combined into one branch, so that the computational cost and efficiency remain unchanged compared with the original network.
[0046] 4. It has broad applicability and can be applied to most conventional CNN models, improving classification accuracy. Attached Figure Description
[0047] Appendix Figure 1 This is a schematic diagram of the process of the method of the present invention.
[0048] Appendix Figure 2 This is a network structure diagram of the method of the present invention.
[0049] Appendix Figure 3 This is a feature distribution map visualized using t-SNE on the CIFAR-100 dataset using the method of this invention.
[0050] Appendix Figure 4 This is a graph showing the change in model accuracy during the training process using the method of this invention.
[0051] Appendix Figure 5 The CAM heatmap is obtained by de-obfuscating the features of the model in the method of this invention. Detailed Implementation
[0052] The technical solution of the present invention will now be clearly and completely described with reference to the accompanying drawings. Obviously, the described embodiments are only some, not all, of the embodiments of the present invention. Based on the embodiments of the present invention, all other embodiments obtained by those skilled in the art without creative effort are within the scope of protection of the present invention.
[0053] like Figure 1 The image classification model optimization method based on confused samples shown includes the following steps:
[0054] In image classification tasks, different categories often share similar features. To identify these common features that contribute little to image classification and are prone to confusion, a bi-branch feature extraction network and a classification network are first established:
[0055] 1) Establish two sets of CNN feature extractors (dual-branch feature extraction networks) to extract conventional features and confusing features from the input image. Specifically, they are named f1 and f2 respectively and are used to extract image features. Specifically, f1 is used to extract conventional features and f2 is used to extract confusing features.
[0056] 2) Establish three sets of CNN feature classifiers (classification networks) to classify the fully connected layers corresponding to regular features, confused features, and deconfused features. Specifically, establish CNN feature classifiers named c1, c2, and c3 respectively. Use c1 to classify regular features, use c2 to classify confused features, and use c3 to classify deconfused features.
[0057] Establish a good network structure, such as Figure 2 As shown.
[0058] The extraction of regular and confusing features from an image is specifically as follows: the image is simultaneously input into f1 and f2 for feature extraction. In the initial training phase, the two feature extraction branches are randomly initialized. As supervised training progresses, the two branches develop different feature extraction preferences, and the extracted features represent regular image classification and easily confused image classification, respectively. Simultaneously, a de-confusing operation is performed, that is, the confusing features are subtracted from the regular features to obtain the de-confusing features, i.e.:
[0059]
[0060]
[0061] X de =XX conf
[0062] Where: A is the input image, f1 is the conventional CNN feature extraction network, f2 is the confused CNN feature extraction network, X is the conventional convolutional feature, X conf To obfuscate the samples using convolutional features, where n is the feature length, X and X... conf In the vector, each position represents a classification feature of the image, and the value at each position represents the strength of the corresponding feature, X. de These are the features after removing the obfuscating features.
[0063] The classification of regular features, confusing features, and de-confusing features is as follows: Two fully connected networks are used to classify the extracted regular features and confusing features respectively. A single fully connected network is used to classify the de-confusing features. In total, three classification networks are used to classify the three types of features, resulting in three outputs:
[0064] y = c1(X)
[0065] y conf =c2(X conf )
[0066] y de =c3(XX) conf )
[0067] Where: c1 is a regular fully connected classification network, used to output the regular classification result y; c2 is a fully connected classification network with confusing features, used to output the classification result yconf of the confused categories; and c3 is a fully connected classification network with de-confusing features, used to output the classification result yde of the correct category after de-confusing.
[0068] The outputs of regular features and deconfused features are supervised using regular classification labels, while the outputs of confused features are supervised using labels of confused categories. Cross-entropy loss is calculated for the outputs of the three classification networks, and the loss calculation formula is as follows:
[0069] loss base =CE(y,y * )
[0070] loss conf =CE(yconf,y')
[0071] loss de =CE(yde,y * )
[0072] loss = loss base +loss conf +loss de
[0073] Where: CE is the cross-entropy loss, loss base For regular classification, loss conf To obfuscate the classification loss, loss de y' is the classification loss after deconfusion; y' is the supervision label of the confused category, which is the category with the second highest confidence in the c1 prediction results.
[0074] After summing the three cross-entropy losses, backpropagate the gradient to train the model. The gradient propagation direction during model training is as follows:
[0075] loss base →y→X→f1
[0076] loss conf →y conf →X conf →f2
[0077] loss de →y de →X conf →f2.
[0078] Specifically, the loss gradients of the output branches of the regular feature classification and the confused feature classification are propagated to the corresponding feature extraction networks (f1, f2), while the loss gradient of the output branch of the de-confused feature classification is propagated only to the feature extraction network of the confused features, thereby preventing it from affecting the originally correct feature extraction network.
[0079] Because the network trained above has multiple branches, its parameter count is more than double that of the basic network, significantly impacting its operational efficiency. Therefore, a network weight averaging method is used to simplify the network, fusing the two branches of the CNN:
[0080] f = WA(f1, f2)
[0081] c = c3
[0082] Where: WA is the network weight averaging operation, and f and c are the feature extraction network and classification network obtained by merging after averaging.
[0083] Specifically, the coefficients for averaging the network weights of f1 and f2 are 1 and -1, respectively. That is, the final network weights can be obtained by subtracting the two network weights. Finally, it is combined with c3 to form a single-branch network model, which improves feature extraction and classification capabilities without increasing network parameters.
[0084] The trained network exhibits a certain ability to deconfound features, showing improvements in feature distribution and accuracy, such as... Figure 3 , Figure 4 As shown. Figure 5 The image shows feature heatmaps obtained using the Network Interpretability (CAM) method. The first row shows the original image, the second row shows the heatmap of the standard model, and the third row shows the heatmap of this method. It can be seen that using a network with deconfused features allows the network to focus less on less discriminative features and more on more discriminative features.
[0085] Table 1 shows the performance comparison of the algorithm of this invention on the CIFAR-100 dataset for ResNet18 and VGG16:
[0086]
[0087]
[0088] The above embodiments are only for illustrating the technical concept and features of the present invention, and are intended to enable those skilled in the art to understand the content of the present invention and implement it accordingly. They should not be construed as limiting the scope of protection of the present invention. All equivalent changes or modifications made in accordance with the spirit and essence of the present invention should be covered within the scope of protection of the present invention.
Claims
1. A method for optimizing an image classification model based on confused samples, characterized in that: include: S1: Build CNN feature extractors named f1 and f2 respectively. Use f1 to extract regular features and f2 to extract obfuscated features. Deobfuscate the regular features to obtain the deobfuscated features. Where: A is the input image. For CNN's standard feature extraction network, For a CNN-based confusion feature extraction network, X represents a regular convolutional feature. To obfuscate the samples, the convolutional features are used, where n is the feature length and X is... In the vector, each position represents a classification feature of the image, and the value at each position represents the strength of the corresponding feature. The features after removing obfuscated features, S2: Build CNN feature classifiers named c1, c2, and c3 respectively. Use c1 to classify regular features, use c2 to classify confused features, and use c3 to classify de-confusing features. The output results of c1, c2, and c3 are as follows: in: This is a standard fully connected classification network used to output the standard classification result y. A fully connected classification network for confusing features, used to output classification results for confusing categories. A fully connected classification network for de-obfuscating features, used to output the classification result of the correct category after de-obfuscation. , S3: Calculate the cross-entropy loss for the outputs of the regular features, confused features, and deconfused features. Use the regular classification labels to supervise the outputs of the regular features and deconfused features, and use the labels of the confused categories to supervise the outputs of the confused features. S4: Sum the three cross-entropy losses, backpropagate the gradient, and train the model. S5: After the loss converges, a two-branch image classification network is obtained; using the network weight averaging method, the two branches of the CNN are fused together, and combined with the de-obfuscated fully connected classification network, the final network model is obtained.
2. The image classification model optimization method based on confused samples according to claim 1, characterized in that: The formula for calculating the loss is: Where: CE is the cross-entropy loss. For losses in the conventional classification, To obfuscate classification loss, The classification loss after decongestion; For supervision labels that are easily confused with categories, take The category with the second highest confidence level in the prediction results.
3. The image classification model optimization method based on confused samples according to claim 1, characterized in that: In S4, the gradient propagation direction during model training is: 。 4. The image classification model optimization method based on confused samples according to claim 1, characterized in that: The network is simplified using a network weight averaging method: Where: WA is the network weight averaging operation, and f and c are the feature extraction network and classification network obtained by merging after averaging.