Training student neural networks using teacher neural networks and additional blocks
Patent Information
- Authority / Receiving Office
- WO · WO
- Patent Type
- Applications
- Current Assignee / Owner
- DEEPMIND TECH LTD
- Filing Date
- 2025-12-11
- Publication Date
- 2026-06-18
AI Technical Summary
Larger and more complex neural networks increase computational and communication costs during training, making them impractical for deployment on devices with limited resources, and existing teacher distillation methods are computationally inefficient and prone to training instability.
A student neural network is trained using a teacher neural network with additional blocks to adapt to varying probability distributions, optimizing common loss terms like cross-entropy and KL divergence, and using an objective function that improves computational efficiency and reduces overfitting.
The student neural network achieves comparable accuracy to the teacher network with fewer computational resources, faster training, and is optimized for deployment on devices with limited resources, avoiding overfitting and training instability.
Smart Images

Figure US2025059188_18062026_PF_FP_ABST
Abstract
Description
Attorney Docket No. 45288-0598WO1TRAINING STUDENT NEURAL NETWORKS USING TEACHER NEURAL NETWORKS AND ADDITIONAL BLOCKSCROSS-REFERENCE TO RELATED APPLICATIONS
[0001] This application claims priority to U.S. Provisional Application No. 63 / 730,882, filed on December 11, 2024. The disclosure of the prior application is considered part of and is incorporated by reference in its entirety in the disclosure of this application.BACKGROUND
[0002] This specification relates to training neural networks.
[0003] Neural networks are machine learning models that employ one or more layers of nonlinear units to predict an output for a received input. Some neural networks include one or more hidden layers in addition to an output layer. The output of each hidden layer is used as input to the next layer in the network, i.e., the next hidden layer or the output layer. Each layer of the network generates an output from a received input in accordance with current value inputs of a respective set of parameters.
[0004] A general trend with neural networks has been to make larger and more complicated networks in order to achieve higher accuracy. As neural networks increase in size and complexity in service of increased accuracy, they also increase in computational and communication cost during the training of the neural networks.SUMMARY
[0005] This specification describes a training system implemented as computer programs on one or more computers in one or more locations that trains a student neural network to perform one or more machine learning tasks through teacher distillation training.
[0006] The subject matter described in this specification can be implemented in particular embodiments so as to realize one or more of the following advantages.
[0007] The specification describes techniques that improve the computational efficiency for training a student neural network through teacher distillation training using a teacher neural network that has already been trained. The student neural network is easier to deploy than the teacher neural network, i.e., because it requires less computation, memory, or both, to generate outputs at run time than the teacher neural network. For example, compared to theAttorney Docket No. 45288-0598WO1 teacher neural network, the student neural network is more practical to run on a device with limited computational power or resources, e.g., a mobile user device such as a smartphone or laptop computer.
[0008] As a particular example, the described techniques can be used to train a student neural network that has an architecture that is optimized for processing on a target hardware platform, e.g., a user device or one or more computers in the cloud that have a specific arrangement of processors, memory, storage, etc., by using a teacher neural network that is not optimized for processing on such a target hardware platform. That is, a student neural network can be trained to achieve hardware optimization for deployment: it may run more quickly and use less memory and power on the target hardware platform compared to the teacher neural network with comparable accuracy.
[0009] By using one or more additional neural network blocks together with the existing neural network blocks included in the student neural network during the teacher distillation training, the described techniques enable the student neural network to better adapt to the varying probability distributions included in the teacher outputs generated by the teacher neural network.
[0010] Not only does the use of the one or more additional neural network blocks improves the ability of the student neural network to adapt to the variation in training signals and avoid overfitting to the noise in teacher outputs, but it also allows the student neural network to be trained based on optimizing an objective function that has common loss terms, e.g., crossentropy loss terms or Kullback-Leibler (KL) divergence terms, that are robust and computationally efficient to compute, and obviates the need to compute a complex distillation loss term that is computationally inefficient to compute or that may otherwise introduce training instability.
[0011] Once trained using the described techniques, the distilled student neural network, which is easier to deploy than the teacher neural network, can generate outputs for any of a range of machine learning tasks — and in particular generative machine learning tasks — that have comparable quality (e.g., accuracy, comprehensiveness, or the like) to outputs generated by the teacher neural network, despite a teacher distillation training process that consumes fewer computational resources, is faster in terms of wall-clock time, or both, than aAttorney Docket No. 45288-0598WO1 conventional teacher distillation training process that does not involve any additional neural network blocks beyond those that are included in the student neural network.
[0012] According to an implementation there is provided a method of training a student neural network using a teacher neural network, wherein the student neural network comprises an input block followed by one or more intermediate blocks followed by an output block.The method comprises: obtaining a batch of training inputs from training data that comprises a plurality of training inputs; for each training input in the batch: processing the training input using the teacher neural network to generate a soft teacher output that includes a soft teacher probability distribution over tokens in a vocabulary; processing the training input using the input block and the one or more intermediate blocks of the student neural network to generate an initial intermediate output; processing the intermediate output using one or more additional blocks to generate a further processed intermediate output; and processing the further processed intermediate output using the output block of the student neural network to generate a first student output that includes a first student probability distribution over the tokens in the vocabulary; and training the student neural network on an objective function that includes (i) a first term that measures, for each training input in the batch, a difference between the first student output and the soft teacher output and (ii) a second term that, for each training input in the batch, is dependent on a second student output that is generated by using the student neural network based on the training input and without using the one or more additional blocks.
[0013] According to an implementation, the method further comprises, for each training input in the batch: processing the training input using the input block, the one or more intermediate blocks, and the output block of the student neural network to generate the second student output without using the one or more additional blocks, wherein the second student output includes a second student probability distribution over the tokens in the vocabulary.
[0014] According to an implementation: the method further comprises, for each training input in the batch, obtaining a hard teacher output that includes a hard teacher probability distribution over the tokens in the vocabulary; and the second term measures, for each training input in the batch, a difference between the second student output and the hard teacher output.Attorney Docket No. 45288-0598WO1
[0015] The details of one or more embodiments of the subject matter described in this specification are set forth in the accompanying drawings and the description below. Other features, aspects, and advantages of the subject matter will become apparent from the description, the drawings, and the claims.BRIEF DESCRIPTION OF THE DRAWINGS
[0016] FIG. 1 shows an example training system.
[0017] FIG. 2 is a flow diagram of an example process for training a student neural network using a teacher neural network.
[0018] FIG. 3 is a flow diagram of sub-steps of one of the steps of the process of FIG. 2.
[0019] FIG. 4 is an illustration of an example data flow during a training iteration.
[0020] Like reference numbers and designations in the various drawings indicate like elements.DETAILED DESCRIPTION
[0021] FIG. 1 shows an example training system 100. The training system 100 is an example of a system implemented as computer programs on one or more computers in one or more locations that trains a student neural network 110 through teacher distillation training by using a teacher neural network 150.
[0022] The teacher neural network 150 is a neural network that has already been trained, e.g., by the training system or another training system.
[0023] Generally, the student neural network 110 is a neural network that has a different architecture from the teacher neural network 150 that makes it easier to deploy than the teacher neural network 150, e.g., because the student neural network 110 requires less memory to store and less computation to generate outputs at run time than the teacher neural network 150.
[0024] For example, the student neural network 110 may have fewer layers, fewer parameters, or both than the teacher neural network 150. In this example, in some cases, the smaller, more computationally efficient student neural network 110 can then be deployed at an inference system, e.g., on an edge device or in another computing environment with limited computational budget, whereas the teacher neural network 150 cannot.Attorney Docket No. 45288-0598WO1
[0025] For example, the teacher neural network 150 may have too large of a memory footprint or generate outputs with too long of a latency in order to be effectively deployed by an inference system. However, the training system 100 may be implemented in a data center with a large number of computing devices, and the extra computational resources that are available in the data center can be used to allow the teacher neural network 150 to be used to improve the training of the student neural network 110.
[0026] Both the student neural network 110 and the teacher neural network 150 can be configured through training to perform any kind of machine learning task, i.e., the neural networks can each be configured through training to receive any kind of input and to generate an output that includes any kind of score, classification, or regression (e.g., generative) output based on the input.
[0027] The neural network (either the student neural network 110 or the teacher neural network 150) can be a generative neural network that can be configured to perform a generative task to generate, as output, data that includes, for example, text data, image data, video data, audio data, or multimodal data that includes data in two or more different modalities.
[0028] Some examples of machine learning tasks, including generative tasks, that the neural network (either the student neural network 110 or the teacher neural network 150) can be configured to perform follow.
[0029] As one example, the task may be a neural machine translation task. For example, if the input to the neural network is a sequence of text, e.g., a sequence of words, phrases, characters, or word pieces, in one language, the output generated by the neural network may be a translation of the sequence of text into another language, i.e., a sequence of text in the other language that is a translation of the input sequence of text. As a particular example, the task may be a multi-lingual machine translation task, where a single neural network is configured to translate between multiple different source language - target language pairs. In this example, the source language text may be augmented with an identifier that indicates the target language into which the neural network should translate the source language text.
[0030] As another example, the task may be an audio processing task. For example, if the input to the neural network is a sequence representing a spoken utterance, the output generated by the neural network may be a score for each of a set of pieces of text, each scoreAttorney Docket No. 45288-0598WO1 representing an estimated likelihood that the piece of text is the correct transcript for the utterance. As another example, if the input to the neural network is a sequence representing a spoken utterance, the output generated by the neural network can indicate whether a particular word or phrase (“hotword”) was spoken in the utterance. As another example, if the input to the neural network is a sequence representing a spoken utterance, the output generated by the neural network can identify the natural language in which the utterance was spoken.
[0031] As another example, the inference system can be part of a dialog system and the input sequence can include audio or text from the most recent conversational turn submitted by a user of the dialog system during the dialog while the output sequence is the next turn in the conversation, e.g., either text or audio that is a response to the most recent conversational turn. Optionally, the input sequence can also include one or more historical conversational turns that occurred earlier in the conversation.
[0032] As another example, the task can be a natural language processing or understanding task, e.g., an entailment task, a paraphrase task, a textual similarity task, a sentiment task, a sentence completion task, a grammaticality task, and so on, that operates on a sequence of text in some natural language.
[0033] As another example, the task can be a text to speech task, where the input is text in a natural language or features of text in a natural language and the output is a spectrogram, a waveform, or other data defining audio of the text being spoken in the natural language.
[0034] As another example, the task can be a health prediction task, where the input is a sequence derived from electronic health record data for a patient and the output is a prediction that is relevant to the future health of the patient, e.g., a predicted treatment that should be prescribed to the patient, the likelihood that an adverse health event will occur to the patient, or a predicted diagnosis for the patient.
[0035] As another example, the task can be a text generation task, where the input is a sequence of text, and the output is another sequence of text, e.g., a completion of the input sequence of text, a response to a question posed in the input sequence, or a sequence of text that is about a topic specified by the first sequence of text. In this example, both the input sequence of text and the output sequence of text can include tokens from a vocabulary of textAttorney Docket No. 45288-0598WO1 tokens that includes, e.g., one or more of characters, sub-words, words, punctuation marks, numbers, or other symbols that appear in a natural language or a computer language.
[0036] As a similar example, the task can be an automatic code generation task, where the input is a sequence of words, wordpieces or characters in a first natural language and the output is a sequence of tokens that represent instructions in a computer programming or markup language, or instructions for controlling an application program to perform a task e.g. build a data item such as an image or web page.
[0037] As a particular example of this, the input can represent a context input that includes a text description of a desired piece of code and / or a snippet of computer code in a programming language and the output can be an output sequence that includes computer code, e.g., a snippet of code that is described by the context input or a snippet of code that follows the context input in a computer program.
[0038] As another particular example, the inference system can be part of a computer code verification system and the input sequence can include data characterizing a piece of software code, including code snippets from the software code, data characterizing execution of the piece of software code including artifacts of execution of the software code (e.g. error messages, exceptions, or OpenTelemetry), or program logs. The output sequence may specify (e.g., may include a statement that defines) whether the piece of software code will execute or has executed as intended on the computer system.
[0039] As another example, the input to the text generation task can be an input other than text, e.g., an image, video and / or audio, and the output sequence can be text that describes the input. In this example, both the input sequence and the output sequence of text can include tokens from a vocabulary of tokens that includes tokens that can represent data other than text, in addition to the text tokens mentioned above.
[0040] For example, the vocabulary of tokens can additionally include image tokens that represent a discrete set of image patch embeddings of an image that can be generated by an image encoder neural network based on processing the image patches of the image. Where video is being processed, each frame (e.g., image) within the video may be represented by a series of image tokens. As another example, the vocabulary of tokens can additionally include audio tokens that represent one or more audio waveforms, e.g., code vectors in a codebook of a quantizer, e g., a residual vector quantizer.Attorney Docket No. 45288-0598WO1
[0041] As another example, the task can be an image generation task, where the input is a conditioning input, e.g., text, an image (e.g., an image of a different (e.g., lower) resolution to the output), or a partial image, and the output is a sequence of intensity value inputs for the pixels of an image (e.g. a sequence of tokens that represent pixels). As another example, the task can be an video generation task, where the input is a conditioning input, e.g., text, a video (e.g., a video of a different (e.g., lower) resolution to the output), or a partial video (e.g., a video to be completed), and the output is a sequence of intensity value inputs for the pixels of a sequence of frames (e g., images) of a video (e.g. a sequence of tokens that represent pixels within the frames). For example, the output sequence generated by the generative neural network includes a plurality of color values for pixels in an image arranged according to a specified order. As another example, the output sequence generated by the neural network includes a plurality of tokens that represent image patch embeddings of an image which can then be processed by a decoder neural network to generate the image. For example, the inference system can use the neural network to generate an image or a video conditioned on an input sequence (prompt) that includes a text description of the content of the image or the video to be generated.
[0042] As another example, the task can be an image or video processing task. For example, the input can be the intensity values of the pixels of the image or video or an encoded representation of the intensity values of the pixels generated by an encoder neural network (e.g., a sequence of tokens that represent the pixels), and the output can be (i) an image or video classification output that classifies the input image or video into one of a plurality of object categories (ii) an object detection output, i.e., a sequence that specifies the coordinates of one or more bounding boxes in the image or video that are predicted to encompass objects or (iii) a segmentation output that classifies each pixel in the input image or video into one of a plurality of categories. As another example, the input can include the intensity values of the pixels of the image or video or an encoded representation of the intensity values of the pixels generated by an encoder neural network (e.g., a sequence of tokens that represent the pixels) and optionally text, and the output can be text that characterizes the image or video, e.g., captions the image or video, or answers a question posed by the text in the input about the image or video.Attorney Docket No. 45288-0598WO1
[0043] As another example, the task can be an audio generation task, where the input is a conditioning input, e.g., text, an image, or context audio, and the output is a sequence of tokens that represents audio (e.g. representing an audio waveform).
[0044] As another example, the task can be an audio processing task. For example, the input can include audio or an encoded representation of the audio generated by an encoder neural network, and the output can be text or an image that characterizes the audio.
[0045] As another example, the task can be an agent control task, where the input is a sequence of observations or other data characterizing states of an environment and the output defines an action to be performed by the agent in response to the most recent data in the sequence. The input sequence can comprise a natural language description of one or more tasks for the agent. The output sequence can comprise a sequence of instructions (e.g., joint angles, torques, velocities, etc., or a sequence of sub-tasks, e.g. in PDDL(Planning Domain Definition Language), or as a component in a system such as SayCan (Ahn et al., arXiv:2204.01691, 2022)) for the agent that cause the agent to perform the one or more tasks described in the input sequence. The agent can be, e.g., a real-world or simulated robot, a control system for an industrial facility, or a control system that controls a different kind of agent. For instance, the environment may be a real-world or simulated environment. The agent may be a mechanical agent operating in the environment.
[0046] In a similar example, the inference system can be part of or associated with a control system in a manufacturing environment for manufacturing a product, i.e., a system for controlling a manufacturing unit or a machine that operates to manufacture the product. In another similar example, the inference system can be part of or associated with a control system in a service facility comprising a plurality of items of electronic equipment.
[0047] As another example, the task can be a genomics task, where the input is a sequence representing a fragment of a DNA sequence or other molecule sequence and the output is either an embedding of the fragment for use in a downstream task, e.g., by making use of an unsupervised learning technique on a data set of DNA sequence fragments, or an output for the downstream task. Examples of downstream tasks include promoter site prediction, methylation analysis, predicting functional effects of non-coding variants, and so on.
[0048] In some cases, the machine learning task is a combination of multiple individual machine learning tasks, i.e., the neural network is configured to perform multiple differentAttorney Docket No. 45288-0598WO1 individual machine learning tasks, e.g., two or more of the machine learning tasks mentioned above. For example, the neural network can be configured to perform multiple individual natural language understanding tasks, with the input including one or more identifiers for one or more individual natural language understanding tasks to be performed on the input.
[0049] As another particular example, the inference system can be part of a computer- assisted medical diagnosis system. For example, the input sequence can be a sequence of data from an electronic medical record and the output sequences can each be a sequence of predicted treatments.
[0050] In some cases, the machine learning task is a multi-modal processing task that requires processing multi-modal data. In general, multi-modal data is a combination of two or more different types of data, e g., two or more of audio data, image data, video data, text data, or graph data. As one example the multi-modal data may comprise audio-visual data, comprising a combination of pixels of an image or of video and audio data representing values of a digitized audio waveform. As another example the multi-modal data may comprise a combination of i) text data representing text in a natural language and ii) pixels of an image or of video or audio data representing values of an audio waveform. Optionally, but not necessarily, the different types of data may represent the same or overlapping objects using the different modalities (types), and when processing multi-modal data the data may be mapped into a common embedding space.
[0051] As a particular example, the task is a multi-modal processing task that requires processing both text and image inputs, so that the neural network includes both a computer vision neural network and a text processing neural network. That is, the target output to be generated by the computer vision neural network for a given image depends on one or more outputs generated by the text processing neural network for one or more corresponding text inputs (and vice versa). Examples of such tasks include open-vocabulary image classification, open-vocabulary object detection, image captioning, text-based image search, image-based retrieval, and so on.
[0052] More generally, the multi-modal processing task may correspond to any of the tasks previously described for any of the types of data making up the multi-modal combination. For example, an accuracy of the previously described tasks may be increased when the task is applied to multi-modal data combining the data for which the task has been previouslyAttorney Docket No. 45288-0598WO1 described and another type of data. For example detection or classification of an object or event may be improved when data of multiple different types (modalities) is processed. As another example, the quality (e.g., accuracy, fidelity, or intelligibility) of a generated image, video, or audio may be improved when data of multiple different types (modalities) is processed.
[0053] In some cases, once trained, the neural network can perform tasks that it was not explicitly trained to perform. For example the neural network can perform translation tasks (provided that the training corpus included words in different languages), generative tasks, and many other tasks.
[0054] In these cases, the neural network can be made to perform a particular task by providing a natural language description of the desired response as a part of the input or “prompt”. The prompt may be a few-shot prompt where a few, e.g., 1 to 10, examples of a query and an example output are provided in the text prior to the actual query.
[0055] Additional description of generative tasks that the neural network (when configured as a “generative” neural network) can perform are discussed below.
[0056] Generally, the generative neural network is configured to process a conditioning input (“prompt”) to generate a data item. The data item can include data in any of a variety of modalities, e.g., text data, image data, video data, or audio data. Generally, the data item represents a response to the conditioning input which may be, e.g. a “prompt” for the generative neural network. For example, the conditioning input can characterize one or more desired properties for the generated data item.
[0057] In some implementations the generative neural network generates an output token sequence from an input token sequence including the conditioning input. The generative neural network may then be configured to process the input token sequence to generate for each position in the output token sequence, a respective score for each token in a vocabulary of output tokens, that is used to select an output token for the output token sequence.
[0058] In some implementations the tokens can represent text, e.g., words, wordpieces or characters, in a natural or computer language. For example, text may be received, e.g., as a series of encoded characters, e.g. UTF-8 encoded characters; such “characters” can include Chinese and other similar characters, as well as logograms, syllabograms and the like. A text encoder, i.e. a tokenizer, can process a sequence of text to represent the text as a series of textAttomev Docket No. 45288-0598WO1 tokens from a vocabulary of text tokens, e.g. that each represent words, wordpieces or characters in a natural or computer language. The computer language may be any formal language used to communicate with a computer, e.g. a markup language, or a command or configuration language, or a data exchange language such as JSON, or a programming language. The tokenizer can, e g., implement BPE (Byte Pair Encoding) or Wordpiece tokenization. Optionally the text can be obtained from audio data representing speech; the output tokens may be converted into audio data that represent speech corresponding to the text.
[0059] Also, or instead the tokens may represent an image. For example, a set (sequence) of input or output tokens can represent an image. Each image token may comprise a patch encoding of values of the pixels in a different region of an image that maps a set of values of the pixels to a respective image token. The patch encoder may comprise a neural network, e.g. having one or more (self-)attention layers, such as a Transformer neural network.
[0060] Also, or instead the tokens may represent an audio waveform. For example, a set (sequence) of input or output tokens can represent audio data representing a waveform e.g. instantaneous audio amplitude values or time-frequency audio data. Each image token may comprise a segment encoding of the audio waveform in a different time segment of the audio that maps a set of values representing the audio waveform to a respective image token. The segment encoder may comprise a neural network, e.g. having one or more (self-)attention layers, such as a Transformer neural network.
[0061] In a multimodal system audio data or an image may be flagged by a start-of-audio token or start-of-image token.
[0062] In some implementations the generative neural network can be a multimodal network that is configured to process a conditioning input comprising one or more of text data, audio data defining an audio signal (e.g. as amplitude values of the audio signal or as a timefrequency representation of the audio signal), or a still or moving image (e.g. as image pixel values), to generate a data item that can similarly comprise text data, audio data, or a still or moving image.
[0063] In more detail, the input sequence (prompt) can represent an image or video and the output sequence can comprise data, e.g. text, describing the image or video, i.e. it can comprise a caption or other description of the image or video. In an image-based questionAttorney Docket No. 45288-0598WO1 answering system the input sequence (prompt) can represent an image or video and data, e g. text or audio, that specifies a question and the output sequence can comprise data, e.g. text, that provides an answer to the question. Some other examples of this include an OCR (optical character recognition) task, an open vocabulary classification task (e.g. an image or action recognition task), or an object detection task.
[0064] For example, the conditioning input may comprise text and the data item may comprise an image or an audio signal (e.g., audio waveform) that represents speech an image generated in response to the text, e.g. described by the text. Also, or instead the conditioning input may comprise an audio signal (e.g., audio waveform) that represents speech, or an image, and the data item may comprise text, e.g. that describes the conditioning input. For instance, the input sequence (prompt) can represent audio comprising speech and the output sequence can comprise text corresponding to the speech (e.g., a transcription of the speech), e.g. in a speech recognition task.
[0065] As another example the conditioning input may comprise an observation, e.g. of a real world environment, e.g. from sensor such as a camera or other image sensor; and optionally additional information such as information defining a particular task to be deformed. The output data item may comprise agent control data that defines one or more actions to be performed by an agent, e.g. by a mechanical agent such as a robot or autonomous vehicle, to perform a task. One or more reward models may be used to train the neural network. The reward model(s) may, e.g., define a preferred trajectory of motion of the mechanical agent in the (real-world) environment.
[0066] In general, the student neural network 110 can have any appropriate architecture to enable it to perform these machine learning tasks, e.g., to perform a generative task by processing the conditioning input to generate the data item.
[0067] In some implementations, the student neural network 110 and the teacher neural network 150 can each be configured as an auto-regressive generative neural network.
[0068] A neural network is referred to as an auto-regressive generative neural network when the neural network auto-regressively generates an output sequence of tokens as the output.
[0069] More specifically, the auto-regressively generated output is created by generating each particular token in the output sequence conditioned on a current input sequence that includes an input sequence included in the input and any tokens that precede the particularAttorney Docket No. 45288-0598WO1 token in the output sequence, i.e., the tokens that have already been generated for any previous positions in the output sequence that precede the particular position of the particular token.
[0070] As a particular example, the neural network (either the student neural network 110 or the teacher neural network 150) can have any of a variety of Transformer-based neural network architectures, e.g., encoder-only Transformer architectures, encoder-decoder Transformer architectures, decoder-only Transformer architectures, mixture-of-experts (MoE) Transformer architectures, other attention-based architectures, and so on.
[0071] Examples of such Transformer-based neural network architectures include those described in Colin Raffel, et al., Exploring the limits of transfer learning with a unified text- to-text transformer. arXiv preprint arXiv: 1910.10683, 2019; Daniel Adiwardana, et al., Towards a human-like open-domain chatbot. CoRR, abs / 2001.09977, 2020; Aakanksha Chowdhery, et al., PaLM: Scaling Language Modeling with Pathways, arXiv preprint arXiv:2204.02311; Rohan Anil, et al. Palm 2 technical report. arXiv preprint arXiv:2305.10403, 2023; Gemini Team, et al., Gemini: a family of highly capable multimodal models. arXiv preprint arXiv:2312.11805 (2023); Comanici, Gheorghe, et al. Gemini 2.5: Pushing the frontier with advanced reasoning, multimodality, long context, and next generation agentic capabilities. arXiv preprint arXiv:2507.06261 (2025); and Gemma Team, et al. Gemma 3 technical report. arXiv preprint arXiv:2503.19786 (2025).
[0072] The student neural network 110 includes a sequence (or stack) of blocks. A block refers to a group of one or more contiguous neural network layers in a neural network.
[0073] For example, the student neural network 110 can include a plurality of neural network layers, and each block can include a different subset of the plurality of neural network layers.
[0074] Each block includes a respective set of parameters. The respective set of parameters includes parameters of each of the one or more contiguous neural network layers included in the block.
[0075] The plurality of neural network layers can include any kind of neural network layers, for example, attention layers, e.g., self-attention layers, multi-head self-attention layers, or cross-attention layers, convolutional layers, fully-connected layers, embedding layers, activation layers, or recurrent layers, e.g., Long Short-Term Memory (LSTM) layers or gated recurrent unit (GRU) layers.Attomev Docket No. 45288-0598WO1
[0076] An attention layer is a neural network layer that applies an attention mechanism. An attention mechanism is a scoring and weighting process that determines the contribution of each element within a set of input vectors to the calculation of an output element, resulting in a contextually-weighted aggregation of features. Generally, to apply the attention mechanism, the attention layer uses one or more attention heads. Each attention head generates a set of queries, a set of keys, and a set of values, and then applies any of a variety of variants of query-key-value (QKV) attention using the queries, keys, and values to generate an output. When there are multiple attention heads, the attention layer then combines the outputs of the multiple attention heads, e.g., by concatenating the outputs and, optionally, processing the concatenated outputs through a linear layer.
[0077] Examples of QKV attention variants are described in Vaswani, et al., attention Is All You Need, ar Xiv: 1706.03762, Raffel, et al, Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer, arXiv: 1910.10683, Devlin et al., BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, arXiv: 1810.04805, Kitaev, et al., Reformer: The efficient transformer, arXiv preprint arXiv:2001.04451, and Chowdhery, et al., Palm: Scaling language modeling with pathways. Journal of Machine Learning Research 24.240 (2023): 1-113, the entire contents of which are hereby incorporated by reference herein in their entirety.
[0078] As illustrated in FIG. 1, the student neural network 110 includes an input block 114, followed by one or more intermediate blocks, e.g., intermediate block 1 116-1 through intermediate block N 116-N, followed by an output block 118. When there are multiple intermediate blocks, they can be stacked, i.e., arranged in a stack (sequence) with the output of any block except the last being an input to another of the blocks.
[0079] The input block 114 is configured to receive an input to the student neural network 110.
[0080] The input block 114 can include an embedding layer that can be used to map each token included in the input to a corresponding embedding of the token.
[0081] In some implementations, e.g., where the tokens are text tokens, the embedding layer includes parameters that represent elements of an embedding matrix which can be used to map each token in the vocabulary of tokens into a corresponding embedding of the token.Attomev Docket No. 45288-0598WO1
[0082] In other implementations, e.g., where the tokens are image tokens or audio tokens, the embedding layer can be or include a projection layer that transforms the tokens from an input space into an embedding space (e.g., that has a different, e.g., lower, dimensionality than the input space) where the corresponding embedding of the tokens reside.
[0083] An “embedding” as used in this specification is a sequence of one or more vectors of numeric values, e.g., floating point values or other values, each vector having a predetermined dimensionality.
[0084] The one or more intermediate blocks 116-1 through 116-N can include one or more attention blocks, e.g., one or more local attention blocks, one or more global attention blocks, or both. An attention block includes an attention layer. For example, a local attention block includes a local attention layer, while a global attention block includes a global attention layer.
[0085] A global attention layer is configured to apply a global attention mechanism that, for each of a plurality of input positions in an input sequence, attends over all of the input positions preceding or equal to the input position in the input sequence.
[0086] A local attention layer, on the other hand, is configured to apply a local attention mechanism that, for each of the plurality of input positions, attends only over a set of local input positions that are within a local window of the input position in the input sequence.
[0087] That is, unlike the global attention mechanisms, the local attention mechanism does not attend to any position that is outside of the local window of the input position.
[0088] The local windows are generally “causal,” so that, for any given input position, they include up to a fixed number of input positions that are closest to the given input position and that precede or are equal to the given input position, but not any input positions that are after the given input position in the input sequence.
[0089] The fixed number of input positions is generally much smaller than the total number of positions in the input sequence and is referred to as the size of the context window.
[0090] The output block 118 is configured to generate an output of the student neural network 110. For example, the output can include a score distribution, e.g., a probability distribution, over tokens in the vocabulary of tokens. The score distribution assigns a respective score, e.g., a respective probability, to each token in the vocabulary of tokens.Attorney Docket No. 45288-0598WO1
[0091] When auto-regressively generating an output sequence of tokens, the student neural network 110 can, at each generation step, use the score distribution to select a token from the vocabulary of tokens to occupy an output position in the output sequence that corresponds to the generation step.
[0092] The output block 118 can include a de-embedding layer that can be used to map embeddings to tokens. The de-embedding layer may map embeddings onto an output space. The output space may be the same as the input space. The de-embedding layer may function as a decoder, receiving a lower-dimensional embedding and utilizing a learned transformation to produce a higher-dimensional output. De-embedding may be otherwise known as decoding.
[0093] In some implementations, e.g., where the tokens are text tokens, the de-embedding layer includes parameters that represent elements of a de-embedding matrix which can be used to map an embedding generated by the last intermediate block to a corresponding token in the vocabulary of tokens.
[0094] In some of these implementations, the de-embedding matrix can be a transpose of the embedding matrix represented by the embedding layer included in the input block, such that each parameter of the de-embedding matrix has a corresponding parameter in the embedding matrix.
[0095] In other implementations, e.g., where the tokens are image tokens or audio tokens, the de-embedding layer can include one or more reconstruction layers that transform embeddings that reside in the embedding space to the tokens in the vocabulary.
[0096] The output block 118 can also include a softmax layer. The softmax layer can process an output of the de-embedding layer to generate the output of the student neural network 110 that includes the score distribution over the tokens in the vocabulary.
[0097] The training system 100 makes use of one or more additional blocks, e.g., additional block 1 120-1 through additional block N 120-N. The one or more additional blocks are separate from and independent of both the student neural network 110 and the teacher neural network 150. For instance, the one or more additional blocks may form part of an additional module for use during training.
[0098] These additional blocks 120-1 through 120-N are used by the training system 100 during the teacher distillation training of the student neural network 110 to update the valuesAttorney Docket No. 45288-0598WO1 of the parameters of the student neural network 1 10 in a way that can result in faster convergence of the student neural network 110. By using the one or more additional blocks, the training system 100 conserves the consumption of computational resources during the training because the number of training iterations that are needed to train the student neural network 110 to convergence can be reduced.
[0099] Once the teacher distillation training is complete, the one or more additional blocks are no longer needed, and if desired, they can be discarded by the training system 100. Upon deployment of the student neural network 110 in an inference system, they need not be provided by the training system 100 together with the student neural network 110 to the inference system for deployment.
[0100] The training system 100 can make use of any number of additional blocks during the training. Each additional block can generally include any number of neural network layers and any kind of neural network layers, e g., one or more of the layers mentioned above.
[0101] Each additional block includes a respective set of parameters. The respective set of parameters includes parameters of each of the one or more neural network layers included in the additional block.
[0102] In some implementations, an additional block does not include layers included in either the input block 114 or the output block 118 of the student neural network 110. For example, an additional block may not include any embedding layers, any de-embedding layers, or any softmax layers.
[0103] In some implementations, the one or more additional blocks include a plurality of additional blocks. For example, the plurality of additional blocks can include a plurality of local attention blocks. As another example, the plurality of additional blocks can include a plurality of global attention blocks. As another example, the plurality of additional blocks can include one or more local attention blocks and one or more global attention blocks.
[0104] In some implementations, the one or more additional blocks include only a single additional block. For example, the additional block can be a local attention block. As another example, the additional block can be a global attention block.
[0105] In particular, during training, the training system 100 uses the one or more additional blocks 120-1-120-N together with the existing blocks — including the input block 114, the one or more intermediate blocks 116-1-116-N, and the output block 118 — of the studentAttorney Docket No. 45288-0598WO1 neural network 110 during each forward pass in a proper subset of all forward passes through the student neural network 110 performed during the training. The proper subset includes less than all of the forward passes through the student neural network 110 performed during the training.
[0106] The training system 100 does not use the one or more additional blocks 120-1-120-N during each remaining forward pass through the student neural network 110 that is not in the subset; instead, the training system 100 only uses the existing blocks of the student neural network 110 using those remaining forward passes.
[0107] In particular, the proper subset of forward passes during which the one or more additional blocks 120-1-120-N are used includes some but less than all of multiple forward passes through the student neural network 110 performed by the training system 100 for each training input.
[0108] During each forward pass through the student neural network 110 where the one or more additional blocks are used, the training system 100 receives an input to the student neural network 110, performs a forward pass using the input through the input block 114, the one or more intermediate blocks 116-1-116-N, the one or more additional blocks 120-1-120- N, and the output block 118 to generate an output of the student neural network 110.
[0109] That is, rather than directly providing an output of the last intermediate block as input to the output block 118, the training system 100 provides the output as input to the one or more additional blocks 120-1-120-N, and then provides an output of the last additional block as input to the output block 118 for processing to generate the output of the student neural network 110.
[0110] During each forward pass through the student neural network 110 where the one or more additional blocks are not used, the training system 100 receives an input to the student neural network 110, performs a forward pass using the input through the input block 114, the one or more intermediate blocks 116-1-116-N, and the output block 118 to generate an output of the student neural network 110.
[0111] That is, the training system 100 directly provides an output of the last intermediate block as input to the output block 118 for processing to generate the output of the student neural network 110, skipping the one or more intermediate blocks 116-1-116-N.Attorney Docket No. 45288-0598WO1
[0112] In some implementations, the training system 100 trains the student neural network 110 from scratch, i.e., trains the student neural network 110 starting from randomly initialized values of the parameters of the student neural network 110.
[0113] In some other implementations, the training system 100 trains the student neural network 110 starting from pre-trained values of the parameters of the student neural network 110 that are, e.g., determined as a result of the pre-training of the student neural network 110, or initialized to be the same as the trained values of (at least a portion of) the parameters of another neural network that has already been trained.
[0114] In some implementations, the training system 100 trains the one or more additional blocks 120-1-120-N from scratch, i.e., trains the one or more additional blocks 120-1-120-N starting from randomly initialized values of the parameters of the additional blocks.
[0115] In some other implementations, the training system 100 trains the one or more additional blocks 120-1-120-N starting from pre-trained values of the parameters of the additional blocks that are, e.g., initialized to be the same as the trained values of (at least a portion of) the parameters of another neural network that has already been trained.
[0116] How the training system 100 performs the teacher distillation training using the one or more additional blocks 120-1-120-N will be described below.
[0117] FIG. 2 is a flow diagram of an example process 200 for training a student neural network using a teacher neural network. The student neural network includes an input block followed by one or more intermediate blocks followed by an output block. For convenience, the process 200 will be described as being performed by a system of one or more computers located in one or more locations. For example, a training system, e.g., the training system 100 depicted in FIG. 1, appropriately programmed in accordance with this specification, can perform the process 200.
[0118] Prior to performing the first iteration of the process 200, the system obtains training data for training the student neural network. The training data includes a plurality of training inputs.
[0119] For example, the system can receive training data as an upload from a remote user of the system over a data communication network, e.g., using an application programming interface (API) made available by the system.Attomev Docket No. 45288-0598WO1
[0120] As another example, the system can receive an input from a user specifying which data that is already maintained by the system or another system that is accessible by the system should be used as the training data.
[0121] For example, in implementations where the student and teacher neural networks are each configured as an auto-regressive generative neural network, the training inputs can be training input sequences. Each training input sequence has a plurality of positions. Each position has a token selected from a vocabulary of tokens.
[0122] As mentioned above, the vocabulary of tokens can include one or more of characters, sub-words, words, punctuation marks, numbers, or other symbols that appear in a corpus of natural language text and / or computer code. Additionally, or alternatively, the vocabulary of tokens can include tokens that can represent data other than text. For example, the vocabulary of tokens can include image tokens that represent a discrete set of image embeddings of an image that can be generated by an image encoder neural network based on processing the image. As another example, the vocabulary of tokens can include audio tokens that represent code vectors in a codebook of a quantizer, e.g., a residual vector quantizer.
[0123] For example, the training input sequences included in the training data can be generated from a large dataset of text in one or more natural languages, e.g., text that is publicly available from the Internet or another text corpus, a large dataset of computer code in one or more programming languages, e.g., Python, C++, C#, Java, Ruby, PHP, and so on, e g., computer code that is publicly available from the Internet or another code repository, a large dataset of audio samples, e.g., audio recordings or waveforms that represent the audio recordings, a large dataset of images where each image includes an array of pixels, a large dataset of videos where each video includes a temporal sequence of frames, or a large multimodal dataset that includes a combination of two or more of these datasets.
[0124] The system obtains a batch of training inputs from the training data (step 202). For example, the system can obtain a batch of training input sequences. The system will generally obtain different training inputs at different iterations, e.g., by sampling a fixed number of training inputs from a larger number of training inputs included in the training data at each iteration.Attorney Docket No. 45288-0598WO1
[0125] For each training input in the batch, the system processes the training input using the teacher neural network to generate one or more soft teacher outputs (step 204). Each soft teacher output includes a soft teacher probability distribution over tokens in the vocabulary.
[0126] In some implementations, the system processes, by using the teacher neural network, some or all of the multiple training inputs in the batch in parallel such that the time during which the one or more soft teacher outputs are generated for one training input substantially overlaps the time during which the one or more soft teacher outputs are generated for another training input in the batch.
[0127] For example, when each training input is a training input sequence, leveraging causal masking, the system can process each training input sequence using the teacher neural network to generate, for each different prefix that includes tokens at a different subset of the plurality of positions in the training input sequence, a soft teacher output for the prefix. Each prefix can include a subset starting from the first token in the training input sequence and including each token in the training input sequence up until an end token for the subset. The end token and the first token may be the same (e.g., one prefix may include only the first token). For example, for a training input sequence that has a total of 5 tokens [x x2, x3, x4, x5], the prefixes can include [x- , [x x2], [xltx2, x3], and so on.
[0128] The soft teacher probability distribution assigns a respective probability to each token in the vocabulary of tokens. For each token in the vocabulary of tokens, the respective probability assigned to the token by the soft teacher probability distribution may represent a likelihood predicted by the teacher neural network that the vocabulary token is the token that occupies the next position in the plurality of positions that immediately follows the prefix in the training input sequence. The likelihood may be determined based on each token in the training sequence that precedes the next position. A soft teacher probability distribution may be determined for each position within the input token sequence.
[0129] For each training input in the batch, the system generates one or more first student outputs (step 206). Each first student output includes a first student probability distribution over the tokens in the vocabulary. The system makes use of the one or more additional blocks together with the existing blocks of the student neural network to generate the first student outputs.Attorney Docket No. 45288-0598WO1
[0130] When each training input is a training input sequence, how the system generates the one or more first student outputs for each training input sequence by using the one or more additional blocks is described in more detail with reference to FIG. 3, which is an example flow diagram of sub-steps 302-306 of the step 206 of the process 200.
[0131] The system can repeatedly perform an iteration of steps 302-306 to generate the one or more first student outputs for each training input sequence in the batch. In some implementations, for each batch, the system performs some or all of iterations of steps 302- 306 in parallel such that the time during which the one or more first student outputs are generated for one training input sequence substantially overlaps the time during which the one or more first student outputs are generated for another training input sequence in the batch.
[0132] The system processes the training input sequence using the input block and the one or more intermediate blocks of the student neural network to generate an initial intermediate output (step 302).
[0133] The input block can process the tokens in the training input sequence to generate an embedding of each token in the training input sequence. The one or more intermediate blocks then update each of the embeddings at least in part by applying an attention mechanism to generate a respective output embedding for each of the tokens.
[0134] The input embeddings for the first intermediate block are the embeddings of the tokens in the training input sequence, and the input embeddings for each subsequent intermediate block are the output embeddings generated by the preceding intermediate block.
[0135] Thus, the initial intermediate output can include the output embeddings generated by the last intermediate block in the stack of one or more intermediate blocks for the tokens in the training input sequence.
[0136] For example, each intermediate block can apply either a global attention mechanism, a local attention mechanism, or another variant of self-attention mechanism to update the input embeddings.
[0137] In any case, a causal mask (e.g., that is in the form of a triangular matrix of zeros and negative infinities) can be added to the attention mechanism to ensure that, for a given position in the training input sequence, the intermediate block updates the input embedding for the given position based only on input embeddings at preceding positions that precede theAttorney Docket No. 45288-0598WO1 given position in the training input sequence (and not on input embeddings at subsequent positions that are after the given position in the training input sequence).
[0138] The system processes the initial intermediate output using one or more additional blocks to generate a further processed intermediate output (step 304).
[0139] For example, when the one or more additional blocks include one or more attention blocks, the one or more additional blocks update each of the output embeddings included in the initial intermediate output at least in part by applying an attention mechanism to generate a respective further processed output embedding for each of the output embeddings.
[0140] For example, each additional block can apply either a global attention mechanism, a local attention mechanism, or another variant of self-attention mechanism, with the causal mask added, to update the output embeddings.
[0141] Thus, the further processed intermediate output can include the further processed output embeddings generated by the last additional block in the stack of one or more additional blocks for the tokens in the training input sequence.
[0142] In particular, the initial intermediate output and the further processed intermediate output can include the same number of embeddings, but the numeric values in each of the embeddings will generally differ between the initial and further processed intermediate outputs, i.e., because of the further processing performed by using the one or more additional blocks.
[0143] The system processes the further processed intermediate output using the output block of the student neural network to generate the one or more first student outputs (step 306). That is, the system provides the further processed intermediate output as input to the output block, and the output block processes the further processed intermediate output to generate the one or more first student outputs.
[0144] For example, when the training input is a training input sequence, the system can generate, at the output block and for each different prefix that includes tokens at a different subset of the plurality of positions in the training input sequence, a first student output for the prefix that includes a first student probability distribution over the tokens in the vocabulary.
[0145] The first student probability distribution assigns a respective probability to each token in the vocabulary of tokens. For each token in the vocabulary of tokens, the respective probability assigned to the token by the first student probability distribution may represent aAttomev Docket No. 45288-0598WO1 likelihood predicted by the student neural network together with the one or more additional blocks that the vocabulary token is the token that occupies the next position in the plurality of positions that immediately follows the prefix in the training input sequence.
[0146] For each training input in the batch, the system obtains one or more hard teacher outputs (step 208). Each hard teacher output includes a hard teacher probability distribution over the tokens in the vocabulary. When each training input is a training input sequence, the system can obtain the same number of hard teacher outputs as the soft teacher outputs for each training input. A hard output may be a discrete (e.g., binary) output (e.g., an output including values of either “1” or “0”). In contrast, a soft output may be a probability distribution (e.g., a vector of likelihoods or confidence values) which may be continuous (e.g., that can include values that can range between 0 and 1).
[0147] In some implementations, the system generates the hard teacher outputs for the training input sequence based on the actual training input sequence.
[0148] For example, for a given position in the training input sequence, a hard teacher output can include a set of scores that includes a 1 for a ground truth token in the vocabulary that actually occupies the given position in the training input sequence, and a 0 for each remaining token in the vocabulary.
[0149] In some implementations, the system generates the hard teacher outputs for the training input sequence based on the soft teacher outputs generated for the same training input sequence.
[0150] For example, for a given position in the training input sequence, a hard teacher output can include a set of scores that includes a 1 for a token that is assigned the highest probability amongst all tokens in the vocabulary by the soft teacher probability distribution generated by the teacher neural network for the given position, and a 0 for each remaining token in the vocabulary.
[0151] For each training input in the batch, the system generates one or more second student outputs (step 210). Each second student output includes a second student probability distribution over the tokens in the vocabulary. The system makes use of the existing blocks of the student neural network and without using the one or more additional blocks to generate the second student outputs.Attorney Docket No. 45288-0598WO1
[0152] In some implementations, the system processes, by using the student neural network, some or all of the multiple training inputs in the batch in parallel such that the time during which the one or more second student outputs are generated for one training input substantially overlaps the time during which the one or more second student outputs are generated for another training input in the batch.
[0153] For example, when each training input is a training input sequence, the system processes each training input sequence using the input block, the one or more intermediate blocks, and the output block of the student neural network to generate the one or more second student outputs for the training input sequence.
[0154] As part of the processing of each training input sequence, the system provides the output generated by input block as input to the first intermediate block in the one or more intermediate blocks, and then provides the output generated by the last intermediate block in the one or more intermediate blocks as input to the output block. The output block then processes the output to generate the one or more second student outputs for the training input sequence.
[0155] For example, for each training input sequence, the system can generate, at the output block and for each different prefix that includes tokens at a different subset of the plurality of positions in the training input sequence, a second student output for the prefix that includes a second student probability distribution over the tokens in the vocabulary.
[0156] The second student probability distribution assigns a respective probability to each token in the vocabulary of tokens. For each token in the vocabulary of tokens, the respective probability assigned to the token by the second student probability distribution may represent a likelihood predicted by the student neural network without using the one or more additional blocks that the vocabulary token is the token that occupies the next position in the plurality of positions that immediately follows the prefix in the training input sequence.
[0157] Hence, the system performs two forward passes through the student neural network for each training input — i.e., a first forward pass which uses the one or more additional blocks and a second forward pass which skips using the one or more additional blocks — to generate two different student outputs for the same prefix: a first student output and a second student output that include different student probability distributions over the tokens in theAttorney Docket No. 45288-0598WO1 same vocabulary. Notably, the first and second student outputs are generated by the same output block (from different inputs to the output block).
[0158] In some implementations, to conserve computational resources, the input block and one or more intermediate blocks need not be repeatedly used in the two forward passes through the student neural network for the same prefix, and the output from the last intermediate block is used twice - once directly provided to the output block and once through the one or more additional blocks, without having to calculate the outputs of the input and intermediate blocks twice.
[0159] The system trains the student neural network to update the values of the parameters of the student neural network based on optimizing an objective function that includes a first term and a second term (step 212).
[0160] The first term measures, for each training input in the batch of training inputs, a difference between (i) each of the one or more first student outputs that are generated by using the student neural network and the one or more additional blocks to process the training input, and (ii) each of the one or more soft teacher outputs that are generated by using the teacher neural network to process the same training input.
[0161] For example, because the first student outputs and the soft teacher outputs can each include a probability distribution over the tokens in the vocabulary, the first term can be a cross-entropy loss term or a Kullback-Leibler (KL) divergence term.
[0162] For each training input in the batch, the second term is dependent on the one or more second student outputs that are generated by using the student neural network based on the training input and without using the one or more additional blocks.
[0163] In some implementations, the second term measures, for each training input in the batch, a difference between (i) each of the one or more second student outputs that are generated by using the student neural network and without using the one or more additional blocks to process the training input and (ii) each of the one or more hard teacher outputs for the training input.
[0164] In these implementations, the use of both soft and hard teacher outputs allows the system to train the student neural network to achieve a better trade-off between accuracy to the ground truth outputs and generalization based on the teacher’s knowledge gained through its training. Moreover, the combination of soft and hard teacher outputs provide a form ofAttorney Docket No. 45288-0598WO1 noise mitigation which softens the negative impact of any incorrect soft (or hard) teacher outputs on the training of the student neural network.
[0165] In some implementations, the second term measures, for each training input in the batch, a difference between (i) each of the one or more second student outputs that are generated by using the student neural network and without using the one or more additional blocks to process the training input and (ii) each of the one or more soft teacher outputs that are generated by using the teacher neural network to process the same training input.
[0166] For example, the second term can be a cross-entropy loss term or a Kullback-Leibler (KL) divergence term.
[0167] The system can update the values of the parameters of the student neural network — including the parameters of the input block, the parameters of the one or more intermediate blocks, the parameters of the one or more additional blocks, and the parameters of the output block — by computing, for each training input in the batch, respective gradients of the objective function with respect to the parameters of the student neural network by b ackpropagation through the appropriate parameters. The system can then determine the updates by applying an update rule, e.g., an Adam update rule, an Rmsprop update rule, or a stochastic gradient descent (SGD) update rule, to (e.g., based on) the respective gradients.
[0168] In particular, the system computes the gradients of the first term with respect to the parameters of the output block, computes the gradients of the second term with respect to the parameters of the output block, and then backpropagates the gradients of the first term through the parameters of the output block and through parameters of the one or more additional blocks, while backpropagating the gradients of the second term through the parameters of the output block and skipping the parameters of the one or more additional blocks, to determine the gradients with respect to parameters of the one or more intermediate blocks.
[0169] Thus, the system determines the update to values of the parameters of the output block and the one or more intermediate blocks based on both the gradients of the first term and the gradients of the second term, and determines the update to values of the parameters of the one or more additional blocks based on the gradients of the first term and not on the gradients of the second term.Attorney Docket No. 45288-0598WO1
[0170] As mentioned above, in some implementations, the de-embedding layer includes parameters that represent elements of a de-embedding matrix, and each parameter of the deembedding matrix can have a corresponding parameter in the embedding matrix represented by an embedding layer included in the input block (with the de-embedding matrix being a transpose of the embedding matrix). For example, the embedding matrix can be defined as WDE ]RVocabSlzexHiddenDim,t e(]e.embedding matrix can be defined as WEE
[0171] Thus, in these implementations, the update to the values of the parameters of the output block is also based on the gradients of the first term and gradients of the second term computed through backpropagation with respect to the parameters of the input block.
[0172] The system can continue performing iterations of the process 200 until termination criteria for the teacher distillation training of the student neural network have been satisfied, e.g., until the parameters have converged, until a threshold amount of wall clock time has elapsed, or until a threshold number of iterations of the process 200 have been performed.
[0173] In some implementations, after having trained the student neural network using the teacher neural network, the system or another system, e.g., a fine-tuning system, fine-tunes, i.e., further trains, the student neural network to perform one or more downstream tasks. The one or more downstream tasks can include any of the tasks mentioned above, and possibly other tasks.
[0174] In these implementations, the system or the other system proceeds to adapt the student neural network, e.g., through supervised fine-tuning or reinforcement learning from human feedback (RLHF), on labeled or unlabeled training data that is specific to the downstream task to perform a downstream task.
[0175] That is, a neural network that can be deployed to compute inference for the one or more downstream tasks, e.g., performed generative tasks in response to prompts received from users, can be generated from a student neural network that has been trained by the system implementing the teacher distillation training process.
[0176] The neural network includes the input block, the one or more intermediate blocks, and the output block, and includes parameters having values resulting from the teacher distillation training process. The neural network need not include the one or more additional blocks that are only used during the training.Attorney Docket No. 45288-0598WO1
[0177] Correspondingly, the processing of a new input using the neural network to generate a new output for the downstream task does not involve performing a forward pass through any of the one or more additional blocks using the new input (even through some forward passes through the neural network made use of the one or more additional blocks during the training).
[0178] FIG. 4 is an illustration 400 of an example data flow during a training iteration. In the example illustrated, the student neural network includes 32 intermediate blocks (“transformer blocks 1-32”) that include both global attention blocks and local attention blocks. The training makes use of 4 additional blocks (“distill blocks 1-4”) that include both global attention blocks and local attention blocks. The student neural network employs parameter tying, i.e., has an output block that includes a de-embedding layer representing a deembedding matrix that is a transpose of the embedding matrix represented by an embedding layer included in an input block.
[0179] During the training iteration, a first student output can be generated by performing a forward pass through the input block (“input emb”), the 32 intermediate blocks, the 4 additional blocks, and the output block (“shared softmax”), whereas a second student output can be generated by performing a forward pass through the input block (“input emb”), the 32 intermediate blocks, and the output block (“shared softmax”).
[0180] The objective function to be evaluated during the training iteration includes (i) a first term that measures at least a difference between the first student output and a soft teacher output generated by a teacher neural network (“distill teacher soft log probs”), and (ii) a second term that measures at least a difference between the second student output and a hard teacher output (“ground truth hard labels”).
[0181] In this specification, the term “configured” is used in relation to computing systems and environments, as well as computer program components. A computing system or environment is considered "configured" to perform specific operations or actions when it possesses the necessary software, firmware, hardware, or a combination thereof, enabling it to carry out those operations or actions during operation. For instance, configuring a system might involve installing a software library with specific algorithms, updating firmware with new instructions for handling data, or adding a hardware component for enhanced processing capabilities. Similarly, one or more computer programs are “configured” to performAttorney Docket No. 45288-0598WO1 particular operations or actions when they contain instructions that, upon execution by a computing device or hardware, cause the device to perform those intended operations or actions.
[0182] The embodiments and functional operations described in this specification can be implemented in various forms, including digital electronic circuitry, software, firmware, computer hardware (encompassing the disclosed structures and their structural equivalents), or any combination thereof. The subject matter can be realized as one or more computer programs, essentially modules of computer program instructions encoded on a tangible non- transitory storage medium for execution by or to control the operation of a computing device or hardware. The storage medium can be a storage device such as a hard drive or solid-state drive (SSD), a storage medium, a random or serial access memory device, or a combination of these. Additionally or alternatively, the program instructions can be encoded on a transmitted signal, such as a machine-generated electrical, optical, or electromagnetic signal, designed to carry information for transmission to a receiving device or system for execution by a computing device or hardware. Furthermore, implementations may leverage emerging technologies like quantum computing or neuromorphic computing for specific applications, and may be deployed in distributed or cloud-based environments where components reside on different machines or within a cloud infrastructure.
[0183] The term “computing device or hardware” refers to the physical components involved in data processing and encompasses all types of devices and machines used for this purpose. Examples include processors or processing units, computers, multiple processors or computers working together, graphics processing units (GPUs), tensor processing units (TPUs), and specialized processing hardware such as field-programmable gate arrays (FPGAs) or application-specific integrated circuits (ASICs). In addition to hardware, a computing device or hardware may also include code that creates an execution environment for computer programs. This code can take the form of processor firmware, a protocol stack, a database management system, an operating system, or a combination of these elements. Embodiments may particularly benefit from utilizing the parallel processing capabilities of GPUs, in a General-Purpose computing on Graphics Processing Units (GPGPU) context, where code specifically designed for GPU execution, often called kernels or shaders, is employed. Similarly, TPUs excel at running optimized tensor operations crucial for manyAttorney Docket No. 45288-0598WO1 machine learning algorithms. By leveraging these accelerators and their specialized programming models, the system can achieve significant speedups and efficiency gains for tasks involving artificial intelligence and machine learning, particularly in areas such as computer vision, natural language processing, and robotics.
[0184] A computer program, also referred to as software, an application, a module, a script, code, or simply a program, can be written in any programming language, including compiled or interpreted languages, and declarative or procedural languages. It can be deployed in various forms, such as a standalone program, a module, a component, a subroutine, or any other unit suitable for use within a computing environment. A program may or may not correspond to a single file in a file system and can be stored in various ways. This includes being embedded within a file containing other programs or data (e.g., scripts within a markup language document), residing in a dedicated file, or distributed across multiple coordinated files (e.g., files storing modules, subprograms, or code segments). A computer program can be executed on a single computer or across multiple computers, whether located at a single site or distributed across multiple sites and interconnected through a data communication network. The specific implementation of the computer programs may involve a combination of traditional programming languages and specialized languages or libraries designed for GPGPU programming or TPU utilization, depending on the chosen hardware platform and desired performance characteristics.
[0185] In this specification, the term “engine” broadly refers to a software-based system, subsystem, or process designed to perform one or more specific functions. An engine is typically implemented as one or more software modules or components installed on one or more computers, which can be located at a single site or distributed across multiple locations. In some instances, one or more dedicated computers may be used for a particular engine, while in other cases, multiple engines may operate concurrently on the same one or more computers. Examples of engine functions within the context of Al and machine learning could include data pre-processing and cleaning, feature engineering and extraction, model training and optimization, inference and prediction generation, and post-processing of results. The specific design and implementation of engines will depend on the overall architecture and the distribution of computational tasks across various hardware components, including CPUs, GPUs, TPUs, and other specialized processors.Attorney Docket No. 45288-0598WO1
[0186] The processes and logic flows described in this specification can be executed by one or more programmable computers running one or more computer programs to perform functions by operating on input data and generating output. Additionally, graphics processing units (GPUs) and tensor processing units (TPUs) can be utilized to enable concurrent execution of aspects of these processes and logic flows, significantly accelerating performance. This approach offers significant advantages for computationally intensive tasks often found in Al and machine learning applications, such as matrix multiplications, convolutions, and other operations that exhibit a high degree of parallelism. By leveraging the parallel processing capabilities of GPUs and TPUs, significant speedups and efficiency gains compared to relying solely on CPUs can be achieved. Alternatively or in combination with programmable computers and specialized processors, these processes and logic flows can also be implemented using specialized processing hardware, such as field-programmable gate arrays (FPGAs) or application-specific integrated circuits (ASICs), for even greater performance or energy efficiency in specific use cases.
[0187] Computers capable of executing a computer program can be based on general- purpose microprocessors, special-purpose microprocessors, or a combination of both. They can also utilize any other type of central processing unit (CPU). Additionally, graphics processing units (GPUs), tensor processing units (TPUs), and other machine learning accelerators can be employed to enhance performance, particularly for tasks involving artificial intelligence and machine learning. These accelerators often work in conjunction with CPUs, handling specialized computations while the CPU manages overall system operations and other tasks. Typically, a CPU receives instructions and data from read-only memory (ROM), random access memory (RAM), or both. Computer elements may include a CPU for executing instructions and one or more memory devices for storing instructions and data. The specific configuration of processing units and memory will depend on factors like the complexity of the Al model, the volume of data being processed, and the desired performance and latency requirements. Embodiments can be implemented on a wide range of computing platforms, from small embedded devices with limited resources to large-scale data center systems with high-performance computing capabilities. The system may include storage devices like hard drives, SSDs, or flash memory for persistent data storage.Attorney Docket No. 45288-0598WO1
[0188] Computer-readable media suitable for storing computer program instructions and data encompass all forms of non-volatile memory, media, and memory devices. Examples include semiconductor memory devices such as read-only memory (ROM), solid-state drives (SSDs), and flash memory devices; hard disk drives (HDDs); optical media; and optical discs such as CDs, DVDs, and Blu-ray discs. The specific type of computer-readable media used will depend on factors such as the size of the data, access speed requirements, cost considerations, and the desired level of portability or permanence.
[0189] To facilitate user interaction, embodiments of the subject matter described in this specification can be implemented on a computing device equipped with a display device, such as a liquid crystal display (LCD) or an organic light-emitting diode (OLED) display, for presenting information to the user. Input can be provided by the user through various means, including a keyboard), touchscreens, voice commands, gesture recognition, or other input modalities depending on the specific device and application. Additional input methods can include acoustic, speech, or tactile input, while feedback to the user can take the form of visual, auditory, or tactile feedback. Furthermore, computers can interact with users by exchanging documents with a user's device or application. This can involve sending web content or data in response to requests or sending and receiving text messages or other forms of messages through mobile devices or messaging platforms. The selection of input and output modalities will depend on the specific application and the desired form of user interaction.
[0190] Machine learning models can be implemented and deployed using machine learning frameworks, such as TensorFlow or JAX. These frameworks offer comprehensive tools and libraries that facilitate the development, training, and deployment of machine learning models.
[0191] Embodiments of the subject matter described in this specification can be implemented within a computing system comprising one or more components, depending on the specific application and requirements. These may include a back-end component, such as a back-end server or cloud-based infrastructure; an optional middleware component, such as a middleware server or application programming interface (API), to facilitate communication and data exchange; and a front-end component, such as a client device with a user interface, a web browser, or an app, through which a user can interact with the implemented subjectAttorney Docket No. 45288-0598WO1 matter. For instance, the described functionality could be implemented solely on a client device (e.g., for on-device machine learning) or deployed as a combination of front-end and back-end components for more complex applications. These components, when present, can be interconnected using any form or medium of digital data communication, such as a communication network like a local area network (LAN) or a wide area network (WAN) including the Internet. The specific system architecture and choice of components will depend on factors such as the scale of the application, the need for real-time processing, data security requirements, and the desired user experience.
[0192] The computing system can include clients and servers that may be geographically separated and interact through a communication network. The specific type of network, such as a local area network (LAN), a wide area network (WAN), or the Internet, will depend on the reach and scale of the application. The client-server relationship is established through computer programs running on the respective computers and designed to communicate with each other using appropriate protocols. These protocols may include HTTP, TCP / IP, or other specialized protocols depending on the nature of the data being exchanged and the security requirements of the system. In certain embodiments, a server transmits data or instructions to a user's device, such as a computer, smartphone, or tablet, acting as a client. The client device can then process the received information, display results to the user, and potentially send data or feedback back to the server for further processing or storage. This allows for dynamic interactions between the user and the system, enabling a wide range of applications and functionalities.
[0193] While this specification contains many specific implementation details, these should not be construed as limitations on the scope of any invention or on the scope of what may be claimed, but rather as descriptions of features that may be specific to particular embodiments of particular inventions. Certain features that are described in this specification in the context of separate embodiments can also be implemented in combination in a single embodiment. Conversely, various features that are described in the context of a single embodiment can also be implemented in multiple embodiments separately or in any suitable subcombination. Moreover, although features may be described above as acting in certain combinations and even initially be claimed as such, one or more features from a claimedAttorney Docket No. 45288-0598WO1 combination can in some cases be excised from the combination, and the claimed combination may be directed to a subcombination or variation of a subcombination.
[0194] Similarly, while operations are depicted in the drawings and recited in the claims in a particular order, this should not be understood as requiring that such operations be performed in the particular order shown or in sequential order, or that all illustrated operations be performed, to achieve desirable results. In certain circumstances, multitasking and parallel processing may be advantageous. Moreover, the separation of various system modules and components in the embodiments described above should not be understood as requiring such separation in all embodiments, and it should be understood that the described program components and systems can generally be integrated together in a single software product or packaged into multiple software products.
[0195] Particular embodiments of the subject matter have been described. Other embodiments are within the scope of the following claims. For example, the actions recited in the claims can be performed in a different order and still achieve desirable results. As one example, the processes depicted in the accompanying figures do not necessarily require the particular order shown, or sequential order, to achieve desirable results. In some cases, multitasking and parallel processing may be advantageous.
[0196] What is claimed is:
Claims
Attomev Docket No. 45288-0598WO1CLAIMS1. A method of training a student neural network using a teacher neural network, wherein the student neural network comprises an input block followed by one or more intermediate blocks followed by an output block, and wherein the method comprises: obtaining a batch of training inputs from training data that comprises a plurality of training inputs; for each training input in the batch: processing the training input using the teacher neural network to generate a soft teacher output that includes a soft teacher probability distribution over tokens in a vocabulary; processing the training input using the input block and the one or more intermediate blocks of the student neural network to generate an initial intermediate output; processing the intermediate output using one or more additional blocks to generate a further processed intermediate output; and processing the further processed intermediate output using the output block of the student neural network to generate a first student output that includes a first student probability distribution over the tokens in the vocabulary; and training the student neural network on an objective function that includes (i) a first term that measures, for each training input in the batch, a difference between the first student output and the soft teacher output and (ii) a second term that, for each training input in the batch, is dependent on a second student output that is generated by using the student neural network based on the training input and without using the one or more additional blocks.
2. The method of claim 1, further comprising, for each training input in the batch: obtaining a hard teacher output that includes a hard teacher probability distribution over the tokens in the vocabulary; and processing the training input using the input block, the one or more intermediate blocks, and the output block of the student neural network to generate the second student output without using the one or more additional blocks, wherein the second student output includes a second student probability distribution over the tokens in the vocabulary.Attorney Docket No. 45288-0598WO13. The method of any one of claims 1-2, wherein the second term measures, for each training input in the batch, a difference between the second student output and the hard teacher output.
4. The method of any one of claims 1-3, wherein the input block comprises an embedding layer, the embedding layer comprising parameters that represent elements of an embedding matrix, the embedding matrix being used to map each token in the vocabulary of tokens into a corresponding embedding of the token.
5. The method of any one of claims 1-4, wherein the one or more intermediate blocks comprise one or more attention blocks.
6. The method of claim 5, wherein the one or more attention blocks comprise one or more local attention blocks, one or more global attention blocks, or both7. The method of any one of claims 1-6, wherein the one or more additional blocks comprise one or more attention blocks.
8. The method of claim 7, wherein the one or more additional blocks comprises a plurality of additional blocks, the plurality of additional blocks comprising one or more local attention blocks and one or more global attention blocks.
9. The method of claim 7, wherein the one or more additional blocks comprises only a single additional block, the additional block being a local attention block.
10. The method of any one of claims 1-9, wherein the output block comprises a deembedding layer, the de-embedding layer comprising parameters that represent elements of a de-embedding matrix, the de-embedding matrix being used to map an embedding to a corresponding token in the vocabulary of tokens.Attorney Docket No. 45288-0598WO111 . The method of claim 10, wherein the output block further comprises a softmax layer that processes an output of the de-embedding layer to generate the first student output that includes the first student probability distribution over the tokens in the vocabulary.
12. The method of any one of claims 1-11, wherein the first term and the second term are each a cross-entropy loss term.
13. The method of any one of claims 1-12, wherein training the student neural network comprises: computing gradients of the first term with respect to parameters of the output block; computing gradients of the second term with respect to the parameters of the output block; and determining an update to values of the parameters of the output block based on the gradients of the first term and the gradients of the second term.
14. The method of claim 13 when also dependent on claim 4 and one of claims 10 or claim 11, wherein each parameter of the de-embedding matrix has a corresponding parameter in the embedding matrix, and wherein the update to the values of the parameters of the output block is also based on gradients of the first term and gradients of the second term computed with respect to the parameters of the input block.
15. The method of any one of claims 1-14, wherein training the student neural network comprises: backpropagating gradients of the first term through parameters of the output block and through parameters of the one or more additional blocks to determine gradients with respect to parameters of the one or more intermediate blocks.
16. The method of any one of claims 1-15, wherein processing the intermediate output using the one or more additional blocks to generate the further processed intermediate output comprises:Attorney Docket No. 45288-0598WO1 providing the initial intermediate output as input to both (i) a first block in the one or more additional blocks and (ii) the output block.
17. A computer-implemented method comprising: receiving a new input; and processing the new input using a student neural network that has been trained using the method of any preceding claim to generate a new output.
18. The method of claim 17, wherein processing the new input using the student neural network comprises processing the new input without using any of the one or more additional blocks.
19. A system comprising one or more computers and one or more storage devices storing instructions that when executed by the one or more computers cause the one more computers to perform the operations of the respective method of any one of claims 1-18.
20. One or more computer storage media storing instructions that when executed by one or more computers cause the one more computers to perform the operations of the respective method of any one of claims 1-18.