Generative adversarial network optimization method and electronic device
By setting the weights of the generator and discriminator to be equal and using gradient descent to iteratively update the weights, the problem of unstable training of generative adversarial networks is solved, achieving diversity of generated samples and stability of the network.
Patent Information
- Authority / Receiving Office
- CN · China
- Patent Type
- Patents(China)
- Current Assignee / Owner
- FU TAI HUA IND SHENZHEN
- Filing Date
- 2021-05-19
- Publication Date
- 2026-06-16
AI Technical Summary
Generative adversarial networks are prone to instability during training, leading to pattern collapse and insufficient diversity of generated sample images.
By setting the weights of the generator and discriminator to be equal and using gradient descent to iteratively update the weights, the learning rate is dynamically adjusted until both the generator and discriminator converge, thus achieving a balance in learning ability.
This improves the stability and diversity of generated samples in generative adversarial networks, ensuring the balance and convergence of the generator and discriminator during training.
Smart Images

Figure CN115374899B_ABST
Abstract
Description
Technical Field
[0001] This application relates to the field of generative adversarial networks (GANs), specifically to a GAN optimization method and electronic device. Background Technology
[0002] Generative Adversarial Networks (GANs) consist of a generator and a discriminator. Through adversarial training between the generator and discriminator, the generator's generated samples conform to the distribution of real data. During training, the generator produces sample images based on input random noise, aiming to generate realistic images to deceive the discriminator. The discriminator learns to distinguish between real and generated sample images, aiming to differentiate between genuine and generated images.
[0003] However, generative adversarial networks have too much training freedom. When training is unstable, the generator and discriminator can easily fall into an abnormal adversarial state, resulting in mode collapse and insufficient diversity of generated sample images. Summary of the Invention
[0004] In view of this, this application provides a generative adversarial network optimization method and electronic device that can balance the losses of the generator and the discriminator, so that the generator and the discriminator have the same learning ability, thereby improving the stability of the generative adversarial network.
[0005] The generative adversarial network optimization method of this application includes: determining a first weight of the generator and a second weight of the discriminator, wherein the first weight and the second weight are equal, the first weight is used to represent the learning ability of the generator, and the second weight is used to represent the learning ability of the discriminator; and iteratively training the generator and the discriminator alternately until both the generator and the discriminator converge.
[0006] In this embodiment of the application, the learning ability is positively correlated with the first weight or the second weight.
[0007] The electronic device of this application includes a memory and a processor. The memory is used to store a computer program. When the computer program is called by the processor, it implements the generative adversarial network optimization method of this application.
[0008] This application iteratively updates the first weight of the generator and the second weight of the discriminator using gradient descent. As the training period lengthens, the learning rates of the generator and discriminator are dynamically adjusted until the loss functions of both converge, thus obtaining the optimal weights. The first weight and the second weight are equal, ensuring that the generator and discriminator have the same learning ability, thereby improving the stability of the generative adversarial network. Attached Figure Description
[0009] Figure 1 This is a schematic diagram of a generative adversarial network.
[0010] Figure 2 This is a schematic diagram of a neural network.
[0011] Figure 3 This is a flowchart of the generative adversarial network optimization method.
[0012] Figure 4 This is a schematic diagram of an electronic device.
[0013] Explanation of main component symbols
[0014] 10 Generative Adversarial Networks
[0015] 11 Generators
[0016] 12 Discriminator
[0017] z noise samples
[0018] x Data Sample
[0019] D. The probability of distinguishing between true and false
[0020] 20 Neural Networks
[0021] y output
[0022] W1, W2, W3 weights
[0023] z1, z2, z3 Hidden layer input
[0024] f1(z1), f2(z2), f3(z3) activation functions
[0025] 40 Electronic devices
[0026] 41 Memory
[0027] 42 processors Detailed Implementation
[0028] To better understand the above-mentioned objectives, features, and advantages of this application, the application will be described in detail below with reference to the accompanying drawings and specific embodiments. It should be noted that, unless otherwise specified, the embodiments and features described in the embodiments of this application can be combined with each other. Many specific details are set forth in the following description to provide a thorough understanding of this application; the described embodiments are only some embodiments of this application, and not all embodiments.
[0029] It should be noted that although a logical order is shown in the flowchart, in some cases, the steps shown or described may be performed in a different order than that shown in the flowchart. The methods disclosed in the embodiments of this application include one or more steps or actions for implementing the method. Method steps and / or actions may be interchanged with each other without departing from the scope of the claims. In other words, unless a specific order of steps or actions is specified, the order and / or use of specific steps and / or actions may be modified without departing from the scope of the claims.
[0030] Generative Adversarial Networks (GANs) are commonly used for data augmentation. When sample data is difficult to collect, they can be trained with a small amount of data to generate a large amount of sample data, thus solving the problem of insufficient sample data. However, GANs are prone to problems such as vanishing gradients, training instability, and slow convergence during training. When training is unstable, GANs are prone to mode collapse, resulting in insufficient diversity of generated sample data.
[0031] Based on this, this application provides a generative adversarial network optimization method, apparatus, electronic device, and storage medium that can balance the losses of the generator and the discriminator, so that the generator and the discriminator have the same learning ability, thereby improving the stability of the generative adversarial network.
[0032] Reference Figure 1 , Figure 1 This is a schematic diagram of a generative adversarial network 10. The generative adversarial network 10 includes a generator 11 and a discriminator 12. The generator 11 receives a noise sample z and generates a first image, and feeds the generated first image and a second image obtained from a data sample x into the discriminator 12. The discriminator 12 receives the first image and the second image and outputs the probability D of true or false discrimination. The probability D takes the value [0, 1], where 1 indicates that the discrimination result is true and 0 indicates that the discrimination result is false.
[0033] In this embodiment of the application, both the generator 11 and the discriminator 12 are neural networks, including but not limited to convolutional neural networks (CNN), recurrent neural networks (RNN), or deep neural networks (DNN).
[0034] During the training of the generative adversarial network 10, the generator 11 and discriminator 12 are trained alternately and iteratively, each optimizing its own network using its respective cost function or loss function. For example, when training the generator 11, the weights of the discriminator 12 are fixed while the weights of the generator 11 are updated; when training the discriminator 12, the weights of the generator 11 are fixed while the weights of the discriminator 12 are updated. Both the generator 11 and the discriminator 12 strive to optimize their respective networks, thus forming a competitive adversarial relationship until they reach a dynamic equilibrium, namely a Nash equilibrium. At this point, the first image generated by the generator 11 is completely identical to the second image obtained from the data sample x, and the discriminator 12 cannot distinguish between the real and fake images, outputting a probability D of 0.5.
[0035] In the embodiments of this application, weight refers to the number of weights in a neural network, which characterizes the learning ability of the neural network. The learning ability is positively correlated with the weight.
[0036] Reference Figure 2 , Figure 2 This is a schematic diagram of neural network 20. The learning process of neural network 20 consists of two processes: forward propagation of the signal and backward propagation of the error. During forward propagation, data sample x is input from the input layer, processed layer by layer through the hidden layers, and then propagates to the output layer. If the output y of the output layer does not match the expected output, the process switches to the backward propagation stage. Backward propagation of the error involves propagating the output error in some form back through the hidden layers to the input layer, distributing the error to all neural units in each layer, thereby obtaining the error signal of each neural unit. This error signal serves as the basis for adjusting the weights W.
[0037] In this embodiment, the neural network includes an input layer, hidden layers, and an output layer. The input layer receives data from outside the neural network, and the output layer outputs the computation results of the neural network. All layers other than the input and output layers are hidden layers. The hidden layers are used to abstract the features of the input data into another dimensional space to linearly divide different types of data.
[0038] The output y of the neural network 20 is shown in formula (1):
[0039] y=f3(W3*f2(W2*f1(W1*x))) (1)
[0040] Where x is a data sample, f1(z1), f2(z2), and f3(z3) are the activation functions of the hidden layer inputs z1, z2, and z3, respectively, and W1, W2, and W3 are the weights between layers.
[0041] The weights W are updated using the gradient descent method as shown in equation (2):
[0042]
[0043] Among them, W + Here, W represents the updated weights, W represents the original weights, Loss represents the loss function, and η represents the learning rate, which refers to the magnitude of the weight W update.
[0044] In the embodiments of this application, the loss function is used to measure the discriminator's ability to judge the generated image. The smaller the value of the loss function, the better the discriminator's performance is in the current iteration, and the better it can distinguish the generated image from the generator; conversely, the larger the value of the loss function, the worse the discriminator's performance is.
[0045] Please refer to the following: Figures 1 to 3 , Figure 3 A flowchart for a generative adversarial network (GAN) optimization method is provided. The GAN optimization method includes the following steps:
[0046] S31, determine the first weight of the generator and the second weight of the discriminator, wherein the first weight and the second weight are equal.
[0047] In the embodiments of this application, the methods for determining the first weight and the second weight include, but are not limited to, Xavier initialization, Kaiming initialization, Fixup initialization, LSUV initialization, or transfer learning.
[0048] The fact that the first weight is equal to the second weight indicates that the generator and the discriminator have the same learning ability.
[0049] S32, train the generator and update the first weights.
[0050] The update of the first weight is related to the generator's learning rate and loss function. The learning rate is dynamically set according to the number of training iterations, and the loss function L... g As shown in formula (3):
[0051]
[0052] Where m is the number of noise samples z, z (i) It refers to the i-th noise sample, G(z) (i) ) refers to the noise sample z (i) The generated image, D(G(z) (i) θ refers to the probability of determining whether the image is real. g This is the first weight.
[0053] The goal of the generator is to maximize the loss function L. g The goal is to make the generated sample distribution fit the real sample distribution as closely as possible.
[0054] S33, train the discriminator and update the second weights.
[0055] The update of the second weight is related to the discriminator's learning rate and loss function. The learning rate is dynamically set according to the number of training iterations, and the loss function L... d As shown in formula (4):
[0056]
[0057] Where, x (i) It refers to the i-th real image, D(x) (i) ) refers to determining the true image x (i) The probability of whether something is true, θ d This is the second weight.
[0058] The goal of the discriminator is to minimize the loss function L. d It aims to distinguish, as far as possible, whether the input sample is a real image or an image generated by the generator.
[0059] S34. Repeat steps S32 and S33 until both the generator and the discriminator converge.
[0060] In the embodiments of this application, the execution order of steps S32 and S33 is not limited. That is, in the alternating iterative training process of the generator and the discriminator, the generator can be trained first, or the discriminator can be trained first.
[0061] This application utilizes the gradient descent method to iteratively update the first weight θ. g With the second weight θ d The learning rates of the generator and discriminator are dynamically adjusted as the training period increases, until the loss function L of the generator is adjusted accordingly. g The loss function L of the discriminator d All converge, thus obtaining the optimal weights.
[0062] Reference Figure 4 , Figure 4 This is a schematic diagram of an electronic device 40. The electronic device 40 includes a memory 41 and a processor 42. The memory 41 stores a computer program, which, when called by the processor 42, implements the generative adversarial network optimization method of this application.
[0063] The electronic device 40 includes, but is not limited to, at least one of smartphones, tablets, personal computers (PCs), e-book readers, workstations, servers, personal digital assistants (PDAs), portable multimedia players (PMPs), MPEG-1 audio layer 3 (MP3) players, mobile medical devices, cameras, and wearable devices. The wearable device includes at least one of the following types: accessory type (e.g., watch, ring, bracelet, anklet, necklace, glasses, contact lens, or head-mounted device (HMD)), fabric or clothing integrated type (e.g., electronic clothing), body-installed type (e.g., skin pads or tattoos), and bio-implantable type (e.g., implantable circuitry).
[0064] The memory 41 is used to store computer programs and / or modules. The processor 42 implements the generative adversarial network optimization method of this application by running or executing the computer programs and / or modules stored in the memory 41 and calling the data stored in the memory 41. The memory 41 includes volatile or non-volatile storage devices, such as a digital multifunction disc (DVD) or other optical discs, a magnetic disk, a hard disk, a smart media card (SMC), a secure digital card (SD), a flash card, etc.
[0065] The processor 42 includes a central processing unit (CPU), a digital signal processor (DSP), an application-specific integrated circuit (ASIC), a field-programmable gate array (FPGA), or other programmable logic devices, discrete gate or transistor logic devices, discrete hardware components, etc.
[0066] It is understood that when the electronic device 40 implements the generative adversarial network optimization method of this application, the specific implementation of the generative adversarial network optimization method is applicable to the electronic device 40.
[0067] The embodiments of this application have been described in detail above with reference to the accompanying drawings. However, this application is not limited to the above embodiments. Within the scope of knowledge possessed by those skilled in the art, various changes can be made without departing from the spirit of this application. Furthermore, unless otherwise specified, the embodiments and features described in the embodiments of this application can be combined with each other.
Claims
1. A generative adversarial network optimization method, characterized in that, The method includes: A first weight of the generator and a second weight of the discriminator are determined, wherein the first weight and the second weight are equal, the first weight being used to represent the learning ability of the generator and the second weight being used to represent the learning ability of the discriminator; The update of the first weight is related to the learning rate and loss function of the generator, and the update of the second weight is related to the learning rate and loss function of the discriminator; The formula for updating the weights is: in, Here, W represents the updated weights, W represents the original weights, and Loss represents the loss function. The learning rate refers to the magnitude of the weight W update. The generator and the discriminator are trained alternately and iteratively until both the generator and the discriminator converge. The generator is used to receive noise samples and generate a first image, and feed the generated first image together with a second image obtained from data samples into the discriminator.
2. The generative adversarial network optimization method as described in claim 1, characterized in that, The learning ability is positively correlated with either the first weight or the second weight.
3. The generative adversarial network optimization method as described in claim 1 or 2, characterized in that, Both the generator and the discriminator are neural networks, and the neural network includes one of the following: convolutional neural network, recurrent neural network, and deep neural network.
4. The generative adversarial network optimization method as described in claim 3, characterized in that, The determination of the first weight of the generator and the second weight of the discriminator is carried out using one of the following methods: Xavier initialization, Kaiming initialization, Fixup initialization, LSUV initialization, or transfer learning.
5. The generative adversarial network optimization method as described in claim 3, characterized in that, The alternating iterative training of the generator and the discriminator includes: Train the generator and update the first weights; The discriminator is trained and the second weights are updated.
6. The generative adversarial network optimization method as described in claim 5, characterized in that, The learning rate is dynamically set based on the number of training iterations.
7. The generative adversarial network optimization method as described in claim 5, characterized in that, The loss function of the generator is: in, Let be the loss function of the generator, and m be the number of noise samples z. It refers to the i-th noise sample. This refers to using noise samples The generated image, It refers to the probability of determining whether the image is real. This is the first weight.
8. The generative adversarial network optimization method as described in claim 7, characterized in that, The loss function of the discriminator is: in, Let be the loss function of the discriminator. It refers to the i-th real image. This refers to determining the true image. The probability of whether it is true. This is the second weight.
9. An electronic device comprising a memory and a processor, wherein the memory is used to store a computer program, characterized in that, When the computer program is invoked by the processor, it implements the generative adversarial network optimization method as described in any one of claims 1 to 8.