A neural network training method and related apparatus
By introducing an intermediate transformation network between the teacher network and the student network, utilizing the alignment features of the expansion and contraction modules, and combining homomorphic transformation and network structure search, the problem of poor prediction performance of the student network is solved, achieving higher prediction accuracy and lower transformation overhead.
Patent Information
- Authority / Receiving Office
- CN · China
- Patent Type
- Patents(China)
- Current Assignee / Owner
- HUAWEI TECH CO LTD
- Filing Date
- 2021-07-20
- Publication Date
- 2026-06-16
Smart Images

Figure CN115640831B_ABST
Abstract
Description
Technical Field
[0001] This application relates to the field of computer technology, and in particular to a neural network training method and related apparatus. Background Technology
[0002] Knowledge distillation uses the "knowledge" from a large teacher network to guide the training of a small student network, improving the performance of the smaller model and thus indirectly compressing the model. For example... Figure 1 As shown, from the perspective of the location of knowledge distillation, distillation techniques can be broadly divided into two categories: output layer (Logits layer) knowledge distillation and network layer knowledge distillation, which will be introduced below.
[0003] Logits layer knowledge distillation is typically used for tasks such as classification and recognition. Figure 2 As shown, Logits refer to the output of the last layer in the network. Both the teacher network and the student network output their own Logits, which directly affect the final classification result. Logits-layer knowledge distillation methods typically use the teacher network's Logits to guide the training of the student network. This is usually implemented using the Kullback-Leibler Divergence Loss (KL) function to measure the difference between the teacher network's Logits and the student network's Logits. Besides Logits-layer knowledge distillation and its variations, a new type of knowledge distillation based on relationships has emerged. This type utilizes not only the soft labels of individual samples but also the relationships between samples for distillation. However, Logits-layer knowledge distillation has poor generalization ability and is often only applicable to tasks such as classification and recognition.
[0004] Network layer knowledge distillation is another type of method where the distilled "knowledge" originates from the feature outputs of network layers (such as layer_t0, layer_t1, layer_ti, etc.). Generally, because the teacher network is much larger than the student network, the size and shape of the features in the teacher and student network layers are inconsistent, making direct loss function calculation impossible. Therefore, this type of method usually requires feature transformation of both the student and teacher networks. The purpose of the transformation is to ensure that the two features corresponding to the loss function calculation have the same shape and size. Then, a loss function is used to measure the difference in features between the teacher and student network layers, thereby achieving the goal of using the teacher network to guide the training of the student network. Initially, network layer knowledge distillation was based directly on raw intermediate features; later, methods based on attention mechanisms were developed, but their basic framework all follow the same principles. Figure 3The process involves network layer knowledge distillation, which has good generalization ability. However, an intermediate transformation network needs to be added between the teacher network layer and the student network layer to transform the features output by both network layers to make them the same size, so as to facilitate the calculation of the loss function. After the student network distillation is completed, the intermediate transformation layer is discarded.
[0005] The problem with existing technology is that the student network's prediction performance is poor because the intermediate transformation network is discarded. Summary of the Invention
[0006] This application discloses a network training method and related apparatus that can improve the prediction accuracy of student networks.
[0007] In a first aspect, embodiments of this application provide a neural network training method, the method comprising:
[0008] A first feature is obtained, which is obtained by processing the input data through the first network layer of the student network. The input data includes any one or more of the following data: image data, audio data, and text data.
[0009] The second feature is obtained by processing the input data through the second network layer of the teacher network. The second network layer is the network layer in the teacher network that corresponds to the first network layer in the student network. Optionally, the correspondence between the first network layer and the second network layer means, for example, that the function of the first network layer in the entire student network is the same as the function of the second network layer in the teacher network, or that the position or order of the first network layer in the entire student network is the same as the position or order of the second network layer in the teacher network, or that the structure of the feature output by the first network layer is the same as the structure of the feature output by the second network layer. Of course, there are other situations, which will not be listed here.
[0010] An intermediate transformation network is trained according to a first loss function. The intermediate transformation network includes an expansion module and a contraction module. The expansion module is used to convert the first feature into a third feature, and the third feature is aligned with the second feature. The first loss function is used to measure the difference between the third feature and the second feature.
[0011] The shrinking module is used to convert the third feature into a fourth feature, which is aligned with the first feature. It should be noted that the feature alignment mentioned in this application embodiment refers to the shape alignment of the feature map, such as the feature map having the same number of channels, length, bandwidth, height, etc.
[0012] In the above method, the expansion module can align the features output by the student network with the features output by the teacher network as much as possible, so as to make full use of the knowledge learned by the teacher network to improve the accuracy of the intermediate transformation network and the student network. The contraction module can shrink the expanded feature map back to its original size, which can ensure that the intermediate transformation network can be seamlessly integrated into the student network after training, making the student network perform better when applied.
[0013] In conjunction with the first aspect, in one possible implementation of the first aspect, training the intermediate transformation network according to the first loss function includes:
[0014] The intermediate transformation network is iteratively subjected to multiple homomorphic transformations until the first loss function no longer decreases.
[0015] In this embodiment, since the homomorphic transformation process is essentially an automatic network growth process, the resulting target network is generally no longer a single linear layer. Therefore, the first loss function can be further reduced based on the target network. When the target network is subsequently fused into the student network, the fused student network achieves a higher fit with the teacher network and higher prediction accuracy. Furthermore, a stopping mechanism is implemented for the homomorphic transformation process, ensuring the transformation effect while minimizing the various overheads incurred during the transformation process.
[0016] In conjunction with the first aspect, or any of the above possible implementations of the first aspect, in yet another possible implementation of the first aspect, one of the multiple homomorphic transformations includes:
[0017] A homomorphic transformation search is performed based on a preset network structure search space to obtain a first target network equivalent to the expansion module in the intermediate transformation network, thereby updating the expansion module; and / or a homomorphic transformation search is performed based on a preset network structure search space to obtain a second target network equivalent to the contraction module in the intermediate transformation network, thereby updating the contraction module.
[0018] It is understandable that the expansion module and the contraction module are two separate parts that perform two different functions. Therefore, these two modules can also be trained separately, that is, they can each be subjected to homomorphic transformation search. Of course, one module can be updated using homomorphic transformation search, while the other module is updated using other methods.
[0019] In conjunction with the first aspect, or any of the possible implementations of the first aspect described above, in another possible implementation of the first aspect, the network structure search space includes multiple network structures, all of which are convolutional layers with a kernel size of 1x1, and all can be integrated with the student network. It can be understood that since multiple network structures are convolutional layers with a kernel size of 1x1, homomorphic transformation can ensure the mathematical equivalence of the networks before and after the transformation as much as possible, reducing information loss. Furthermore, it also facilitates better integration of the homomorphically transformed network into the student network.
[0020] In conjunction with the first aspect, or any of the above-mentioned possible implementations of the first aspect, another possible implementation of the first aspect further includes:
[0021] The intermediate transformation network is integrated into the student network. It can be understood that integrating this intermediate transformation network into the student network preserves as much key information as possible related to the distillation process, thereby improving the prediction accuracy of the student network.
[0022] In conjunction with the first aspect, or any of the above possible implementations of the first aspect, in yet another possible implementation of the first aspect, the integration of the intermediate transformation network into the student network includes:
[0023] The intermediate transformation network updates the weights of the first network layer or the next layer of the student network. It can be understood that this intermediate transformation network is trained based on the relevant information of the first network layer. Therefore, updating the weights of the first network layer or the next layer of the first network layer through the intermediate transformation network effectively ensures the representation of information in the intermediate transformation network, thereby improving the prediction accuracy of the student network. It should be noted that after inserting the intermediate transformation network into the first network layer or the next layer of the first network layer, the network structure of the student network does not change.
[0024] In conjunction with the first aspect, or any of the above possible implementations of the first aspect, in yet another possible implementation of the first aspect, the integration of the intermediate transformation network into the student network includes:
[0025] The intermediate transformation network is fused into a target network layer, and this target network layer is inserted between the first network layer and the next network layer of the student network. It can be understood that this intermediate transformation network is trained based on the relevant information of the first network layer. Therefore, fusing it into a target network layer and inserting it between the first network layer and the next network layer effectively ensures the representation of information in the intermediate transformation network, thereby guaranteeing the prediction accuracy of the student network. It should be noted that inserting the target network layer between the first network layer and the next network layer changes the structure of the student network, essentially adding an extra target network layer.
[0026] In conjunction with the first aspect, or any of the above-mentioned possible implementations of the first aspect, in yet another possible implementation of the first aspect...
[0027] The number of branches in the target network is greater than the number of branches in the current intermediate transformation network, and / or,
[0028] The target network has a width greater than the width of the current intermediate transformation network, and / or the target network has a depth greater than the depth of the current intermediate transformation network.
[0029] It is understandable that because the number of branches, width, or depth of the target network increases, the number of parameters in the target network increases, thus enhancing the fitting ability of the target network.
[0030] In conjunction with the first aspect, or any of the above-mentioned possible implementations of the first aspect, another possible implementation of the first aspect further includes:
[0031] The student network, after being incorporated into the intermediate transformation network, is trained according to the second loss function.
[0032] In other words, by integrating the intermediate transformation network into the student network and then continuing to train the student network, the negative impact of the integration can be corrected, thereby further improving the prediction accuracy of the student network.
[0033] In conjunction with the first aspect, or any of the above possible implementations of the first aspect, in another possible implementation of the first aspect, the teacher network is obtained by enlarging (e.g., widening or deepening) the student network, or the teacher network is a heterogeneous network with a different structure from the student network.
[0034] Secondly, embodiments of this application provide a neural network training apparatus, the apparatus comprising:
[0035] The first acquisition unit is used to acquire a first feature, which is obtained by processing the input data through the first network layer of the student network. The input data includes any one or more of the following data: image data, audio data, and text data.
[0036] The second acquisition unit is used to acquire a second feature, which is obtained by processing the input data through the second network layer of the teacher network. The second network layer is the network layer in the teacher network that corresponds to the first network layer of the student network. Optionally, the correspondence between the first network layer and the second network layer means, for example, that the function of the first network layer in the entire student network is the same as the function of the second network layer in the teacher network, or that the structure of the feature output by the first network layer is the same as the structure of the feature output by the second network layer. Of course, there are other situations, which will not be listed here.
[0037] The first training unit is used to train an intermediate transformation network according to a first loss function. The intermediate transformation network includes an expansion module and a contraction module. The expansion module is used to convert the first feature into a third feature, the third feature being aligned with the second feature. The first loss function is used to measure the difference between the third feature and the second feature.
[0038] The shrinking module is used to convert the third feature into a fourth feature, and the fourth feature is aligned with the first feature; it should be noted that the feature alignment mentioned in the embodiments of this application refers to the shape alignment of the feature map.
[0039] In the above device, the expansion module can align the features output by the student network with the features output by the teacher network as much as possible, so as to make full use of the knowledge learned by the teacher network to improve the accuracy of the intermediate transformation network and the student network. The shrinking module can shrink the expanded feature map back to its original size, which can ensure that the intermediate transformation network can be seamlessly integrated into the student network after training, making the student network more effective when applied.
[0040] In conjunction with the second aspect, in one possible implementation of the second aspect, in training the intermediate transformation network according to the first loss function, the first training unit is specifically used for:
[0041] The intermediate transformation network is iteratively subjected to multiple homomorphic transformations until the first loss function no longer decreases.
[0042] In this embodiment, since the homomorphic transformation process is essentially an automatic network growth process, the resulting target network is generally no longer a single linear layer. Therefore, the first loss function can be further reduced based on the target network. When the target network is subsequently fused into the student network, the fused student network achieves a higher fit with the teacher network and higher prediction accuracy. Furthermore, a stopping mechanism is implemented for the homomorphic transformation process, ensuring the transformation effect while minimizing the various overheads incurred during the transformation process.
[0043] In conjunction with the second aspect, or any of the above possible implementations of the second aspect, in yet another possible implementation of the second aspect, regarding one homomorphic transformation among multiple homomorphic transformations, the first training unit is specifically used for:
[0044] A homomorphic transformation search is performed based on a preset network structure search space to obtain a first target network equivalent to the expansion module in the intermediate transformation network, thereby updating the expansion module; and / or a homomorphic transformation search is performed based on a preset network structure search space to obtain a second target network equivalent to the contraction module in the intermediate transformation network, thereby updating the contraction module.
[0045] It is understandable that the expansion module and the contraction module are two separate parts that perform two different functions. Therefore, these two modules can also be trained separately, that is, they can each be subjected to homomorphic transformation search. Of course, one module can be updated using homomorphic transformation search, while the other module is updated using other methods.
[0046] In conjunction with the second aspect, or any of the possible implementations of the second aspect described above, in yet another possible implementation of the second aspect, the network structure search space includes multiple network structures, all of which are convolutional layers with a kernel size of 1x1, and all can be integrated with the student network. It can be understood that since multiple network structures are convolutional layers with a kernel size of 1x1, homomorphic transformation can ensure the mathematical equivalence of the networks before and after the transformation as much as possible, reducing information loss. Furthermore, it also facilitates better integration of the homomorphically transformed network into the student network.
[0047] In conjunction with the second aspect, or any of the above-mentioned possible implementations of the second aspect, another possible implementation of the second aspect further includes:
[0048] The integration unit is used to integrate the intermediate transformation network into the student network. It can be understood that integrating the intermediate transformation network into the student network preserves as much key information as possible involved in the distillation process, thereby improving the prediction accuracy of the student network.
[0049] In conjunction with the second aspect, or any of the above possible implementations of the second aspect, in yet another possible implementation of the second aspect, in integrating the intermediate transformation network into the student network, the integration unit is specifically used for:
[0050] The intermediate transformation network updates the weights of the first network layer or the next layer of the student network. It can be understood that this intermediate transformation network is trained based on the relevant information of the first network layer. Therefore, updating the weights of the first network layer or the next layer of the first network layer through the intermediate transformation network effectively ensures the representation of information in the intermediate transformation network, thereby improving the prediction accuracy of the student network. It should be noted that after inserting the intermediate transformation network into the first network layer or the next layer of the first network layer, the network structure of the student network does not change.
[0051] In conjunction with the second aspect, or any of the above possible implementations of the second aspect, in yet another possible implementation of the second aspect, in integrating the intermediate transformation network into the student network, the integration unit is specifically used for:
[0052] The intermediate transformation network is fused into a target network layer, and this target network layer is inserted between the first network layer and the next network layer of the student network. It can be understood that this intermediate transformation network is trained based on the relevant information of the first network layer. Therefore, fusing it into a target network layer and inserting it between the first network layer and the next network layer effectively ensures the representation of information in the intermediate transformation network, thereby guaranteeing the prediction accuracy of the student network. It should be noted that inserting the target network layer between the first network layer and the next network layer changes the structure of the student network, essentially adding an extra target network layer.
[0053] In conjunction with the second aspect, or any of the above possible implementations of the second aspect, in yet another possible implementation of the second aspect,
[0054] The number of branches in the target network is greater than the number of branches in the current intermediate transformation network, and / or,
[0055] The target network has a width greater than the width of the current intermediate transformation network, and / or the target network has a depth greater than the depth of the current intermediate transformation network.
[0056] It is understandable that because the number of branches, width, or depth of the target network increases, the number of parameters in the target network increases, thus enhancing the fitting ability of the target network.
[0057] In conjunction with the second aspect, or any of the above-mentioned possible implementations of the second aspect, another possible implementation of the second aspect further includes:
[0058] The second training unit is used to train the student network incorporated into the intermediate transformation network according to the second loss function.
[0059] In other words, by integrating the intermediate transformation network into the student network and then continuing to train the student network, the negative impact of the integration can be corrected, thereby further improving the prediction accuracy of the student network.
[0060] In conjunction with the second aspect, or any of the above possible implementations of the second aspect, in another possible implementation of the second aspect, the teacher network is obtained by enlarging (e.g., widening or deepening) the student network, or the teacher network is a heterogeneous network with a different structure from the student network.
[0061] Thirdly, embodiments of this application provide a neural network training method, the method comprising:
[0062] The system receives an intermediate transformation network sent by a server. This intermediate transformation network includes an expansion module and a contraction module. The expansion module converts the first feature into a third feature, and the contraction module converts the third feature into a fourth feature. The fourth feature is aligned with the first feature, and the third feature is aligned with the second feature. The first feature is obtained by processing input data through a first network layer of a student network. The input data includes any one or more of the following: image data, audio data, and text data. The second feature is obtained by processing input data through a second network layer of a teacher network. The second network layer is a network layer in the teacher network corresponding to the first network layer of the student network. The intermediate transformation network is then integrated into the student network.
[0063] It should be noted that, specifically, this can refer to a user device that uses the student network to receive the transformation network sent by the server. After integrating the intermediate transformation network into the student network, it can perform corresponding prediction operations based on the student network. This user device can also be called a client device, such as a mobile phone, computer, smart car, robot, etc.
[0064] In the above method, the expansion module can align the features output by the student network with the features output by the teacher network as much as possible, so as to make full use of the knowledge learned by the teacher network to improve the accuracy of the intermediate transformation network and the student network. The contraction module can shrink the expanded feature map back to its original size, which can ensure that the intermediate transformation network can be seamlessly integrated into the student network after training, making the student network perform better when applied.
[0065] In conjunction with the third aspect, in one optional implementation of the third aspect, integrating the intermediate transformation network into the student network includes:
[0066] The intermediate transformation network updates the weights of the first network layer or the next layer of the student network. It can be understood that this intermediate transformation network is trained based on the relevant information of the first network layer. Therefore, updating the weights of the first network layer or the next layer of the first network layer through the intermediate transformation network effectively ensures the representation of information in the intermediate transformation network, thereby improving the prediction accuracy of the student network. It should be noted that after inserting the intermediate transformation network into the first network layer or the next layer of the first network layer, the network structure of the student network does not change.
[0067] In conjunction with the third aspect, or any of the above possible implementations of the third aspect, in yet another optional implementation of the third aspect, the integration of the intermediate transformation network into the student network includes:
[0068] The intermediate transformation network is fused into a target network layer, and this target network layer is inserted between the first network layer and the next network layer of the student network. It can be understood that this intermediate transformation network is trained based on the relevant information of the first network layer. Therefore, fusing it into a target network layer and inserting it between the first network layer and the next network layer effectively ensures the representation of information in the intermediate transformation network, thereby guaranteeing the prediction accuracy of the student network. It should be noted that inserting the target network layer between the first network layer and the next network layer changes the structure of the student network, essentially adding an extra target network layer.
[0069] In conjunction with the third aspect, or any of the above-mentioned possible implementations of the third aspect, another optional implementation of the third aspect further includes:
[0070] The student network, after being incorporated into the intermediate transformation network, is trained according to the second loss function.
[0071] In other words, by integrating the intermediate transformation network into the student network and then continuing to train the student network, the negative impact of the integration can be corrected, thereby further improving the prediction accuracy of the student network.
[0072] Fourthly, embodiments of this application provide a neural network training apparatus, the apparatus comprising:
[0073] A receiving unit is configured to receive an intermediate transformation network sent by a server. The intermediate transformation network includes an expansion module and a contraction module. The expansion module converts the first feature into a third feature, and the contraction module converts the third feature into a fourth feature. The fourth feature is aligned with the first feature, and the third feature is aligned with the second feature. The first feature is obtained by processing input data through a first network layer of a student network. The input data includes any one or more of the following: image data, audio data, and text data. The second feature is obtained by processing input data through a second network layer of a teacher network. The second network layer is a network layer in the teacher network corresponding to the first network layer of the student network.
[0074] An integration unit is used to integrate the intermediate transformation network into the student network.
[0075] It should be noted that, specifically, this can refer to a user device that uses the student network to receive the transformation network sent by the server. After integrating the intermediate transformation network into the student network, it can perform corresponding prediction operations based on the student network. This user device can also be called a client device, such as a mobile phone, computer, smart car, robot, etc.
[0076] In the above device, the expansion module can align the features output by the student network with the features output by the teacher network as much as possible, so as to make full use of the knowledge learned by the teacher network to improve the accuracy of the intermediate transformation network and the student network. The shrinking module can shrink the expanded feature map back to its original size, which can ensure that the intermediate transformation network can be seamlessly integrated into the student network after training, making the student network more effective when applied.
[0077] In conjunction with the fourth aspect, in one alternative implementation of the fourth aspect, regarding the integration of the intermediate transformation network into the student network, the integration unit is specifically used for:
[0078] The intermediate transformation network updates the weights of the first network layer or the next layer of the student network. It can be understood that this intermediate transformation network is trained based on the relevant information of the first network layer. Therefore, updating the weights of the first network layer or the next layer of the first network layer through the intermediate transformation network effectively ensures the representation of information in the intermediate transformation network, thereby improving the prediction accuracy of the student network. It should be noted that after inserting the intermediate transformation network into the first network layer or the next layer of the first network layer, the network structure of the student network does not change.
[0079] In conjunction with the fourth aspect, or any of the above possible implementations of the fourth aspect, in yet another optional implementation of the fourth aspect, in terms of integrating the intermediate transformation network into the student network, the integration unit is specifically used for:
[0080] The intermediate transformation network is fused into a target network layer, and this target network layer is inserted between the first network layer and the next network layer of the student network. It can be understood that this intermediate transformation network is trained based on the relevant information of the first network layer. Therefore, fusing it into a target network layer and inserting it between the first network layer and the next network layer effectively ensures the representation of information in the intermediate transformation network, thereby guaranteeing the prediction accuracy of the student network. It should be noted that inserting the target network layer between the first network layer and the next network layer changes the structure of the student network, essentially adding an extra target network layer.
[0081] In conjunction with the fourth aspect, or any of the above possible implementations of the fourth aspect, in yet another optional implementation of the fourth aspect, the apparatus further includes:
[0082] A training unit is used to train the student network incorporated into the intermediate transformation network according to the second loss function.
[0083] In other words, by integrating the intermediate transformation network into the student network and then continuing to train the student network, the negative impact of the integration can be corrected, thereby further improving the prediction accuracy of the student network.
[0084] Fifthly, embodiments of this application provide a neural network training device, which includes a processor and a memory, wherein:
[0085] The memory is used to store computer programs;
[0086] The processor is used to invoke the computer program to implement the method described in the first aspect or any possible implementation of the first aspect.
[0087] Sixthly, embodiments of this application provide a neural network training device, which includes a processor and a memory, wherein:
[0088] The memory is used to store computer programs;
[0089] The processor is used to invoke the computer program to implement the method described in the third aspect or any possible implementation of the third aspect.
[0090] In a seventh aspect, embodiments of this application provide a computer-readable storage medium storing a computer program that, when run on a processor, implements the method described in the first aspect, or any possible implementation of the first aspect, or the third aspect, or any possible implementation of the third aspect.
[0091] Eighthly, embodiments of this application provide a computer program product that, when running on a processor, implements the method described in the first aspect, or any possible implementation of the first aspect, or the third aspect, or any possible implementation of the third aspect.
[0092] By implementing the embodiments of this application, the expansion module can align the features output by the student network with the features output by the teacher network as much as possible, so as to make full use of the knowledge learned by the teacher network to improve the accuracy of the intermediate transformation network and the student network. The contraction module can shrink the expanded feature map back to its original size, which can ensure that the intermediate transformation network can be seamlessly integrated into the student network after training, making the student network more effective when applied. Attached Figure Description
[0093] The accompanying drawings used in the embodiments of this application are described below.
[0094] Figure 1 This is a schematic diagram illustrating a classification of existing distillation techniques;
[0095] Figure 2 This is a schematic diagram of a distillation network in the prior art;
[0096] Figure 3 This is a schematic diagram of a distillation network in the prior art;
[0097] Figure 4 This is a schematic diagram of a network distillation architecture provided in an embodiment of this application;
[0098] Figure 5 This is a schematic diagram of the structure of a distillation network provided in an embodiment of this application;
[0099] Figure 6 This is a flowchart illustrating a neural network training method provided in an embodiment of this application;
[0100] Figure 7A This is a schematic diagram of an intermediate transformation network provided in an embodiment of this application;
[0101] Figure 7B This is a schematic diagram of the structure of the 1*1 convolution kernel provided in this application;
[0102] Figure 8This is a schematic diagram of a homomorphic transformation scenario provided in an embodiment of this application;
[0103] Figure 9 This is a schematic diagram of another homomorphic transformation scenario provided in the embodiments of this application;
[0104] Figure 10 This is a schematic diagram of another homomorphic transformation scenario provided in the embodiments of this application;
[0105] Figure 11 This is a schematic flowchart of a network distillation method provided in an embodiment of this application;
[0106] Figure 12 This is a schematic diagram of the structure of a distillation network provided in an embodiment of this application;
[0107] Figure 13 This is a schematic diagram of another distillation network provided in the embodiments of this application;
[0108] Figure 14 This is a schematic diagram of the structure of a neural network training device provided in an embodiment of this application;
[0109] Figure 15 This is a schematic diagram of the structure of a neural network training device provided in an embodiment of this application. Detailed Implementation
[0110] The embodiments of this application are described below with reference to the accompanying drawings.
[0111] Deep neural networks are widely used in applications such as mobile terminals, surveillance videos, and cloud services. For example, the camera functions of terminals (such as Huawei's flagship P and Mate series phones) incorporate artificial intelligence (AI) enhancements, significantly improving image quality and enhancing photo and video effects. This functionality is primarily achieved through deep neural networks and is an essential technology in high-end mobile phones. In smart city applications, some infrared cameras use deep neural networks to fuse RGB and infrared images, enhancing imaging effects. Furthermore, photo albums in terminals (such as mobile phones) often feature AI-powered automatic classification functions, which also rely on deep neural networks for classification. In these applications, deep neural networks are typically deployed on the edge (i.e., dedicated chips in the phone or camera), creating a strong demand for miniaturization. The embodiments of this application can be applied to these scenarios (but are not limited to them). In general, the solutions in the embodiments of this application can be applied to any process where a teacher network guides a student network for training.
[0112] Please see Figure 4 , Figure 4This is a schematic diagram of a network distillation architecture provided in an embodiment of this application. In one scenario, it includes a neural network training device 401 and one or more network user devices 402. The neural network training device 401 and the network user devices 402 communicate via wired (e.g., Ethernet, Universal Serial Bus, etc.) or wireless (e.g., Wireless Local Area Network WLAN, Wi-Fi, Bluetooth, Near Field Communication (NRC), infrared communication, 2G / 3G / 4G / 5G, etc.) means. Therefore, the neural network training device 401 can send the trained network (or model) to the network user devices 402; correspondingly, the network user devices 402 perform task prediction using the received network. In another scenario, device 401 and at least one device 402 can perform collaborative training, and the trained network (or model) is used by at least one device 402 to process specific tasks.
[0113] Optionally, the network using device 402 can feed back the prediction results based on the network to the aforementioned device 401, so that device 401 can further train the network based on the prediction results of device 402; the retrained network can be sent to device 402 to update the original network. The network using device 402 is a device that needs to perform prediction tasks, such as handheld devices (e.g., mobile phones, tablets, PDAs, etc.), in-vehicle devices (e.g., cars, bicycles, electric vehicles, airplanes, ships, etc.), wearable devices (e.g., smartwatches (such as iWatch, smart bracelets, pedometers, etc.), smart home devices (e.g., refrigerators, televisions, air conditioners, electricity meters, etc.), smart robots, workshop equipment, etc.
[0114] The neural network training device 401 can be a device with strong computing power, such as a server or a server cluster consisting of multiple servers. The neural network training device 401 can include many neural networks. Generally, network distillation scenarios include teacher networks and student networks. The teacher network has a more complex structure than the student network, typically exhibiting one or more of the following characteristics: more layers, wider width, greater depth, or a different structure but with a larger number of parameters. Therefore, the teacher network generally has a larger capacity and stronger data representation ability than the student network. The process of using the teacher network to guide the training of the student network can be called network distillation.
[0115] When using a teacher network to guide the training of a student network, the teacher network can be deployed in device 401, and device 401 trains the teacher network. The student network can be deployed in any one or more devices 402, and device 402 trains the student network. In one embodiment, device 401 uses a larger dataset when training the teacher network, while device 402 uses a smaller dataset when training the student network (e.g., each device 402 uses locally collected data to train the student network, while device 401 uses data from multiple devices 402 to train the teacher network).
[0116] like Figure 5 As shown, intermediate transformation networks are set between the network layers of the teacher network and the network layers of the student network provided in this application. For example, an intermediate transformation network R_0 is set between the network layer layer_t0 of the teacher network and the network layer layer_s0 of the student network; an intermediate transformation network R_1 is set between the network layer layer_t1 of the teacher network and the network layer layer_s1 of the student network; and an intermediate transformation network R_i is set between the network layer layer_ti of the teacher network and the network layer layer_si of the student network. For any intermediate transformation network, the intermediate transformation network includes an expansion module and a contraction module. The expansion module is used to convert the first feature output by the network layer of the student network into a third feature aligned with the teacher network. During the training of the student network by the teacher network, the training of the intermediate transformation network is also guided. The first loss function used to train the intermediate transformation network is the loss of the third feature relative to the second feature output by the network layer of the teacher network. The contraction module is used to convert the third feature into a fourth feature aligned with the first feature, so that the feature can still enter the next network layer of the student network after passing through the intermediate transformation network.
[0117] In this embodiment, optionally, a network structure search space is also configured in the neural network training device 401 to store some network structures, so that the neural network training device 401 can search in the network structure search space based on Neural Architecture Search (NAS) technology to determine the intermediate transformation network. The purpose of NAS technology is to automatically search for new network structures through search algorithms, thereby replacing manual design of network structures.
[0118] In this embodiment, the intermediate transformation network undergoes a homomorphic transformation to reduce the aforementioned first loss function. Optionally, the homomorphically transformed networks are also obtained from the network structure search space, which is combined here... Figures 8-10For example, let C2*C1*1*1 be network 1, two parallel C2*C1*1*1 networks be network 2, serial C2*C1*1*1 and C3*C2*1*1 be network 3, serial C4*C1*1*1 and C3*C4*1*1 be network 4, and serial C5*C1*1*1 and C2*C5*1*1 be network 5. Since network 1 is mathematically equivalent to network 2 and network 5, searching for network 1 in the network structure search space may lead to the discovery of network 2 and / or network 5. When multiple networks are found, further filtering can be performed through a certain mechanism. Therefore, replacing network 1 with network 2 or network 5 can be considered a homomorphic transformation. Similarly, since network 3 and network 4 are mathematically equivalent, searching for network 3 in the network structure search space may lead to the discovery of network 4. Therefore, replacing network 3 with network 4 can be considered a homomorphic transformation.
[0119] Please see Figure 6 , Figure 6 This is a flowchart illustrating a neural network training method provided in an embodiment of this application. The method can be based on... Figure 4 The architecture shown can also be implemented based on other architectures, and the method includes, but is not limited to, the following steps:
[0120] Step S601: Determine the teacher network and student network.
[0121] Specifically, both the student network and the teacher network can be pre-designed or selected from an existing network library as the student network; the specific method of determination is not limited here. Optionally, the teacher network can also be obtained by scaling up the student network. Scaling methods include, but are not limited to, increasing the width, increasing the number of network layers, and using heterogeneous large networks.
[0122] After obtaining the teacher network, it can be trained to obtain a teacher network with higher prediction accuracy. The teacher network used in the subsequent distillation process is the teacher network trained here.
[0123] Step S602: Configure an intermediate transformation network between the teacher network and the student network.
[0124] Specifically, the number of intermediate transformation networks configured can be one or more. These intermediate transformation networks are configured between the outputs of the network layers in the teacher network and the student network. Generally, both the teacher and student networks have many network layers. However, not all network layer outputs are connected to an intermediate transformation network. In this embodiment, if a network layer and its next layer are layers performing linear operations, such as convolution or deconvolution, then the output of that network layer can be connected to an intermediate transformation network. If the network layer and its next layer are not performing linear operations, then the network layer is not connected to an intermediate transformation network because it cannot be truly integrated between that network layer and its next layer, i.e., it cannot be integrated into the student network. The final number of intermediate transformation networks deployed between the teacher and student networks may also be determined by considering other factors. Figure 5 As shown, for example, an intermediate transformation network R_0 is set between the network layer layer_t0 of the teacher network and the network layer layer_s0 of the student network; an intermediate transformation network R_1 is set between the network layer layer_t1 of the teacher network and the network layer layer_s1 of the student network; and an intermediate transformation network R_i is set between the network layer layer_ti of the teacher network and the network layer layer_si of the student network, where i is an integer.
[0125] Optionally, embodiments of this application may also configure a network structure search space to store some network structures, so that the neural network training device 401 can perform homomorphic transformation search in the network structure search space based on Neural Architecture Search (NAS) technology, thereby determining (or updating) the network in the intermediate transformation network. The purpose of NAS technology is to automatically search for new network structures through search algorithms, thereby replacing manual design of network structures. The network in the intermediate transformation network can be called the transformation network. The initial transformation network can also be searched from the network structure search space. Of course, the simplest network can also be used by default, and there is no specific limitation here.
[0126] It should be noted that the network structure search space includes multiple network structures, all of which are convolutional layers with a kernel size of 1x1, generally linear convolutional layers, and all can be integrated with the student network.
[0127] Step S603: Train the intermediate transformation network through the network layers of the teacher network.
[0128] Specifically, a first feature and a second feature can be obtained firstly. The first feature is obtained by processing the input data through the first network layer of the student network. The input data includes any one or more of the following: image data, audio data, and text data. The second feature is obtained by processing the input data through the second network layer of the teacher network. The second network layer is the network layer in the teacher network that corresponds to the first network layer of the student network. For example, layer_s0 and layer_s1 represent two different first network layers, and layer_t0 and layer_t1 represent two different second network layers.
[0129] like Figure 7A As shown, for any intermediate transformation network, the intermediate transformation network includes an expansion module 701 and a contraction module 702. The expansion module in the intermediate transformation network is used to convert the first feature output by the network layer of the student network into a third feature aligned with the teacher network. The contraction module in the intermediate transformation network is used to convert the third feature into a fourth feature aligned with the student network. It can be understood that, as Figure 7A As shown, assume the number of feature channels corresponding to the student network is C. S The number of characteristic channels in the teacher network is C. T Therefore, in the process of converting the first feature output by the student network's network layer into a third feature aligned with the teacher network, the number of input channels in the input layer remains constant at C. S The number of output layer channels is constant at CT, and the network that implements this function can be called an extended module; during the process of converting the third feature into a fourth feature aligned with the student network, the number of input layer channels is constant at C. T The number of output layer channels is constant, C. S A network that implements this function can be called a shrink module.
[0130] In this module, the expansion module includes a branch count m greater than or equal to 1, and the contraction module includes a branch count k greater than or equal to 1. During the subsequent homomorphic transformation, m and k may change continuously as the transformation progresses. The number of layers q in different branches can be different, such as... Figure 7A As shown, each column of small squares can be considered a branch, and each row of small squares can be considered a layer. Each small square indicates the shape of the convolution kernel. For example, the small square marked S11 corresponds to a convolution kernel shape of C8*C9*1*1, indicating that the number of input channels for this convolution operation is C8 and the number of output channels is C9. q is greater than or equal to 1. Figure 7B As shown, the input to an lx1 convolution is a vector with dimension C. in The feature map *h*w (ignoring the batch dimension and considering only one image) then uses each C... in*1*1 convolution kernel pair C in Convolution operations are performed on the *h*w feature maps to obtain a 1*h*w feature map, because C out If there are 1 convolution kernel, then C will be generated in the end. out Feature map of *h*w.
[0131] In the process of using the teacher network to guide the training of the student network, the intermediate transformation network will also be trained. The first loss function used to train the intermediate transformation network is the loss of the third feature relative to the second feature output by the teacher network layer. For ease of understanding, an example of the second feature is given here, such as... Figure 5 As shown, the arrows near the "data" side of network layer_t0 indicate the input features of network layer_t0, and the arrows near layer_t1 indicate the output features, i.e., the second features, of network layer_t0. Similarly, the arrows near layer_t0 of network layer_t1 indicate the input features of network layer_t1, and the arrows near layer_t2 of network layer_t1 indicate the output features, i.e., the second features, of network layer_t1. The second features at other positions follow the same pattern. It can be understood that the first loss function determined based on the second features output by layer_t0 is used to train the intermediate transformation network R_0, the first loss function determined based on the second features output by layer_t1 is used to train the intermediate transformation network R_1, and so on.
[0132] During training, the parameters of the intermediate transformation network are continuously adjusted to make the first loss function smaller and smaller. The functions used in the training process include, but are not limited to, L1 loss function, L2 loss function, cosine distance, Kullback-Leibler divergence (KLD) function, and cross-entropy.
[0133] In this embodiment, the reason for converting the third feature into a fourth feature aligned with the student network is that the intermediate transformation network ultimately needs to be integrated into the student network. The features processed by each layer of the student network need to be aligned with the student network. After the intermediate transformation network converts the first feature into the third feature, it becomes misaligned with the student network. Therefore, the third feature needs to be converted into a fourth feature aligned with the student network, which helps the intermediate transformation network to better integrate into the student network.
[0134] Step S604: Perform homomorphic transformation on the intermediate transformation network during the training process of the intermediate transformation network.
[0135] In this embodiment, the expansion module and the contraction module in the intermediate transform network are two separate parts, each implementing two different functions. Therefore, these two modules can also be trained separately, that is, they can each perform homomorphic transformation search separately. Specifically, homomorphic transformation search is performed based on a preset network structure search space to obtain a first target network equivalent to the expansion module in the intermediate transform network to update the expansion module; and / or, homomorphic transformation search is performed based on a preset network structure search space to obtain a second target network equivalent to the contraction module in the intermediate transform network to update the contraction module.
[0136] There are many ways to perform homomorphic transformation search based on a preset network structure search space to obtain a target network that is mathematically equivalent to the current network in the expansion module and / or contraction module, and then update the expansion module and / or contraction module. Several optional homomorphic transformation search methods are listed below:
[0137] The first method is homomorphic transformation search based on splitting operations.
[0138] by Figure 8 As shown in the example, the network before the homomorphic transformation search includes operators (OPs), where OP is a 1*1 convolution with a kernel shape of C2*C1*1*1, where C... in These are the input features of the OP, with shape n*C1*h*w, where C... out The output features of the operator (OP) are of shape n*C2*h*w. The splitting operation is a process of increasing the number of branches in the network. For example, one convolutional kernel OP can be split into two (or other numbers) convolutional kernels OP1 and OP2. OP1 is a 1*1 convolution with a kernel shape of C2*C1*1*1, and OP2 is also a 1*1 convolution with a kernel shape of C2*C1*1*1. The input features of the OP1 and OP2 networks obtained by homomorphic transformation search are C... in The input features are the same as those before the homomorphic transformation search (i.e., before the splitting operation), and the output features of the network obtained by the homomorphic transformation search are C. out2 From a mathematical perspective, the following relationship is satisfied:
[0139] C out =C in *W (5-1)
[0140] C out2 =C in *W1+C in *W2=C in *(W1+W2) (5-2)
[0141] Where W represents the weights of the OP network before the homomorphic transformation search, and W1 and W2 represent the weights of the OP1 and OP2 networks obtained after the homomorphic transformation search, respectively. It can be seen that as long as W = W1 + W2 is satisfied, C is guaranteed. out =C out2 This ensures that the output of the network before the homomorphic transformation search is the same as the output of the network obtained by the homomorphic transformation search, guaranteeing that the target network obtained by the homomorphic transformation search is mathematically equivalent to the network before the homomorphic transformation search. In this example, the target network includes OP1 and OP2 networks, and the network before the homomorphic transformation search includes the OP network. It can be seen that the number of branches in the target network obtained after the homomorphic transformation search is greater than the number of branches in the intermediate transformation network (referring to the network before the homomorphic transformation search).
[0142] The second method is homomorphic transformation search based on widening operations.
[0143] by Figure 9 As shown in the example, the network before the homomorphic transformation search includes OP1 and OP2 networks. OP1 is a 1*1 convolution with a kernel shape of C2*C1*1*1; OP2 is a 1*1 convolution with a kernel shape of C3*C2*1*1. in These are the input features of the network before the homomorphic transformation search, with a shape of n*C1*h*w, where n is the first dimension of the feature that influences its shape; C out This represents the output feature of the network before the homomorphic transformation search, with a shape of n*C3*h*w. The widening operation is a process of increasing the number of channels in the convolutional layers of the network. For example, transforming OP1 and OP2 networks into OP3 and OP4 networks, where OP3 is a 1*1 convolution with a kernel shape of C4*C1*1*1; OP4 is a 1*1 convolution with a kernel shape of C3*C4*1*1; C... in It is the input feature of the network obtained by homomorphic transformation search, which is the same as the input feature of the network before homomorphic transformation search, with a shape of n*C1*h*w; while C out2 These are the output features of the network obtained through homomorphic transformation search, with a shape of n*C3*h*w. From a mathematical perspective, they satisfy the following relationship:
[0144] C out =C in *W1*W2 (5-3)
[0145] C out2 =C in *W3*W4 (5-4)
[0146] Where W1 and W2 are the weights of the OP1 and OP2 networks before the homomorphic transformation search (referring to the widening operation), respectively, and W3 and W4 are the weights of the OP3 and OP4 networks obtained through the homomorphic transformation search, respectively. The widening operation widens W1 to W3, meaning the shape of W3 becomes C4*C1*1*1, and the shape of W4 becomes C3*C4*1*1. Specifically, the number of output channels of the OP1 network changes from C2 to C4, where C4 > C2. To ensure that the network before the homomorphic transformation search is mathematically equivalent to the target network obtained through the homomorphic transformation search, C must satisfy... out =C out2 That is, W1*W2 = W3*W4 must be satisfied. This can be achieved by ensuring that the number of neurons added to W3 and W4 is zero. It can be seen that the width of the target network obtained by the homomorphic transformation search is greater than the width of the current (before the homomorphic transformation search) network of the intermediate transformation network.
[0147] The third type is homomorphic transformation search based on variable depth operations.
[0148] by Figure 10 As shown in the example, the network before the homomorphic transformation search includes OP, where OP is a 1*1 convolution with a kernel shape of C2*C1*1*1; C in These are the input features of the network operator (OP) before homomorphic transformation search, with a shape of n*C1*h*w, where n represents the batch size of the neural network; and C... out This represents the output features of the network OP before homomorphic transformation search, with a shape of n*C2*h*w. The variable depth operation is a process of increasing the number of layers in the network. For example, transforming the OP network into an OP network and an OP1 network is equivalent to inserting an OP1 network into the original network. OP1 is a 1*1 convolution with a kernel shape of C2*C2*1*1; C... in It is the input feature of the network obtained by homomorphic transformation search, which is the same as the input feature of the network before homomorphic transformation search, with a shape of n*C1*h*w; while C out2 These are the output features of the network obtained through homomorphic transformation search, with a shape of n*C²*h*w. From a mathematical perspective, they satisfy the following relationship:
[0149] C out =C in *W (5-5)
[0150] C out2 =C in *W*W1 (5-6)
[0151] Where W represents the weights of the OP network before the homomorphic transformation search, and W1 represents the weights of the OP1 network formed after the homomorphic transformation search (here, a depth transformation operation). To ensure that the network before the homomorphic transformation search is mathematically equivalent to the target network obtained through the homomorphic transformation search, C must be satisfied. out =C out2 That is, W = W * W1 must be satisfied. Therefore, as long as W1 is initialized to have all elements equal to 1, W = W * W1 can be satisfied. It can be seen that the depth of the target network after the homomorphic transformation search is greater than the depth of the intermediate transformation network before the homomorphic transformation search.
[0152] In this embodiment of the application, the homomorphic transformation search can be any one of the three methods mentioned above, or any combination of multiple methods.
[0153] After the above homomorphic transformation, this application realizes the change of the intermediate transformation network structure. However, the output of the changed intermediate transformation network is mathematically equivalent to that before the transformation, so that the intermediate transformation network after the homomorphic transformation can quickly inherit the effect already trained before the homomorphic transformation, accelerate the efficiency of distillation training, and can also automatically grow an intermediate transformation network that is more suitable for distillation.
[0154] Optionally, the purpose of performing the above homomorphic transformation is to further reduce the first loss function. During the homomorphic transformation, if the first loss function decreases, the homomorphic transformation can continue; if the first loss function no longer decreases, the homomorphic transformation ends. Generally, multiple homomorphic transformations are required before the first loss function stops decreasing.
[0155] Optionally, if multiple intermediate transformation networks exist, homomorphic transformations can be performed sequentially from the bottom-level intermediate transformation networks to the top-level ones. When the homomorphic transformation of one intermediate transformation network is complete, the homomorphic transformation of the next higher-level intermediate transformation network is then performed, until all intermediate transformation networks have completed their homomorphic transformations. It should be noted that intermediate transformation networks closer to the output layer of the student network (or teacher network) are considered to be at the top level.
[0156] like Figure 7A As shown, during the homomorphic transformation process, the number of branches (e.g., the number of branches m in the expanding module and the number of branches k in the contracting module) and the number of layers q in the intermediate transformation network may change (the number of layers in the expanding module and the number of layers in the contracting module may be different or the same, depending on the homomorphic transformation).
[0157] Step S605: Integrate the intermediate transformation network into the student network.
[0158] Optionally, before merging the intermediate transform network into the student network, the networks in the intermediate transform network that have undergone homomorphic transformation can be merged. In practical applications, merging may not be performed. Below are three examples of operations in the merging case (which can be combined with...). Figure 7A (Understanding) In practical applications, the merging process may involve more or fewer operations, which is not limited here.
[0159] (1) First merge each independent branch into a 1x1 convolution.
[0160] For example, in the process of feature expansion (that is, transforming the first feature output by the network layers of the student network into a third feature aligned with the teacher network), the intermediate transformation network includes one or more branches to implement the feature expansion function (such as...). Figure 7A In the diagram, 7011, 7012, and 7013 represent three different branches that implement this feature expansion function (which can be called expansion modules). Taking one of the branches as an example, this branch has q consecutive calculations of 1x1 convolutions, and the combined result C... out The mathematical form is as follows:
[0161] C out =C in ×W1×W2×…×W q (6-1)
[0162] Among them, C in Let W1, W2, ..., W be the input features of this branch. q These are the weights of the first convolution operation, the second convolution operation, ..., the qth convolution operation on that branch, respectively.
[0163] Since 1x1 convolutions are all linear operations, the following relationship can be derived from the commutative and associative laws of matrix multiplication:
[0164] C out =C in ×(W1×W2×…×W q ) = C in ×W 1q (6-2)
[0165] Among them, W 1q The shape is C t ×C s The result is a 1x1 convolution, thus merging all convolution operations in this branch into a single 1x1 convolution operation. Similarly, other branches can be implemented in the same way, which will not be elaborated here.
[0166] For example, in the process of feature shrinkage (that is, transforming the third feature into a fourth feature aligned with the student network), the intermediate transformation network includes one or more branches to implement the feature shrinkage function (e.g., Figure 7A 7024, 7025, and 7026 are three different branches that implement the feature shrinkage function (which can be called shrinkage modules). Referring to the merging principle mentioned earlier, each branch in this feature shrinkage process can be merged into a C-shaped module. s ×C t A ×1×1 convolutional layer.
[0167] (2) Merge different branches into a single 1x1 convolution.
[0168] For example, if the feature expansion process merges m branches, resulting in m 1x1 convolutions, then in this step, these m branches can be further merged into a single main branch for the feature expansion process, with the merged result C. out The mathematical form is as follows:
[0169] C out =C in ×W 1q +C in ×W 2q +…+C in ×W mq (6-3)
[0170] Among them W xq It is the convolution of the merged result of step (1) on the x-th branch, where x takes a positive integer between 1 and m. Regardless of which value x takes within this interval, W xq The shapes are all C t ×C s Therefore, we can further transform the above equation (6-3) to obtain a 1x1 convolution, as follows:
[0171] C out =C in ×(W 1q +W 2q +…+W mq ) = C in ×W E (6-4)
[0172] Among them, W E The final merged result we obtain has the shape C. t ×C s ×1×1 convolution.
[0173] For example, if the feature shrinking process merges k branches, resulting in k 1x1 convolutions, then in this step, these k branches can be merged into a single main branch for the feature shrinking process, with the shape of the merged result being C. s ×C t ×1×1 convolution W s .
[0174] (3) Merge the 1x1 branches from the feature expansion process and the 1x1 branches from the feature contraction process to obtain the final 1x1 convolution of the intermediate transform network. Based on the relevant results calculated in step (2), the input feature C of the entire intermediate transform network can be determined. in and output feature C out The following relationship must be satisfied:
[0175] C out =C in ×W E ×W s (6-5)
[0176] Further W E ×W s By merging using the associative law, the final weights W of the intermediate transformation network are obtained. According to matrix multiplication, the shape of the final weights W is C. s ×C s ×1×1.
[0177] Optionally, when merging branches of the expansion module, in addition to the merging process described above, other orders can be followed, as long as they remain mathematically equivalent. For example, in a certain case, the expansion module has 5 branches, each with the same number of layers, such as 3 layers, and each layer has the same number of channels. In this case, you can choose to first merge the corresponding layers of different branches, thus merging the 5 branches into one branch, and then merge the different layers of the merged branch into a 1x1 convolution. Other cases can be deduced similarly; as long as the mathematical equivalence property is satisfied, different orders can be used for merging.
[0178] After obtaining the final network structure of the intermediate transformation network through the above methods, it is integrated into the aforementioned student network.
[0179] In this embodiment of the application, the intermediate transformation network is integrated into the student network in the following ways, including but not limited to the following two:
[0180] The first method involves inserting an intermediate transform network into the student network. Specifically, this includes first fusing the intermediate transform network into a target network layer, and then inserting the target network layer between the first network layer and the next network layer of the student network. This insertion method essentially preserves the structure of the intermediate transform network; after insertion, the student network's structure changes, as a target network layer is added. Besides directly inserting the intermediate transform network between two layers of the student network, it is also possible to first fuse the intermediate transform network into one or more network layers, and then insert the resulting one or more network layers between two layers of the student network.
[0181] The second approach involves directly merging the intermediate transformation network with a layer in the student network, such as the first network layer mentioned above. This integrates the intermediate transformation network into the first network layer, effectively updating its weights. The intermediate transformation network can then be considered nonexistent, and the student network's structure remains unchanged. To facilitate understanding, an example is provided below:
[0182] Specifically, this intermediate transformation network can be integrated into the layer preceding its connection point with the student network (e.g., ...). Figure 5 The intermediate transformation network connecting the output position of layer 0 and the input position of layer 1 can be selected to be incorporated into layer 0. Alternatively, this intermediate transformation network can be incorporated into the layer immediately following its connection point with the student network (e.g., ...). Figure 5 The intermediate transformation network connecting the output position of layer 0 and the input position of layer 1 can be incorporated into layer 1. The following explanation uses insertion into the previous layer as an example.
[0183] As described above, the output characteristic C of the intermediate transform network out The mathematical expression is:
[0184] C out =C in ×W (6-6)
[0185] Where W is C s ×C s A convolution kernel of shape ×1×1, and its input feature C in It is also the result of some kind of convolution operation, therefore C in The calculation formula can be expanded as follows:
[0186] C in =C p ×W p (6-7)
[0187] Among them, C pW is the input of the layer preceding the connection point between this intermediate transformation network and the student network. p This is the convolution kernel of the previous layer, with kernel size j, where j is a positive integer. p The shape is C s ×C p ×1×1. Combining formulas (6-6) and (6-7), the output C of the previous layer can be obtained. out The relationship shown in formula (6-8) is satisfied:
[0188] C out =C p ×W p ×W=C p ×(W p ×W)=C p ×W′ p (6-8)
[0189] As in formula (6-8), the previous layer weight W is also determined by the commutative and associative laws. p The weights W of the final intermediate transformation network are further fused to obtain the new weights W′ of the previous layer. p This achieves the goal of integrating the intermediate transformation network into the previous layer, that is, integrating the intermediate transformation network into the student network.
[0190] The principle of integrating the intermediate transformation network into the latter layer is the same as that of integrating it into the former layer, and will not be repeated here.
[0191] Step S606: Train the student network that incorporates the intermediate transformation network according to the second loss function.
[0192] First, the value of the second loss function needs to be determined. For example, the value of the second loss function of the student network will be determined based on the output of the student network. Specifically, the type of the output of the student network is not limited. For example, it can be the pixel value of an image (or other types of numerical values), the image classification result (or the classification of other scenes), the image recognition result (or the recognition of other scenes), and so on. Then, by comparing the output result with the reference result, the loss of the student network can be obtained. For ease of description, the function that measures this loss can be called the second loss function. The reference result can be the real result in the pre-preserved test samples (such as the real pixel value, or the real recognition result, or the real classification result, etc.), or the output result of the teacher network, etc. Therefore, the second loss function can reflect the accuracy of the student network's output result. The smaller the value of the second loss function, the more accurate the prediction of the student network.
[0193] Then, the student network is trained based on the second loss function. For example, the parameters in the student network, which incorporates the intermediate transformation network, can be adjusted to reduce the value of the recalculated second loss function. This process is repeated iteratively until the value of the second loss function decreases to an expected value (usually a small value). The resulting student network is the final student network, which can then be used for prediction.
[0194] An optional logical relationship in the process implemented by the neural network training device described above can be as follows: Figure 11 As shown, the process begins by determining the teacher and student networks, and then configuring intermediate transformation networks between them. Next, it's determined whether all intermediate transformation networks have been trained. If not, untrained intermediate transformation networks are selected from the bottom up and trained using the first loss function. During training, homomorphic transformation is used to gradually reduce the first loss function. If the first loss function no longer decreases, training of the intermediate transformation networks ends. The trained intermediate transformation networks are then integrated into the student network, and the student network is trained using the second loss function until the loss is reduced to the expected value. This completes the network distillation process.
[0195] Step S607: The neural network training device sends the student network to the network user device.
[0196] Step S608: The network uses a device to receive the student network sent by the neural network training device.
[0197] Specifically, after receiving the target network, the network uses the student network to make predictions, such as predicting the pixel values (or other types of values) of an image, predicting the image classification result (or the classification of other scenes), predicting the image recognition result (or the recognition of other scenes), performing machine translation based on the acquired speech content, planning routes based on the collected driving information, classifying text content such as novels, and so on.
[0198] The following section discusses specific application scenarios. Figure 6 The method shown will be further explained.
[0199] Scenario 1: Distillation of the student network AI-VRAW, a raw (RAW) domain video augmentation network. The input to the RAW domain video augmentation network is a RAW domain video frame image, and the output is an enhanced RGB video frame image. The purpose of the RAW domain video augmentation network is to achieve effects such as noise reduction and detail enhancement on RAW domain video frames. The following describes some steps of the distillation process.
[0200] (1) The student network is scaled up proportionally according to the number of channels to obtain the teacher network, and the final teacher network is obtained through training.
[0201] (2) Figure 12 As shown, in the student network, an initial intermediate transformation network can be inserted between each Down stage (belonging to the network layers mentioned above) and Up stage (belonging to the network layers mentioned above), i.e. Figure 12 The structure of Trans0 to Trans6 in the text. Figure 12 In the diagram, elements marked Down0 to Down3 represent downsampling branches within numerous network layers, each with an output size 1 / 4 of the input size. Elements marked Up1 to Up3 represent upsampling branches within numerous network layers, each with an output size 4 times the input size. The final output block outputs RGB video frames. Elements marked Trans represent the intermediate transform network in this embodiment. It should be noted that the term "insertion" here indicates the connection relationship between the intermediate transform network and the network layers in the student network, and does not imply that the intermediate transform network has been integrated into the student network.
[0202] (3) Based on the above homomorphic transformation principle, distillation training is performed stage by stage until all Trans elements are trained.
[0203] (4) Continue to train the entire student network using the original minimum absolute value error loss function (L1 loss function).
[0204] (5) The final intermediate transformation network between the two stages is merged into one of the stages to obtain a new student network.
[0205] (6) Based on the original task and L1 loss function, the new student network is further fine-tuned to obtain the final student network.
[0206] Scenario 2: Student network distillation for a classification task. The network structure is a 50-layer deep residual network (ResNet50). Its input is an image, and its output is the class label corresponding to the image. The goal of this task is to automatically classify images, which can be applied to smart photo albums, etc. Figure 6 When the method shown is applied to a classification task, the distillation process includes the following key steps.
[0207] (1) The student network is scaled up proportionally according to the number of channels to obtain the teacher network, and the final teacher network is obtained through training.
[0208] (2) In the student network, the initial intermediate transformation network is inserted at the 3x3 convolutional layer in the last bottleneck layer of each stage. For example... Figure 13As shown, the original normal Bottleneck contains two 1x1 and one 3x3 convolutions. In this embodiment, the initial feature transformation network Trans is inserted after the 3x3 and before the 1x1.
[0209] (3) Based on the above homomorphic transformation principle, distillation training is performed stage by stage until all Trans elements are trained.
[0210] (4) Continue training the entire student network using the original cross-entropy (CE) loss function.
[0211] (5) The final intermediate transformation network between the two stages is merged into one of the stages to obtain a new student network.
[0212] As shown in Table 1, Table 1 compares the classification accuracy of the ordinary network without distillation, the student network with ordinary distillation, and the student network (ResNet50 structure) with the distillation technique of this application when performing classification tasks. It can be seen that the student network obtained by the distillation technique of this application has a higher prediction accuracy than the other two cases.
[0213] Table 1
[0214]
[0215] Understandable. Figure 6 The general process of the method shown is basically the same when applied to different tasks (or scenarios). The only difference is that the network structure used in different tasks (or scenarios) is different, the specific position of the initial intermediate transformation network insertion may also be different, and the loss function used for distillation training may be different for different tasks (or scenarios).
[0216] The above embodiments mention the training of the teacher network, the intermediate transformation network, and the student network. These three types of training can be carried out simultaneously, or the teacher network can be trained first, followed by the intermediate transformation network, and then the student network. Training the teacher network first, and then training the intermediate transformation network and the student network, can reduce computational overhead.
[0217] exist Figure 6In the described method, during the distillation of the student network through the teacher network, an initial intermediate transformation network is inserted between the network layers of the teacher network and the student network. After the student network is distilled, this intermediate transformation network is integrated into the student network, preserving as much key information as possible from the distillation process and improving the prediction accuracy of the student network. Furthermore, since the homomorphic transformation process is essentially an automatic network growth process, the resulting target network is generally no longer a single linear layer; therefore, the first loss function can be further reduced based on the target network. When the target network is subsequently fused into the student network, the result is that the fused student network has a higher fit to the teacher network, leading to higher prediction accuracy.
[0218] The methods of the embodiments of this application have been described in detail above, and the apparatus of the embodiments of this application is provided below.
[0219] Please see Figure 14 , Figure 14 This is a schematic diagram of the structure of a neural network training device 140 provided in an embodiment of this application. The neural network training device can be the aforementioned neural network training equipment or a module in the neural network training equipment. The neural network training device 140 may include a first acquisition unit 1401, a second acquisition unit 1402, and a first training unit 1403, wherein each unit is described in detail below.
[0220] The first acquisition unit 1401 is used to acquire a first feature, which is obtained by processing the input data through the first network layer of the student network. The input data includes any one or more of the following data: image data, audio data, and text data.
[0221] The second acquisition unit 1402 is used to acquire a second feature, which is obtained by processing the input data through the second network layer of the teacher network. The second network layer is the network layer in the teacher network that corresponds to the first network layer of the student network. Optionally, the correspondence between the first network layer and the second network layer means, for example, that the function of the first network layer in the entire student network is the same as the function of the second network layer in the teacher network, or that the structure of the feature output by the first network layer is the same as the structure of the feature output by the second network layer. Of course, there are other situations, which will not be listed here.
[0222] The first training unit 1403 is used to train an intermediate transformation network according to a first loss function. The intermediate transformation network includes an expansion module and a contraction module. The expansion module is used to convert the first feature into a third feature. The third feature is aligned with the second feature. The first loss function is used to measure the difference between the third feature and the second feature.
[0223] The shrinking module is used to convert the third feature into a fourth feature, and the fourth feature is aligned with the first feature; it should be noted that the feature alignment mentioned in the embodiments of this application refers to the shape alignment of the feature map.
[0224] In the above device, the expansion module can align the features output by the student network with the features output by the teacher network as much as possible, so as to make full use of the knowledge learned by the teacher network to improve the accuracy of the intermediate transformation network and the student network. The shrinking module can shrink the expanded feature map back to its original size, which can ensure that the intermediate transformation network can be seamlessly integrated into the student network after training, making the student network more effective when applied.
[0225] In an alternative embodiment, in training the intermediate transformation network according to the first loss function, the first training unit 1403 is specifically used for:
[0226] The intermediate transformation network is iteratively subjected to multiple homomorphic transformations until the first loss function no longer decreases.
[0227] In this embodiment, since the homomorphic transformation process is essentially an automatic network growth process, the resulting target network is generally no longer a single linear layer. Therefore, the first loss function can be further reduced based on the target network. When the target network is subsequently fused into the student network, the fused student network achieves a higher fit with the teacher network and higher prediction accuracy. Furthermore, a stopping mechanism is implemented for the homomorphic transformation process, ensuring the transformation effect while minimizing the various overheads incurred during the transformation process.
[0228] In one alternative embodiment, regarding one homomorphic transformation among multiple homomorphic transformations, the first training unit 1403 is specifically used for:
[0229] A homomorphic transformation search is performed based on a preset network structure search space to obtain a first target network equivalent to the expansion module in the intermediate transformation network, thereby updating the expansion module; and / or a homomorphic transformation search is performed based on a preset network structure search space to obtain a second target network equivalent to the contraction module in the intermediate transformation network, thereby updating the contraction module.
[0230] It is understandable that the expansion module and the contraction module are two separate parts that perform two different functions. Therefore, these two modules can also be trained separately, that is, they can each be subjected to homomorphic transformation search. Of course, one module can be updated using homomorphic transformation search, while the other module is updated using other methods.
[0231] In one optional scheme, the network structure search space includes multiple network structures, all of which are convolutional layers with a kernel size of 1x1, and all can be integrated with the student network. It is understood that since multiple network structures are convolutional layers with a kernel size of 1x1, homomorphic transformation can ensure the mathematical equivalence of the networks before and after the transformation as much as possible, reducing information loss. Furthermore, it also facilitates better integration of the homomorphically transformed network into the student network.
[0232] In an alternative embodiment, the device 140 further includes:
[0233] The integration unit is used to integrate the intermediate transformation network into the student network. It can be understood that integrating the intermediate transformation network into the student network preserves as much key information as possible involved in the distillation process, thereby improving the prediction accuracy of the student network.
[0234] In one alternative embodiment, in integrating the intermediate transformation network into the student network, the integration unit is specifically used for:
[0235] The intermediate transformation network updates the weights of the first network layer or the next layer of the student network. It can be understood that this intermediate transformation network is trained based on the relevant information of the first network layer. Therefore, updating the weights of the first network layer or the next layer of the first network layer through the intermediate transformation network effectively ensures the representation of information in the intermediate transformation network, thereby improving the prediction accuracy of the student network. It should be noted that after inserting the intermediate transformation network into the first network layer or the next layer of the first network layer, the network structure of the student network does not change.
[0236] In yet another possible implementation, in integrating the intermediate transformation network into the student network, the integration unit is specifically used for:
[0237] The intermediate transformation network is fused into a target network layer, and this target network layer is inserted between the first network layer and the next network layer of the student network. It can be understood that this intermediate transformation network is trained based on the relevant information of the first network layer. Therefore, fusing it into a target network layer and inserting it between the first network layer and the next network layer effectively ensures the representation of information in the intermediate transformation network, thereby guaranteeing the prediction accuracy of the student network. It should be noted that inserting the target network layer between the first network layer and the next network layer changes the structure of the student network, essentially adding an extra target network layer.
[0238] In yet another possible implementation...
[0239] The number of branches in the target network is greater than the number of branches in the current intermediate transformation network, and / or,
[0240] The target network has a width greater than the width of the current intermediate transformation network, and / or the target network has a depth greater than the depth of the current intermediate transformation network.
[0241] It is understandable that because the number of branches, width, or depth of the target network increases, the number of parameters in the target network increases, thus enhancing the fitting ability of the target network.
[0242] Another possible implementation includes:
[0243] The second training unit is used to train the student network incorporated into the intermediate transformation network according to the second loss function.
[0244] In other words, by integrating the intermediate transformation network into the student network and then continuing to train the student network, the negative impact of the integration can be corrected, thereby further improving the prediction accuracy of the student network.
[0245] In another alternative, the teacher network is obtained by enlarging (e.g., widening or deepening) the student network, or the teacher network is a heterogeneous network with a different structure from the student network.
[0246] It should be noted that the implementation of each unit can also be referenced accordingly. Figure 6 The corresponding description of the method embodiments shown.
[0247] Please see Figure 15 , Figure 15 This application provides a neural network training device 150, which includes a processor 1501, a memory 1502, and a communication interface 1503. The processor 1501, the memory 1502, and the communication interface 1503 are interconnected via a bus.
[0248] The memory 1502 includes, but is not limited to, random access memory (RAM), read-only memory (ROM), erasable programmable read-only memory (EPROM), or compact disc read-only memory (CD-ROM), and is used for related computer programs and data. The communication interface 1503 is used for receiving and sending data.
[0249] Processor 1501 can be one or more central processing units (CPUs). When processor 1501 is a CPU, the CPU can be a single-core CPU or a multi-core CPU.
[0250] The processor 1501 in the neural network training device 150 is used to read the computer program code stored in the memory 1502 and perform the following operations:
[0251] A first feature is obtained, which is obtained by processing the input data through the first network layer of the student network. The input data includes any one or more of the following data: image data, audio data, and text data.
[0252] The second feature is obtained by processing the input data through the second network layer of the teacher network. The second network layer is the network layer in the teacher network that corresponds to the first network layer of the student network. Optionally, the correspondence between the first network layer and the second network layer means, for example, that the function of the first network layer in the entire student network is the same as the function of the second network layer in the teacher network, or that the structure of the feature output by the first network layer is the same as the structure of the feature output by the second network layer. Of course, there are other situations, which will not be listed here.
[0253] An intermediate transformation network is trained according to a first loss function. The intermediate transformation network includes an expansion module and a contraction module. The expansion module is used to convert the first feature into a third feature, and the third feature is aligned with the second feature. The first loss function is used to measure the difference between the third feature and the second feature.
[0254] The shrinking module is used to convert the third feature into a fourth feature, and the fourth feature is aligned with the first feature; it should be noted that the feature alignment mentioned in the embodiments of this application refers to the shape alignment of the feature map.
[0255] In the above method, the expansion module can align the features output by the student network with the features output by the teacher network as much as possible, so as to make full use of the knowledge learned by the teacher network to improve the accuracy of the intermediate transformation network and the student network. The contraction module can shrink the expanded feature map back to its original size, which can ensure that the intermediate transformation network can be seamlessly integrated into the student network after training, making the student network perform better when applied.
[0256] In one possible implementation, training the intermediate transformation network based on the first loss function includes:
[0257] The intermediate transformation network is iteratively subjected to multiple homomorphic transformations until the first loss function no longer decreases.
[0258] In this embodiment, since the homomorphic transformation process is essentially an automatic network growth process, the resulting target network is generally no longer a single linear layer. Therefore, the first loss function can be further reduced based on the target network. When the target network is subsequently fused into the student network, the fused student network achieves a higher fit with the teacher network and higher prediction accuracy. Furthermore, a stopping mechanism is implemented for the homomorphic transformation process, ensuring the transformation effect while minimizing the various overheads incurred during the transformation process.
[0259] In yet another possible implementation, one of the homomorphic transformations in the multiple homomorphic transformations includes:
[0260] A homomorphic transformation search is performed based on a preset network structure search space to obtain a first target network equivalent to the expansion module in the intermediate transformation network, thereby updating the expansion module; and / or a homomorphic transformation search is performed based on a preset network structure search space to obtain a second target network equivalent to the contraction module in the intermediate transformation network, thereby updating the contraction module.
[0261] It is understandable that the expansion module and the contraction module are two separate parts that perform two different functions. Therefore, these two modules can also be trained separately, that is, they can each be subjected to homomorphic transformation search. Of course, one module can be updated using homomorphic transformation search, while the other module is updated using other methods.
[0262] In another possible implementation, the network structure search space includes multiple network structures, each of which is a convolutional layer with a kernel size of 1x1, and all of these structures can be integrated with the student network. It is understood that since multiple network structures are convolutional layers with a kernel size of 1x1, homomorphic transformation can ensure the mathematical equivalence of the networks before and after the transformation as much as possible, reducing information loss. Furthermore, it also facilitates better integration of the homomorphically transformed network into the student network.
[0263] Another possible implementation includes:
[0264] The intermediate transformation network is integrated into the student network. It can be understood that integrating this intermediate transformation network into the student network preserves as much key information as possible related to the distillation process, thereby improving the prediction accuracy of the student network.
[0265] In yet another possible implementation, integrating the intermediate transformation network into the student network includes:
[0266] The intermediate transformation network updates the weights of the first network layer or the next layer of the student network. It can be understood that this intermediate transformation network is trained based on the relevant information of the first network layer. Therefore, updating the weights of the first network layer or the next layer of the first network layer through the intermediate transformation network effectively ensures the representation of information in the intermediate transformation network, thereby improving the prediction accuracy of the student network. It should be noted that after inserting the intermediate transformation network into the first network layer or the next layer of the first network layer, the network structure of the student network does not change.
[0267] In yet another possible implementation, integrating the intermediate transformation network into the student network includes:
[0268] The intermediate transformation network is fused into a target network layer, and this target network layer is inserted between the first network layer and the next network layer of the student network. It can be understood that this intermediate transformation network is trained based on the relevant information of the first network layer. Therefore, fusing it into a target network layer and inserting it between the first network layer and the next network layer effectively ensures the representation of information in the intermediate transformation network, thereby guaranteeing the prediction accuracy of the student network. It should be noted that inserting the target network layer between the first network layer and the next network layer changes the structure of the student network, essentially adding an extra target network layer.
[0269] In yet another possible implementation...
[0270] The number of branches in the target network is greater than the number of branches in the current intermediate transformation network, and / or,
[0271] The target network has a width greater than the width of the current intermediate transformation network, and / or the target network has a depth greater than the depth of the current intermediate transformation network.
[0272] It is understandable that because the number of branches, width, or depth of the target network increases, the number of parameters in the target network increases, thus enhancing the fitting ability of the target network.
[0273] Another possible implementation includes:
[0274] The student network, after being incorporated into the intermediate transformation network, is trained according to the second loss function.
[0275] In other words, by integrating the intermediate transformation network into the student network and then continuing to train the student network, the negative impact of the integration can be corrected, thereby further improving the prediction accuracy of the student network.
[0276] In another possible implementation, the teacher network is obtained by enlarging (e.g., widening or deepening) the student network, or the teacher network is a heterogeneous network with a different structure from the student network.
[0277] It should be noted that the implementation of each operation can also be referenced accordingly. Figure 6 The corresponding description of the method embodiments shown.
[0278] This application embodiment also provides a chip system, the chip system including at least one processor, a memory, and interface circuitry, the memory, the transceiver, and the at least one processor being interconnected via circuitry, the at least one memory storing a computer program; when the computer program is executed by the processor, it implements... Figure 6 The method flow is shown.
[0279] This application also provides a computer-readable storage medium storing a computer program that, when run on a processor, implements... Figure 6 The method flow is shown.
[0280] This application also provides a computer program product that, when run on a processor, implements... Figure 6 The method flow is shown.
[0281] In summary, by implementing the embodiments of this application, during the distillation of the student network through the teacher network, an initial intermediate transformation network is inserted between the network layers of the teacher network and the student network. After the student network is distilled, this intermediate transformation network is integrated into the student network, preserving as much key information involved in the distillation process as possible and improving the prediction accuracy of the student network. Furthermore, since the homomorphic transformation process is an automatic network growth process, the resulting target network is generally no longer a single linear layer. Therefore, the first loss function mentioned above can be further reduced based on the target network. When the target network is subsequently fused into the student network, the effect achieved is that the fused student network has a higher fit with the teacher network and higher prediction accuracy.
[0282] Those skilled in the art will understand that all or part of the processes in the methods of the above embodiments can be implemented by a computer program using computer program-related hardware. The computer program can be stored in a computer-readable storage medium, and when executed, it can include the processes described in the above method embodiments. The aforementioned storage medium includes various media capable of storing computer program code, such as ROM or random access memory (RAM), magnetic disks, or optical disks.
Claims
1. A neural network training method, characterized in that, include: A first feature is obtained, which is obtained by processing the input data through the first network layer of the student network. The input data includes any one or more of the following data: image data, audio data, and text data. A second feature is obtained, which is obtained by processing the input data through the second network layer of the teacher network. The second network layer is the network layer in the teacher network that corresponds to the first network layer of the student network. An intermediate transformation network is trained according to a first loss function. The intermediate transformation network includes an expansion module and a contraction module. The expansion module is used to convert the first feature into a third feature, and the third feature is aligned with the second feature. The first loss function is used to measure the difference between the third feature and the second feature. The shrinking module is used to convert the third feature into a fourth feature, the fourth feature being aligned with the first feature; The intermediate transformation network is integrated into the student network.
2. The method according to claim 1, characterized in that, The step of training the intermediate transform network based on the first loss function includes: The intermediate transformation network is iteratively subjected to multiple homomorphic transformations until the first loss function no longer decreases.
3. The method according to claim 2, characterized in that, One of the multiple homomorphic transformations includes: A homomorphic transformation search is performed based on a preset network structure search space to obtain a first target network equivalent to the expansion module in the intermediate transformation network, thereby updating the expansion module; and / or a homomorphic transformation search is performed based on a preset network structure search space to obtain a second target network equivalent to the contraction module in the intermediate transformation network, thereby updating the contraction module.
4. The method according to claim 3, characterized in that, The network structure search space includes multiple network structures, all of which are convolutional layers with a kernel size of 1x1, and all can be integrated with the student network.
5. The method according to any one of claims 1-4, characterized in that, The integration of the intermediate transformation network into the student network includes: The weights of the first network layer or the next network layer of the student network are updated according to the intermediate transformation network.
6. The method according to any one of claims 1-4, characterized in that, The integration of the intermediate transformation network into the student network includes: The intermediate transformation network is merged into a target network layer, and the target network layer is inserted between the first network layer of the student network and the next network layer of the first network layer.
7. The method according to any one of claims 3-4, characterized in that, The number of branches in the first target network or the second target network is greater than the number of branches in the current network of the intermediate transformation network, and / or, The width of the first target network or the second target network is greater than the width of the current network of the intermediate transformation network, and / or, The depth of the first target network or the second target network is greater than the depth of the current network of the intermediate transformation network.
8. The method according to claim 4, characterized in that, Also includes: The student network, after being incorporated into the intermediate transformation network, is trained according to the second loss function.
9. A neural network training device, characterized in that, include: The first acquisition unit is used to acquire a first feature, which is obtained by processing the input data through the first network layer of the student network. The input data includes any one or more of the following data: image data, audio data, and text data. The second acquisition unit is used to acquire a second feature, which is obtained by processing the input data through the second network layer of the teacher network. The second network layer is the network layer in the teacher network that corresponds to the first network layer of the student network. The first training unit is used to train an intermediate transformation network according to a first loss function. The intermediate transformation network includes an expansion module and a contraction module. The expansion module is used to convert the first feature into a third feature, the third feature being aligned with the second feature. The first loss function is used to measure the difference between the third feature and the second feature. The shrinking module is used to convert the third feature into a fourth feature, the fourth feature being aligned with the first feature; An integration unit is used to integrate the intermediate transformation network into the student network.
10. The apparatus according to claim 9, characterized in that, In the aspect of training the intermediate transform network according to the first loss function, the first training unit is specifically used for: The intermediate transformation network is iteratively subjected to multiple homomorphic transformations until the first loss function no longer decreases.
11. The apparatus according to claim 10, characterized in that, Regarding one homomorphic transformation in the process of performing multiple homomorphic transformations, the first training unit is specifically used for: A homomorphic transformation search is performed based on a preset network structure search space to obtain a first target network equivalent to the expansion module in the intermediate transformation network, thereby updating the expansion module; and / or a homomorphic transformation search is performed based on a preset network structure search space to obtain a second target network equivalent to the contraction module in the intermediate transformation network, thereby updating the contraction module.
12. The apparatus according to claim 11, characterized in that, The network structure search space includes multiple network structures, all of which are convolutional layers with a kernel size of 1x1, and all can be integrated with the student network.
13. The apparatus according to any one of claims 9-12, characterized in that, In integrating the intermediate transformation network into the student network, the integration unit is specifically used for: The weights of the first network layer or the next network layer of the student network are updated according to the intermediate transformation network.
14. The apparatus according to any one of claims 9-12, characterized in that, In integrating the intermediate transformation network into the student network, the integration unit is specifically used for: The intermediate transformation network is merged into a target network layer, and the target network layer is inserted between the first network layer of the student network and the next network layer of the first network layer.
15. The apparatus according to any one of claims 11-12, characterized in that, The number of branches in the first target network or the second target network is greater than the number of branches in the current network of the intermediate transformation network, and / or, The width of the first target network or the second target network is greater than the width of the current network of the intermediate transformation network, and / or, The depth of the first target network or the second target network is greater than the depth of the current network of the intermediate transformation network.
16. The apparatus according to claim 12, characterized in that, Also includes: The second training unit is used to train the student network incorporated into the intermediate transformation network according to the second loss function.
17. A neural network training device, characterized in that, Includes processor and memory, of which: The memory is used to store computer programs; The processor is used to invoke the computer program to implement the method according to any one of claims 1-8.
18. A computer-readable storage medium, characterized in that, The computer-readable storage medium stores a computer program that, when run on a processor, implements the method described in any one of claims 1-8.
19. A computer program product, characterized in that, When the computer program product is run on a processor, it implements the method according to any one of claims 1-8.