Semi-supervised classification model training, image classification method and apparatus
By serializing and randomly masking the source samples, and combining mask prediction with linear classification network training, the dependence of deep neural networks on high-quality labeled samples is solved, achieving efficient feature extraction and model training in semi-supervised learning, and improving classification performance.
Patent Information
- Authority / Receiving Office
- CN · China
- Patent Type
- Patents(China)
- Current Assignee / Owner
- BEIJING WODONG TIANJUN INFORMATION TECH CO LTD
- Filing Date
- 2022-12-02
- Publication Date
- 2026-06-12
AI Technical Summary
Training deep neural networks requires a large number of high-quality labeled samples, which leads to high data labeling costs. How to effectively utilize unlabeled samples on the Internet for semi-supervised learning to improve model accuracy is an important issue.
By serializing and randomly masking the material samples, the overall semantic features are extracted using a mask prediction network. This is then combined with a linear classification network for training to form a semi-supervised classification model. After training all the data using the mask prediction network, the model is fine-tuned on a small amount of labeled data to extract higher-order semantic information.
It improves the classification performance of semi-supervised classification models, effectively utilizes unlabeled samples and a small number of labeled samples for training, avoids wasting model capacity, and improves classification accuracy.
Smart Images

Figure CN115908933B_ABST
Abstract
Description
Technical Field
[0001] This disclosure relates to the field of computer technology, specifically to the field of artificial intelligence technology, and in particular to semi-supervised classification model training methods and apparatus, image classification methods and apparatus, electronic devices, and computer-readable media. Background Technology
[0002] Deep neural networks have been applied to various fields such as image classification, object detection and tracking, semantic segmentation, sentiment analysis, machine translation, and speech recognition, becoming one of the most important methods in modern artificial intelligence. Training deep neural networks requires a large number of high-quality labeled samples; however, obtaining high-quality labeled samples is extremely difficult, with high time and economic costs associated with data labeling. In today's era of widespread mobile internet use, how to utilize the massive amount of unlabeled samples on the internet so that models can achieve high accuracy with the supervision of a small number of labeled samples and the help of a large number of unlabeled samples—that is, semi-supervised learning—has become a very important problem. Summary of the Invention
[0003] Embodiments of this disclosure provide methods and apparatus for training semi-supervised classification models, methods and apparatus for image classification, electronic devices, and computer-readable media.
[0004] In a first aspect, embodiments of this disclosure provide a semi-supervised classification model training method, which includes: serializing acquired material samples to obtain a material sequence; randomly masking the material sequence to obtain a mask sequence including overall semantic features; inputting the material sequence and the mask sequence into a pre-constructed mask prediction network of a semi-supervised classification network to calculate the mask prediction loss of the mask prediction network; inputting the overall semantic features with target labels and predicted by the mask prediction network into a linear classification network of the semi-supervised classification network to calculate the supervision loss of the linear classification network; and training the semi-supervised classification network based on the mask prediction loss and the supervision loss to obtain a semi-supervised classification model corresponding to the semi-supervised classification network.
[0005] In some embodiments, the mask prediction network includes: a mask segmenter, a mask classifier, and a trained material segmenter and a material encoding dictionary; inputting the material sequence and mask sequence into the mask prediction network of a pre-constructed semi-supervised classification network, and calculating the mask prediction loss of the mask prediction network includes: inputting the material sequence into the material segmenter to obtain material block encoding; selecting material vectors that match the material block encoding from the material encoding dictionary to obtain a material vector sequence; inputting the mask sequence into the mask segmenter to obtain prediction block encoding; inputting the prediction block encoding into the mask classifier so that the mask classifier selects prediction vectors that match the prediction block encoding from the material encoding dictionary to obtain a prediction vector sequence; and calculating the mask prediction loss of the mask prediction network based on the material vector sequence and the prediction vector sequence.
[0006] In some embodiments, the training process of the above-mentioned material segmenter and material encoding dictionary is as follows: the acquired sample materials are serialized to obtain a sample sequence; the sample sequence is input into the material segmentation network to obtain a sample feature sequence; a sample encoding sequence corresponding to the sample feature sequence is selected from the sample encoding dictionary, and the sample encoding sequence is decoded to obtain a prediction sequence; the sample sequence is input into a pre-trained sample supervision model to obtain a supervision sequence; based on the prediction sequence and the supervision sequence, the material segmentation network and the material encoding dictionary are trained; in response to the material segmentation network meeting the training completion condition, the material segmenter is obtained.
[0007] In some embodiments, the above-mentioned input of the overall semantic features with target labels and predicted by the mask prediction network into the linear classification network of the semi-supervised classification network to calculate the supervision loss of the linear classification network includes: inputting the overall semantic features with target labels and predicted by the mask prediction network into the linear classification network of the semi-supervised classification network to obtain the classification result output by the linear classification network; and calculating the supervision loss of the linear classification network based on the classification result and the target label.
[0008] In some embodiments, the above-mentioned training of a semi-supervised classification network based on mask prediction loss and supervised loss to obtain a semi-supervised classification model of the corresponding semi-supervised classification network includes: determining the weight values of the supervised loss; multiplying the supervised loss by the weight values and then adding them to the mask prediction loss to obtain the loss of the semi-supervised classification network; and training the semi-supervised classification network based on the loss of the semi-supervised classification network to obtain a semi-supervised classification model of the corresponding semi-supervised classification network.
[0009] Secondly, embodiments of this disclosure provide an image classification method, the method comprising: acquiring an image to be classified; dividing the image to be classified into blocks to obtain an image block sequence; inputting the image block sequence into a semi-supervised classification model to obtain the classification result of the target in the image to be classified output by the semi-supervised classification model, wherein the semi-supervised classification model is trained using the semi-supervised classification model training method as described in any embodiment of the first aspect.
[0010] Thirdly, embodiments of this disclosure provide a semi-supervised classification model training apparatus, comprising: an acquisition unit configured to serialize acquired material samples to obtain a material sequence; a masking unit configured to perform random masking on the material sequence to obtain a mask sequence including overall semantic features; a masking calculation unit configured to input the material sequence and the mask sequence into a pre-constructed masking prediction network of a semi-supervised classification network to calculate the masking prediction loss of the masking prediction network; a supervision calculation unit configured to input the overall semantic features with target labels and predicted by the masking prediction network into a linear classification network of the semi-supervised classification network to calculate the supervision loss of the linear classification network; and a training unit configured to train the semi-supervised classification network based on the masking prediction loss and the supervision loss to obtain a semi-supervised classification model corresponding to the semi-supervised classification network.
[0011] In some embodiments, the mask prediction network includes: a mask segmenter, a mask classifier, and a trained material segmenter and a material encoding dictionary; the mask calculation unit is further configured to: input the material sequence into the material segmenter to obtain material block encoding; select material vectors that match the material block encoding from the material encoding dictionary to obtain a material vector sequence; input the mask sequence into the mask segmenter to obtain prediction block encoding; input the prediction block encoding into the mask classifier so that the mask classifier selects prediction vectors that match the prediction block encoding from the material encoding dictionary to obtain a prediction vector sequence; and calculate the mask prediction loss of the mask prediction network based on the material vector sequence and the prediction vector sequence.
[0012] In some embodiments, the training process of the above-mentioned material segmenter and material encoding dictionary is as follows: the acquired sample materials are serialized to obtain a sample sequence; the sample sequence is input into the material segmentation network to obtain a sample feature sequence; a sample encoding sequence corresponding to the sample feature sequence is selected from the sample encoding dictionary, and the sample encoding sequence is decoded to obtain a prediction sequence; the sample sequence is input into a pre-trained sample supervision model to obtain a supervision sequence; based on the prediction sequence and the supervision sequence, the material segmentation network and the material encoding dictionary are trained; in response to the material segmentation network meeting the training completion condition, the material segmenter is obtained.
[0013] In some embodiments, the above-mentioned supervised computation unit is further configured to: input the overall semantic features with target labels and predicted by the mask prediction network into the linear classification network of the semi-supervised classification network to obtain the classification result output by the linear classification network; and calculate the supervised loss of the linear classification network based on the classification result and the target label.
[0014] In some embodiments, the training unit is further configured to: determine the weight values of the supervised loss; multiply the supervised loss by the weight values and add them to the mask prediction loss to obtain the loss of the semi-supervised classification network; and train the semi-supervised classification network based on the loss of the semi-supervised classification network to obtain the semi-supervised classification model of the corresponding semi-supervised classification network.
[0015] Fourthly, embodiments of this disclosure provide an image classification apparatus, comprising: an image acquisition unit configured to acquire an image to be classified; an image processing unit configured to perform block processing on the image to be classified to obtain an image block sequence; and a target classification unit configured to input the image block sequence into a semi-supervised classification model to obtain the classification result of the target in the image to be classified output by the semi-supervised classification model, wherein the semi-supervised classification model is trained using a semi-supervised classification model training apparatus according to any embodiment of the third aspect.
[0016] Fifthly, embodiments of this disclosure provide an electronic device comprising: one or more processors; a storage device having one or more programs stored thereon; and, when the one or more programs are executed by the one or more processors, causing the one or more processors to implement the method described in any embodiment of the first or second aspect.
[0017] In a sixth aspect, embodiments of the present disclosure provide a computer-readable medium having a computer program stored thereon that, when executed by a processor, implements the methods described in any of the embodiments of the first or second aspect.
[0018] The semi-supervised classification model training method and apparatus provided in the embodiments of this disclosure firstly serializes the acquired material samples to obtain a material sequence; secondly, the material sequence is randomly masked to obtain a mask sequence including overall semantic features; thirdly, the material sequence and the mask sequence are input into the mask prediction network of a pre-constructed semi-supervised classification network to calculate the mask prediction loss of the mask prediction network; next, the overall semantic features with target labels and predicted by the mask prediction network are input into the linear classification network of the semi-supervised classification network to calculate the supervision loss of the linear classification network; finally, based on the mask prediction loss and the supervision loss, the semi-supervised classification network is trained to obtain the semi-supervised classification model of the corresponding semi-supervised classification network. Therefore, using a mask prediction network can focus on high-order semantics and global features related to downstream tasks, avoiding the waste of capacity in a semi-supervised classification model. By first training all the data with a mask prediction network and then fine-tuning a small amount of labeled data with a linear classification network, the downstream tasks are predicted in advance during the mask prediction modeling training to extract higher-order, task-related semantic information, thereby improving the classification performance of the semi-supervised classification model. Attached Figure Description
[0019] Other features, objects, and advantages of this disclosure will become more apparent from the following detailed description of non-limiting embodiments with reference to the accompanying drawings:
[0020] Figure 1 This is an exemplary system architecture diagram to which one embodiment of this disclosure can be applied;
[0021] Figure 2 This is a flowchart of an embodiment of a semi-supervised classification model training method according to the present disclosure;
[0022] Figure 3 This is a schematic diagram of the network structure corresponding to the semi-supervised classification model disclosed in this paper;
[0023] Figure 4 This is a flowchart of an embodiment of the image classification method according to the present disclosure;
[0024] Figure 5 This is a schematic diagram of the structure of an embodiment of a semi-supervised classification model training device according to the present disclosure;
[0025] Figure 6 This is a schematic diagram of the structure of an embodiment of the image classification apparatus according to the present disclosure;
[0026] Figure 7 This is a schematic diagram of the structure of an electronic device suitable for implementing embodiments of the present disclosure. Detailed Implementation
[0027] The present disclosure will now be described in further detail with reference to the accompanying drawings and embodiments. It should be understood that the specific embodiments described herein are merely illustrative of the invention and not intended to limit it. Furthermore, it should be noted that, for ease of description, only the parts relevant to the invention are shown in the accompanying drawings.
[0028] It should be noted that, unless otherwise specified, the embodiments and features described in this disclosure can be combined with each other. This disclosure will now be described in detail with reference to the accompanying drawings and embodiments.
[0029] Figure 1 An exemplary system architecture 100 is shown that can be applied to the semi-supervised classification model training method or image classification method of this disclosure.
[0030] like Figure 1 As shown, system architecture 100 may include terminals 101 and 102, network 103, database server 104, and server 105. Network 103 serves as the medium for providing communication links between terminals 101 and 102, database server 104, and server 105. Network 103 may include various connection types, such as wired or wireless communication links or fiber optic cables, etc.
[0031] User 110 can use terminals 101 and 102 to interact with server 105 via network 103 to receive or send messages, etc. Various client applications can be installed on terminals 101 and 102, such as model training applications, image recognition applications, shopping applications, payment applications, web browsers, and instant messaging tools.
[0032] The terminals 101 and 102 here can be either hardware or software. When terminals 101 and 102 are hardware, they can be various electronic devices with displays, including but not limited to smartphones, tablets, e-book readers, MP3 players (Moving Picture Experts Group Audio Layer III), laptops, and desktop computers. When terminals 101 and 102 are software, they can be installed in the electronic devices listed above. They can be implemented as multiple software programs or software modules (e.g., to provide distributed services) or as a single software program or software module. No specific limitations are set here.
[0033] Database server 104 can be a database server that provides various services. For example, the database server can store a material sample set. The material sample set contains a large number of material samples, which can include material samples with target labels and material samples without target labels, where the target labels are the labels corresponding to the classification task. In this way, user 110 can also select material samples from the material sample set stored in database server 104 through terminals 101 and 102.
[0034] Server 105 can also be a server providing various services, such as a backend server supporting various applications displayed on terminals 101 and 102. The backend server can use material samples from the material sample set sent by terminals 101 and 102 to train a semi-supervised classification model, and can send the trained semi-supervised classification model back to terminals 101 and 102. In this way, users can apply the generated semi-supervised classification model to determine the classification result of objects in the image, etc.
[0035] The database server 104 and server 105 here can be either hardware or software. When they are hardware, they can be implemented as a distributed server cluster consisting of multiple servers, or as a single server. When they are software, they can be implemented as multiple software programs or software modules (e.g., used to provide distributed services), or as a single software program or software module. No specific limitations are made here.
[0036] It should be noted that the semi-supervised classification model training method or image classification method provided in the embodiments of this disclosure is generally executed by the server 105. Accordingly, the semi-supervised classification model training device or image classification device is also generally located in the server 105.
[0037] It should be noted that if server 105 can perform the relevant functions of database server 104, database server 104 may not be set up in system architecture 100.
[0038] It should be understood that Figure 1 The number of terminals, networks, database servers, and servers shown is merely illustrative. Depending on implementation needs, any number of terminals, networks, database servers, and servers can be included.
[0039] This disclosure provides a semi-supervised classification model training method. During the training process, a mask prediction network is used to facilitate efficient semi-supervised learning with a small number of labeled samples. Mask prediction encoding performs unsupervised feature extraction by learning the global semantic association information of the sample content, which can be better applied to downstream tasks. For example... Figure 2The flowchart 200 illustrates an embodiment of a semi-supervised classification model training method according to the present disclosure, which includes the following steps:
[0040] Step 201: Serialize the acquired material samples to obtain the material sequence.
[0041] In this embodiment, the material samples are samples obtained from the material sample set. For different semi-supervised classification tasks, the form of the material samples can be different. For example, for image classification tasks, the material samples are image samples; for text classification tasks, the material samples are text samples.
[0042] In this embodiment, the source samples can be data randomly extracted from a source sample set. The source sample set is a dataset used to implement the prediction of the semi-supervised classification model, and it includes multiple source samples. The source samples include: samples with target labels and samples without target labels. The number of samples with target labels is relatively small, while the number of samples without target labels is relatively large. By combining a smaller number of samples with and without target labels, a large amount of source samples can be generated.
[0043] In this embodiment, the target label is a label related to the target classification task, that is, a label corresponding to the classification task. Through this target label, the model can accurately determine the target type to which the material sample belongs.
[0044] In this embodiment, the execution entity of the semi-supervised classification model training method (e.g., Figure 1 The server shown can acquire material sample sets and extract material samples from them in various ways. For example, the executing entity can obtain them from a database server (e.g., via a wired or wireless connection). Figure 1 The existing material sample set stored in the database server 104 shown can be retrieved. For example, the user can access the sample set via a terminal (e.g., ...). Figure 1 The terminals 101 and 102 shown are used to collect samples. In this way, the executing entity can receive the samples collected by the terminals and store these samples locally, thereby generating a material sample set.
[0045] For text classification tasks, the above-mentioned serialization process of the acquired material samples results in a material sequence that includes: segmenting the sample text into characters to obtain a character sequence containing overall semantic features.
[0046] For image classification tasks, the above-mentioned serialization process of the acquired material samples to obtain the material sequence includes: cutting the sample images into blocks in sequence and arranging them into a material sequence, where each element in the material sequence is a small block of the sample image.
[0047] Step 202: Perform random masking on the material sequence to obtain a mask sequence that includes the overall semantic features.
[0048] In this embodiment, the material sequence consists of multiple fragments of generated material samples. Masking is performed on any one or more of these fragments to obscure their content, resulting in a mask sequence that includes the masked fragments.
[0049] In this embodiment, the overall semantic feature is a semantic representation used for downstream classification tasks, and this overall semantic feature can express the features of the entire sample. For samples with target labels in the sample, the overall semantic feature of the sample with target labels also carries the target label; for samples without target labels in the sample, the overall semantic feature of the sample without target labels also does not carry the target label.
[0050] For text classification tasks, semi-supervised classification networks insert a symbol before the material sample and use the output vector corresponding to the symbol as the semantic representation of the entire text for text classification.
[0051] For text classification tasks, random masking is performed on the material sequence to obtain a masked sequence that includes the overall semantic features. This includes masking any one or more characters in the character sequence that includes the overall semantic features to obtain a masked character sequence that includes the overall semantic features.
[0052] For image classification tasks, such as Figure 3 As shown, random masking of the source sequence to obtain a masked sequence including overall semantic features includes: masking any one or more image patches in the image patch sequence including overall semantic features (CLS), and appending a trainable additional image patch to the masked sequence, such as... Figure 3 The "CLS" in the image refers to the extra image patch used as a semantic feature for subsequent overall image classification. When the mask prediction network makes predictions, this extra image patch will also be predicted, just like other image patches in the sequence.
[0053] Step 203: Input the material sequence and mask sequence into the mask prediction network of the pre-built semi-supervised classification network, and calculate the mask prediction loss of the mask prediction network.
[0054] In this embodiment, the mask prediction network is used to encode the material sequence and predict the mask material in the mask sequence. The mask prediction network includes an encoding sub-network, a prediction sub-network, and a loss calculation module. The encoding sub-network is used to encode the material sequence, the prediction sub-network is used to predict the mask material in the mask sequence, and the loss calculation module is used to calculate the loss between the prediction result of the prediction sub-network and the encoding result of the encoding sub-network to obtain the mask prediction loss.
[0055] For image classification tasks, the mask sequence is fed into the prediction subnetwork of the mask prediction network. The prediction subnetwork predicts the image words of each masked image patch in the mask sequence. This is actually a multi-classification problem, that is, correctly classifying the image patch into its corresponding image word. The encoding subnetwork can provide the true mask image words. Here, the mask prediction problem is transformed into a supervised classification problem. The mask prediction loss is the cross-entropy classification loss. By optimizing the mask prediction loss, the mask prediction network can correctly find the remaining parts based on a partial image. If the extracted features can accurately determine the "concept" represented by the masked partial image, it means that the model can extract the high-order semantic information of the sample well. At the same time, mask prediction modeling cleverly constructs such a feature extraction task into a classification problem, completing the extraction of unsupervised sample features through a relatively simple task.
[0056] In this embodiment, the mask prediction network performs mask prediction training on both samples with and without target labels in the source sample. Therefore, the mask prediction loss is equal to the sum of the first loss corresponding to the sample with target label and the second loss corresponding to the sample without target label.
[0057] Step 204: Input the overall semantic features with target labels and predicted by the mask prediction network into the linear classification network of the semi-supervised classification network, and calculate the supervised loss of the linear classification network.
[0058] In this embodiment, the overall semantic feature processed by the semi-supervised classification network is a high-order semantic feature, which can effectively summarize the features of the material sample.
[0059] In this embodiment, the target label is a label related to the classification task. When the linear classification network makes predictions, it makes predictions related to the task type for samples with target labels to obtain classification prediction results. Based on the difference between the prediction results and the target labels, it calculates the cross-entropy classification loss to obtain the supervision loss of the linear classification network.
[0060] In this embodiment, the high-order semantic features with target labels belong to the task samples. The high-order semantic features with target labels are used as features and input into the linear classification network for training. The loss of the linear classification network is calculated using the multi-class cross-entropy loss function to obtain the supervised loss.
[0061] For image classification tasks, such as Figure 3In this context, the overall semantic feature is CLS. The overall semantic feature Ct processed by the mask prediction network is the feature after the mask prediction network trains on the overall semantic feature CLS. Inputting the overall semantic feature Ct processed by the mask prediction network into the linear classification network enables the linear classification network to predict the target type of the material sample.
[0062] Step 205: Based on the mask prediction loss and the supervision loss, train a semi-supervised classification network to obtain the semi-supervised classification model of the corresponding semi-supervised classification network.
[0063] In this embodiment, the semi-supervised classification model is the model corresponding to the semi-supervised classification network. When the semi-supervised classification network meets the training completion condition during the iterative training process, the current semi-supervised classification network is determined to be the semi-supervised classification model.
[0064] In this embodiment, the semi-supervised classification network and the semi-supervised classification model are used to represent the correspondence between the material and the target category in the material. The material can be an image or text. For example, if the material is an image, the semi-supervised classification model is used to represent the correspondence between the image and the target type in the image.
[0065] In this embodiment, the training completion conditions may include: the number of training iterations reaching a predetermined iteration threshold, and the loss of the semi-supervised classification network being less than a predetermined loss threshold. For example, if the training iterations reach 50,000, and the loss of the semi-supervised classification network is less than 0.05, this embodiment can accelerate the model convergence speed by setting training completion conditions.
[0066] In this embodiment, the above-mentioned training of a semi-supervised classification network based on mask prediction loss and supervised loss to obtain a semi-supervised classification model of the corresponding semi-supervised classification network includes: adding the mask prediction loss and the supervised loss to obtain the loss of the semi-supervised classification network; detecting whether the loss of the semi-supervised classification network reaches a predetermined loss threshold; in response to detecting that the loss of the semi-supervised classification network cannot reach the predetermined loss threshold, obtaining the number of training iterations of the semi-supervised classification network; when the number of training iterations reaches a predetermined iteration threshold, determining that the semi-supervised classification network is a semi-supervised classification model that has been trained; and in response to detecting that the loss of the semi-supervised classification network reaches the predetermined loss threshold, determining that the semi-supervised classification network is a semi-supervised classification model that has been trained.
[0067] It should be noted that, in response to the detection that the loss of the semi-supervised classification network has not reached the predetermined loss threshold and the number of training iterations has not reached the predetermined iteration threshold, steps 201 to 205 can continue to be executed until the semi-supervised classification network meets the training completion conditions.
[0068] After the above steps, data-efficient semi-supervised learning based on mask predictive coding can be achieved. Mask predictive coding can better extract task-related high-order global semantic features under the guidance of a small number of supervision signals, thereby achieving the goal of data-efficient semi-supervised learning.
[0069] The semi-supervised classification model training method provided in this disclosure first serializes the acquired material samples to obtain a material sequence; second, it randomly masks the material sequence to obtain a mask sequence including overall semantic features; third, it inputs the material sequence and the mask sequence into a pre-constructed mask prediction network of a semi-supervised classification network to calculate the mask prediction loss of the mask prediction network; next, it inputs the overall semantic features with target labels and predicted by the mask prediction network into a linear classification network of the semi-supervised classification network to calculate the supervision loss of the linear classification network; finally, it trains the semi-supervised classification network based on the mask prediction loss and the supervision loss to obtain the semi-supervised classification model of the corresponding semi-supervised classification network. Therefore, using a mask prediction network allows attention to high-order semantics and global features related to downstream tasks, avoiding wasted capacity in the semi-supervised classification model; by first training all data with the mask prediction network and then fine-tuning a small amount of labeled data with the linear classification network, the downstream task is predicted in advance during mask prediction modeling training to extract higher-order, task-related semantic information, thus improving the classification performance of the semi-supervised classification model.
[0070] In some optional implementations of this embodiment, the mask prediction network includes: a mask segmenter, a mask classifier, and a trained material segmenter and a material encoding dictionary; inputting the material sequence and mask sequence into the mask prediction network of a pre-constructed semi-supervised classification network, and calculating the mask prediction loss of the mask prediction network includes: inputting the material sequence into the material segmenter to obtain material block encoding; selecting material vectors that match the material block encoding from the material encoding dictionary to obtain a material vector sequence; inputting the mask sequence into the mask segmenter to obtain prediction block encoding; inputting the prediction block encoding into the mask classifier so that the mask classifier selects prediction vectors that match the prediction block encoding from the material encoding dictionary to obtain a prediction vector sequence; and calculating the mask prediction loss of the mask prediction network based on the material vector sequence and the prediction vector sequence.
[0071] In this embodiment, after obtaining the material vector sequence and the prediction vector sequence, the material vector sequence is used as the ground truth of the prediction vector sequence. The mask prediction loss of the mask prediction network can be calculated by using the multi-class cross-entropy loss function.
[0072] In this embodiment, when the mask prediction network is used to perform masking processing on an image, the mask prediction network is a network that performs masking and image type prediction processing on the image. Specifically, the mask segmenter and the material segmenter can adopt the Beit (Bidirectional Encoder representation from Image Transformers) model structure.
[0073] In this embodiment, when the mask prediction network is used to mask text, the mask prediction network is a network that performs masking and text type prediction on the text. Specifically, the mask segmenter and the material segmenter can adopt the Bert (Bidirectional Encoder Representations from Transformers) model structure.
[0074] like Figure 3 As shown, the image block sequence obtained by the material segmentation VIT1 is encoded into N "image words": Image blocks are encoded using the masked word segmenter VIT2. The main purpose of the material segmenter and the masked word segmenter is to encode image blocks into discrete "image words" (an "image word" can be analogous to a "concept," such as...). Figure 3 Image word number 37 describes the concept "eyes," and image word number 78 describes the concept "mouth." Each image word is still a vector, denoted as... This set of conceptual vectors is called the "material encoding dictionary". Specifically, let the vector obtained by encoding the image patch sequence through the image segmenter be denoted as... The final encoded image word is combined with the image word vector v that is closest to the word in dictionary h to form z. That is:
[0075] A subset of original image patches from a random masking sequence, the set of masked image patches is denoted as . And replace it with a trainable encoding e m This yields the mask sequence. The entire input image is then represented as:
[0076] The mask prediction network provided in this embodiment can achieve a prediction accuracy of 65.12% with only 1% of the target label samples using a mask word segmenter. It can effectively utilize unsupervised samples and a small number of supervised samples for training, ensuring the effectiveness of semi-supervised classification model training.
[0077] The optional implementation provides a method for calculating the mask prediction loss of the mask prediction network. It obtains the prediction vector sequence of the corresponding prediction block encoding through a mask segmenter and a mask classifier, and provides a reliable annotation basis for the prediction vector sequence through a material segmenter and a material encoding dictionary, thereby improving the reliability of the mask prediction network in obtaining high-order semantics and global features.
[0078] In some optional implementations of this embodiment, the training process of the material segmenter and the material encoding dictionary is as follows: the acquired sample materials are serialized to obtain a sample sequence; the sample sequence is input into the material segmentation network to obtain a sample feature sequence; a sample encoding sequence corresponding to the sample feature sequence is selected from the sample encoding dictionary, and the sample encoding sequence is decoded to obtain a prediction sequence; the sample sequence is input into a pre-trained sample supervision model to obtain a supervision sequence; based on the prediction sequence and the supervision sequence, the material segmentation network and the material encoding dictionary are trained; in response to the material segmentation network meeting the training completion condition, the material segmenter is obtained.
[0079] In this optional implementation, the sample supervision model can employ a pre-trained CLIP (Contrastive Language-Image Pre-training, a transfer visual model based on natural language supervision signals) model. The CLIP model first inputs the image and text into an image encoder and a text encoder, respectively, to obtain vector representations of the image and text. Then, these vector representations are mapped to a multimodal space, resulting in new, directly comparable image and text vector representations (this is a common method in multimodal learning; data representations from different modalities may differ and cannot be directly compared, so mapping different modalities to the same multimodal space facilitates subsequent similarity calculations). Next, the cosine similarity between the image and text vectors is calculated. Finally, an objective function based on the contrastive learning principle is used to ensure higher similarity for positive sample pairs and lower similarity for negative sample pairs.
[0080] For image classification tasks, the material segmentation network is an image segmentation network, and the sample encoding dictionary is an image encoding dictionary. The acquired image samples are serialized to obtain a sample sequence. This sample sequence is then input into the image segmentation network to be encoded into image words. Finally, the decoder in the image segmentation network decodes the image words into image patch features. i The goal is to reconstruct reasonable image patch features from image words. The decoded feature supervision signal t... i This comes from a pre-trained sample-supervised model. The overall training loss is...
[0081]
[0082] In this overall training loss, sg[h i [] is the vector of the decoded image. It should be noted that a decoder can be used to decode the sample encoded sequence. During the training of the material segmenter and the material encoding dictionary, the decoder is also trained along with the material segmenter and the material encoding dictionary until the overall training loss meets the training requirements, thus obtaining the trained material segmenter and material encoding dictionary.
[0083] In calculating the overall training loss, since image words are discrete and cannot be directly differentiated, a "straight through" method can be used for estimation. After training, this yields a word segmenter and a word encoding dictionary. Unlike generative model-based reconstruction, this method does not perform pixel-by-pixel restoration, but only reconstructs based on low-dimensional features, thus avoiding the problem of wasted model capacity in generative models.
[0084] The training method for the material segmenter and material encoding dictionary provided in this optional implementation uses a pre-trained sample-supervised model to train both the material segmenter and the material encoding dictionary simultaneously, which can provide a reliable basis for updating the parameters of the material segmenter and the material encoding dictionary.
[0085] In some optional implementations of this embodiment, the above-mentioned inputting the overall semantic features with target labels and predicted by the mask prediction network into the linear classification network of the semi-supervised classification network and calculating the supervision loss of the linear classification network includes: inputting the overall semantic features with target labels and predicted by the mask prediction network into the linear classification network of the semi-supervised classification network to obtain the classification result output by the linear classification network; and calculating the supervision loss of the linear classification network based on the classification result and the target label.
[0086] In this optional implementation, the linear classification network can be a binary classification network or a multi-class classification network.
[0087] In this optional implementation, after obtaining the classification result of the linear classification network, the classification result and the target label are substituted into the cross-entropy loss function to calculate the supervised loss of the linear classification network.
[0088] The optional implementation provides a method for calculating the supervised loss of a linear classification network. By inputting the overall semantic features, which are labeled with the target and predicted by a mask prediction network, into the semi-supervised classification network, the linear classification network can perform type labeling on the overall semantic features based on the target label, thus providing a reliable basis for the classification of the linear classification network.
[0089] In some optional implementations of this embodiment, the above-mentioned training of a semi-supervised classification network based on mask prediction loss and supervised loss to obtain a semi-supervised classification model of the corresponding semi-supervised classification network includes: determining the weight values of the supervised loss; multiplying the supervised loss by the weight values and then adding them to the mask prediction loss to obtain the loss of the semi-supervised classification network; and training the semi-supervised classification network based on the loss of the semi-supervised classification network to obtain a semi-supervised classification model of the corresponding semi-supervised classification network.
[0090] In this optional implementation, the semi-supervised classification network is considered to have completed training when the loss of the semi-supervised classification network reaches a preset loss threshold, thus obtaining a semi-supervised classification model.
[0091] In this optional implementation, the mask prediction training loss and the supervised training loss are fused together by weight values, which can achieve feature extraction guided by a small number of labeled samples. The entire data can be efficiently semi-supervised learning can be completed end-to-end.
[0092] In this optional implementation, since there are few material samples for the target label, setting the weight value of the supervision loss to a larger value can increase the importance of the supervision loss and improve the accuracy of semi-supervised classification.
[0093] The optional implementation provides a method for determining the loss of a semi-supervised classification network. By assigning weights to the supervised loss, the proportion of the supervised loss in the entire semi-supervised classification network is increased, thereby improving the reliability and accuracy of training the semi-supervised classification network.
[0094] Optionally, the above-mentioned training of a semi-supervised classification network based on mask prediction loss and supervised loss to obtain the semi-supervised classification model of the corresponding semi-supervised classification network includes: determining a first weight value of the supervised loss; determining a second weight value of the mask prediction loss, multiplying the supervised loss by the first weight value, and adding the product of the second weight value and the mask prediction loss to obtain the loss of the semi-supervised classification network; and training the semi-supervised classification network based on the loss of the semi-supervised classification network to obtain the semi-supervised classification model of the corresponding semi-supervised classification network.
[0095] Please see Figure 4 The diagram illustrates a flow 400 of an embodiment of the image classification method provided in this disclosure, which may include the following steps:
[0096] Step 401: Obtain the image to be classified.
[0097] In this embodiment, the execution entity on which the image classification method runs can communicate with a terminal (such as...). Figure 1 The system communicates with terminals 101 and 102 to obtain the images to be classified sent by the terminals.
[0098] In this embodiment, the image to be classified is an image in which the species of the target cannot be determined. For example, the image to be classified is an image that includes different types of animals, but the types of animals in the image cannot be determined.
[0099] Step 402: Divide the image to be classified into blocks to obtain an image block sequence.
[0100] In this embodiment, the image to be classified is divided into blocks to obtain multiple image blocks. These multiple image blocks are combined together to obtain the image to be classified. The multiple image blocks are arranged in sequence to obtain an image block sequence.
[0101] Step 403: Input the image patch sequence into the semi-supervised classification model to obtain the classification result of the target in the image to be classified from the output of the semi-supervised classification model.
[0102] In this embodiment, an image patch sequence is input into a semi-supervised classification model. The semi-supervised classification model identifies and classifies the features of targets in the image patch sequence, obtaining the classification result of the targets in the image to be classified. The classification result may include: the type of the target, and the confidence level of the target belonging to different target types. By comparing the confidence levels of the target in different target types, the specific type of the target can be determined. It should be noted that the semi-supervised classification model can be a binary classification model or a multi-class classification model. When the semi-supervised classification model is a binary classification model, the classification result of the target can be determined by whether the target in the image to be classified belongs to a predetermined target type. When the semi-supervised classification model is a multi-class classification model, the classification result of the target can be determined by which of several predetermined target types the target in the image to be classified belongs to.
[0103] In this embodiment, the predetermined target type and multiple target types are related to the target label of the semi-supervised classification model. When the target label represents only one type of target, the semi-supervised classification model is a binary classification model; when the target label identifies multiple types of targets, the semi-supervised classification model is a multi-class classification model.
[0104] In this embodiment, a semi-supervised classification model is generated using the semi-supervised classification model training method described in the above embodiment. The specific generation process of the semi-supervised classification model can be found in [reference needed]. Figure 2 The relevant descriptions of the embodiments will not be repeated here.
[0105] It should be noted that the image classification method in this embodiment can be used to test the semi-supervised classification models generated in the above embodiments. Furthermore, the semi-supervised classification models can be continuously optimized based on the test results. This method can also be a practical application of the semi-supervised classification models generated in the above embodiments. Using the semi-supervised classification models generated in the above embodiments to identify target types in images to be classified helps improve the efficiency of image recognition.
[0106] The image classification method provided in the embodiments of this disclosure first acquires an image to be classified; second, it divides the image to be classified into blocks to obtain an image block sequence; finally, it inputs the image block sequence into a semi-supervised classification model to obtain the classification result of the target in the image to be classified, output by the semi-supervised classification model. Thus, by using a pre-trained semi-supervised classification model to identify the image to be classified and obtain the target classification result, the efficiency of image classification is improved.
[0107] Optionally, this embodiment also provides a text classification method, which includes: obtaining the text to be classified; performing word segmentation on the text to be classified to obtain a word segmentation sequence; inputting the word segmentation sequence into a semi-supervised classification model to obtain the classification result of the text to be classified output by the semi-supervised classification model.
[0108] In this embodiment, the text to be classified is text whose theme, sentiment type, intent, or other information cannot be determined. For example, the text to be classified is a sentence "The weather is nice today," but the theme of the sentence cannot be determined.
[0109] In this embodiment, the text to be classified is segmented into words, which can yield multiple words and characters. The words and characters are then arranged in order to obtain the sequence of words to be segmented.
[0110] In this embodiment, the sequence to be segmented is input into a semi-supervised classification model. The semi-supervised classification model identifies and classifies the features of the sequence to be segmented, obtaining the classification result of the text to be classified. The classification result can include: the type of the target, and the confidence level of the target belonging to different target types. By comparing the confidence levels of the target in different target types, the specific type of the target can be determined. Specifically, when classifying text by topic, the target type includes: different kinds of topics. When classifying text by sentiment, the target type includes: negative and positive. When classifying text by intent, the target type includes: different kinds of intent.
[0111] Further reference Figure 5 As an implementation of the methods shown in the above figures, this disclosure provides an embodiment of a semi-supervised classification model training device, which is similar to... Figure 2 Corresponding to the method embodiments shown, this device can be specifically applied to various electronic devices.
[0112] like Figure 5As shown, an embodiment of this disclosure provides a semi-supervised classification model training device 500, which includes: an acquisition unit 501, a mask processing unit 502, a mask calculation unit 503, a supervised calculation unit 504, and a training unit 505. The acquisition unit can be configured to serialize acquired material samples to obtain a material sequence. The mask processing unit 502 can be configured to perform random masking on the material sequence to obtain a mask sequence including overall semantic features. The mask calculation unit 503 can be configured to input the material sequence and the mask sequence into a pre-constructed mask prediction network of a semi-supervised classification network to calculate the mask prediction loss of the mask prediction network. The supervised calculation unit 504 can be configured to input the overall semantic features with target labels and predicted by the mask prediction network into a linear classification network of the semi-supervised classification network to calculate the supervised loss of the linear classification network. The training unit 505 can be configured to train the semi-supervised classification network based on the mask prediction loss and the supervised loss to obtain a semi-supervised classification model corresponding to the semi-supervised classification network.
[0113] In this embodiment, the specific processing of the semi-supervised classification model training device 500, including the acquisition unit 501, mask processing unit 502, mask calculation unit 503, supervised calculation unit 504, and training unit 505, and the resulting technical effects, can be found in the following references: Figure 2 The corresponding steps are 201, 202, 203, 204, and 205 in the embodiment.
[0114] In some embodiments, the mask prediction network includes: a mask segmenter, a mask classifier, and a trained material segmenter and a material encoding dictionary; the mask calculation unit 503 is further configured to: input the material sequence into the material segmenter to obtain material block encoding; select material vectors that match the material block encoding from the material encoding dictionary to obtain a material vector sequence; input the mask sequence into the mask segmenter to obtain prediction block encoding; input the prediction block encoding into the mask classifier so that the mask classifier selects prediction vectors that match the prediction block encoding from the material encoding dictionary to obtain a prediction vector sequence; and calculate the mask prediction loss of the mask prediction network based on the material vector sequence and the prediction vector sequence.
[0115] In some embodiments, the training process of the above-mentioned material segmenter and material encoding dictionary is as follows: the acquired sample materials are serialized to obtain a sample sequence; the sample sequence is input into the material segmentation network to obtain a sample feature sequence; a sample encoding sequence corresponding to the sample feature sequence is selected from the sample encoding dictionary, and the sample encoding sequence is decoded to obtain a prediction sequence; the sample sequence is input into a pre-trained sample supervision model to obtain a supervision sequence; based on the prediction sequence and the supervision sequence, the material segmentation network and the material encoding dictionary are trained; in response to the material segmentation network meeting the training completion condition, the material segmenter is obtained.
[0116] In some embodiments, the supervised computation unit 504 is further configured to: input the overall semantic features with target labels and predicted by the mask prediction network into the linear classification network of the semi-supervised classification network to obtain the classification result output by the linear classification network; and calculate the supervised loss of the linear classification network based on the classification result and the target label.
[0117] In some embodiments, the training unit 505 is further configured to: determine the weight values of the supervised loss; multiply the supervised loss by the weight values and add them to the mask prediction loss to obtain the loss of the semi-supervised classification network; and train the semi-supervised classification network based on the loss of the semi-supervised classification network to obtain the semi-supervised classification model of the corresponding semi-supervised classification network.
[0118] The semi-supervised classification model training apparatus provided in the embodiments of this disclosure firstly obtains a material sequence by serializing the acquired material samples by unit 501; secondly, a masking unit 502 performs random masking on the material sequence to obtain a mask sequence including overall semantic features; thirdly, a masking calculation unit 503 inputs the material sequence and the mask sequence into the masking prediction network of a pre-constructed semi-supervised classification network to calculate the masking prediction loss of the masking prediction network; fourthly, a supervision calculation unit 504 inputs the overall semantic features with target labels and predicted by the masking prediction network into the linear classification network of the semi-supervised classification network to calculate the supervision loss of the linear classification network; and finally, a training unit 505 trains the semi-supervised classification network based on the masking prediction loss and the supervision loss to obtain a semi-supervised classification model corresponding to the semi-supervised classification network. Therefore, using a mask prediction network can focus on high-order semantics and global features related to downstream tasks, avoiding the waste of capacity in a semi-supervised classification model. By first training all the data with a mask prediction network and then fine-tuning a small amount of labeled data with a linear classification network, the downstream tasks are predicted in advance during the mask prediction modeling training to extract higher-order, task-related semantic information, thereby improving the classification performance of the semi-supervised classification model.
[0119] Further reference Figure 6As an implementation of the methods shown in the above figures, this disclosure provides an embodiment of a text classification device, which is similar to... Figure 4 Corresponding to the method embodiments shown, this device can be specifically applied to various electronic devices.
[0120] like Figure 6 As shown, an embodiment of this disclosure provides a text classification device 600, which includes an image acquisition unit 601, an image processing unit 602, and a target classification unit 603. The image acquisition unit 601 is configured to acquire an image to be classified. The image processing unit is configured to divide the image to be classified into blocks to obtain an image block sequence. The target classification unit 603 is configured to input the image block sequence into a semi-supervised classification model to obtain the classification result of the target in the image to be classified, output by the semi-supervised classification model.
[0121] In this embodiment, the semi-supervised classification model is trained using a semi-supervised classification model training device.
[0122] In this embodiment, the specific processing of the image acquisition unit 601, the image processing unit 602, and the target classification unit 603 in the text classification device 600, and the resulting technical effects, can be referred to respectively. Figure 4 The corresponding steps are 401, 402, and 403 in the embodiment.
[0123] The image classification apparatus provided in the embodiments of this disclosure firstly acquires an image to be classified by an image acquisition unit 601; secondly, an image processing unit 602 performs block processing on the image to be classified to obtain an image block sequence; finally, a target classification unit 603 inputs the image block sequence into a semi-supervised classification model to obtain the classification result of the target in the image to be classified, output by the semi-supervised classification model. Thus, by using a pre-trained semi-supervised classification model to identify the image to be classified and obtain the target classification result, the efficiency of image classification is improved.
[0124] The following is for reference. Figure 7 It shows a schematic diagram of the structure of an electronic device 700 suitable for implementing embodiments of the present disclosure.
[0125] like Figure 7As shown, the electronic device 700 may include a processing unit (e.g., a central processing unit, a graphics processor, etc.) 701, which can perform various appropriate actions and processes according to a program stored in a read-only memory (ROM) 702 or a program loaded from a storage device 708 into a random access memory (RAM) 703. The RAM 703 also stores various programs and data required for the operation of the electronic device 700. The processing unit 701, ROM 702, and RAM 703 are interconnected via a bus 704. An input / output (I / O) interface 705 is also connected to the bus 704.
[0126] Typically, the following devices can be connected to I / O interface 705: input devices 706 including, for example, touchscreens, touchpads, keyboards, mice, etc.; output devices 707 including, for example, liquid crystal displays (LCDs), speakers, vibrators, etc.; storage devices 708 including, for example, magnetic tapes, hard disks, etc.; and communication devices 709. Communication device 709 allows electronic device 700 to communicate wirelessly or wiredly with other devices to exchange data. Although Figure 7 An electronic device 700 with various devices is shown; however, it should be understood that it is not required to implement or possess all of the devices shown. More or fewer devices may be implemented or possessed alternatively. Figure 7 Each box shown can represent a device or multiple devices as needed.
[0127] In particular, according to embodiments of this disclosure, the processes described above with reference to the flowcharts can be implemented as computer software programs. For example, embodiments of this disclosure include a computer program product comprising a computer program carried on a computer-readable medium, the computer program containing program code for performing the methods shown in the flowcharts. In such embodiments, the computer program can be downloaded and installed from a network via communication device 709, or installed from storage device 708, or installed from ROM 702. When the computer program is executed by processing device 701, it performs the functions defined in the methods of embodiments of this disclosure.
[0128] It should be noted that the computer-readable medium in the embodiments of this disclosure may be a computer-readable signal medium or a computer-readable storage medium, or any combination thereof. A computer-readable storage medium may be, for example,—but not limited to—an electrical, magnetic, optical, electromagnetic, infrared, or semiconductor system, apparatus, or device, or any combination thereof. More specific examples of a computer-readable storage medium may include, but are not limited to: an electrical connection having one or more wires, a portable computer disk, a hard disk, random access memory (RAM), read-only memory (ROM), erasable programmable read-only memory (EPROM or flash memory), optical fiber, portable compact disk read-only memory (CD-ROM), optical storage device, magnetic storage device, or any suitable combination thereof. In the embodiments of this disclosure, a computer-readable storage medium may be any tangible medium containing or storing a program that can be used by or in conjunction with an instruction execution system, apparatus, or device. In the embodiments of this disclosure, a computer-readable signal medium may include a data signal propagated in baseband or as part of a carrier wave, carrying computer-readable program code. Such propagated data signals may take various forms, including but not limited to electromagnetic signals, optical signals, or any suitable combination thereof. A computer-readable signal medium may be any computer-readable medium other than a computer-readable storage medium, which can send, propagate, or transmit a program for use by or in connection with an instruction execution system, apparatus, or device. The program code contained on the computer-readable medium can be transmitted using any suitable medium, including but not limited to: wires, optical fibers, RF (Radio Frequency), etc., or any suitable combination thereof.
[0129] The aforementioned computer-readable medium may be included in the aforementioned server; or it may exist independently and not assembled into the server. The aforementioned computer-readable medium carries one or more programs, which, when executed by the server, cause the server to: serialize the acquired material samples to obtain a material sequence; perform random masking on the material sequence to obtain a mask sequence including overall semantic features; input the material sequence and the mask sequence into a pre-constructed semi-supervised classification network's mask prediction network to calculate the mask prediction loss of the mask prediction network; input the overall semantic features with target labels and predicted by the mask prediction network into a linear classification network of the semi-supervised classification network to calculate the supervised loss of the linear classification network; and train the semi-supervised classification network based on the mask prediction loss and the supervised loss to obtain a semi-supervised classification model for the corresponding semi-supervised classification network.
[0130] Computer program code for performing the operations of embodiments of this disclosure can be written in one or more programming languages or a combination thereof. Programming languages include object-oriented programming languages—such as Java, Smalltalk, and C++—and conventional procedural programming languages—such as the "C" language or similar programming languages. The program code can be executed entirely on the user's computer, partially on the user's computer, as a standalone software package, partially on the user's computer and partially on a remote computer, or entirely on a remote computer or server. In cases involving remote computers, the remote computer can be connected to the user's computer via any type of network—including a local area network (LAN) or a wide area network (WAN)—or can be connected to an external computer (e.g., via the Internet using an Internet service provider).
[0131] The flowcharts and block diagrams in the accompanying drawings illustrate the architecture, functionality, and operation of possible implementations of systems, methods, and computer program products according to various embodiments of the present disclosure. In this regard, each block in a flowchart or block diagram may represent a module, segment, or portion of code containing one or more executable instructions for implementing a specified logical function. It should also be noted that in some alternative implementations, the functions indicated in the blocks may occur in a different order than those indicated in the drawings. For example, two consecutively indicated blocks may actually be executed substantially in parallel, and they may sometimes be executed in reverse order, depending on the functions involved. It should also be noted that each block in the block diagrams and / or flowcharts, and combinations of blocks in the block diagrams and / or flowcharts, may be implemented using a dedicated hardware-based system that performs the specified function or operation, or using a combination of dedicated hardware and computer instructions.
[0132] The units described in the embodiments of this disclosure can be implemented in software or hardware. The described units can also be housed in a processor; for example, a processor may be described as including an acquisition unit, a mask processing unit, a mask calculation unit, a supervisory calculation unit, a loss determination unit, and a training unit. The names of these units do not necessarily limit the specific unit; for example, the acquisition unit may also be described as a unit "configured to perform serialization processing on acquired material samples to obtain a material sequence."
[0133] The above description is merely a preferred embodiment of this disclosure and an explanation of the technical principles employed. Those skilled in the art should understand that the scope of the invention involved in the embodiments of this disclosure is not limited to technical solutions formed by specific combinations of the above-described technical features, but should also cover other technical solutions formed by arbitrary combinations of the above-described technical features or their equivalents without departing from the above-described inventive concept. For example, technical solutions formed by substituting the above-described features with (but not limited to) technical features with similar functions disclosed in the embodiments of this disclosure.
Claims
1. A semi-supervised classification model training method, the method comprising: The acquired material samples are serialized to obtain a material sequence, wherein the material includes images or text; The material sequence is subjected to random masking to obtain a mask sequence that includes overall semantic features; The material sequence and the mask sequence are input into the mask prediction network of a pre-constructed semi-supervised classification network, and the mask prediction loss of the mask prediction network is calculated. The mask prediction network is used to encode the material sequence, predict the mask material in the mask sequence, and calculate the mask prediction loss based on the prediction result and the encoding result. The overall semantic features with target labels and predicted by the mask prediction network are input into the linear classification network of the semi-supervised classification network, and the supervised loss of the linear classification network is calculated. Based on the mask prediction loss and the supervision loss, the semi-supervised classification network is trained to obtain a semi-supervised classification model corresponding to the semi-supervised classification network.
2. The method according to claim 1, wherein, The mask prediction network includes: a mask segmenter, a mask classifier, a trained material segmenter, and a material encoding dictionary; the process of inputting the material sequence and the mask sequence into the pre-constructed semi-supervised classification network of the mask prediction network, and calculating the mask prediction loss of the mask prediction network includes: The material sequence is input into the material segmenter to obtain the material block encoding; Select the material vector that matches the material block code from the material coding dictionary to obtain the material vector sequence; The mask sequence is input into the mask word segmenter to obtain the predicted block code; The predicted block code is input into the mask classifier, so that the mask classifier selects a predicted vector that matches the predicted block code from the material encoding dictionary, thereby obtaining a sequence of predicted vectors; Based on the material vector sequence and the prediction vector sequence, the mask prediction loss of the mask prediction network is calculated.
3. The method according to claim 2, wherein the training process of the material segmenter and the material encoding dictionary is as follows: The acquired sample materials are serialized to obtain a sample sequence; The sample sequence is input into the material segmentation network to obtain the sample feature sequence. The sample encoding sequence corresponding to the sample feature sequence is selected from the sample encoding dictionary, and the sample encoding sequence is decoded to obtain the prediction sequence. The sample sequence is input into a pre-trained sample supervision model to obtain a supervision sequence; Based on the predicted sequence and the supervised sequence, train the material segmentation network and the material encoding dictionary; In response to the material segmentation network meeting the training completion conditions, a material segmenter is obtained.
4. The method according to claim 1, wherein the step of inputting the overall semantic features with target labels and predicted by the mask prediction network into the linear classification network of the semi-supervised classification network, and calculating the supervised loss of the linear classification network, includes: The overall semantic features, including the target label and predicted by the mask prediction network, are input into the linear classification network of the semi-supervised classification network to obtain the classification result output by the linear classification network. Based on the classification results and the target label, the supervised loss of the linear classification network is calculated.
5. The method according to any one of claims 1-4, wherein, The step of training the semi-supervised classification network based on the mask prediction loss and the supervised loss to obtain the semi-supervised classification model corresponding to the semi-supervised classification network includes: Determine the weight value of the supervision loss; The loss of the semi-supervised classification network is obtained by multiplying the supervised loss by the weight value and then adding it to the mask prediction loss. Based on the loss of the semi-supervised classification network, the semi-supervised classification network is trained to obtain a semi-supervised classification model corresponding to the semi-supervised classification network.
6. An image classification method, the method comprising: Obtain the image to be classified; The image to be classified is divided into blocks to obtain an image block sequence; The image patch sequence is input into a semi-supervised classification model to obtain the classification result of the target in the image to be classified, which is output by the semi-supervised classification model. The semi-supervised classification model is trained using the semi-supervised classification model training method described in any one of claims 1-5.
7. A semi-supervised classification model training device, the device comprising: The unit is configured to serialize the acquired material samples to obtain a material sequence, wherein the material includes images or text; The masking unit is configured to perform random masking on the material sequence to obtain a mask sequence that includes overall semantic features; A mask calculation unit is configured to input the material sequence and the mask sequence into a mask prediction network of a pre-constructed semi-supervised classification network, and calculate the mask prediction loss of the mask prediction network. The mask prediction network is used to encode the material sequence, predict the mask material in the mask sequence, and calculate the mask prediction loss based on the prediction result and the encoding result. The supervised computation unit is configured to input the overall semantic features with target labels and predicted by the mask prediction network into the linear classification network of the semi-supervised classification network, and calculate the supervised loss of the linear classification network. The training unit is configured to train the semi-supervised classification network based on the mask prediction loss and the supervision loss, thereby obtaining a semi-supervised classification model corresponding to the semi-supervised classification network.
8. An image classification device, the device comprising: The image acquisition unit is configured to acquire the image to be classified. An image processing unit is configured to perform block processing on the image to be classified to obtain an image block sequence; The target classification unit is configured to input the image patch sequence into a semi-supervised classification model to obtain the classification result of the target in the image to be classified output by the semi-supervised classification model, wherein the semi-supervised classification model is trained using the semi-supervised classification model training device of claim 7.
9. An electronic device, comprising: One or more processors; A storage device on which one or more programs are stored; When the one or more programs are executed by the one or more processors, the one or more processors implement the method as described in any one of claims 1-6.
10. A computer-readable medium having a computer program stored thereon, wherein, When the program is executed by the processor, it implements the method as described in any one of claims 1-6.