Automatic partitioning of machine learning models through static program analysis

The described system efficiently partitions machine learning models across multiple processing nodes by using static program analysis to assign unique labels and detect conflicts, enhancing partitioning speed and reducing compute time while optimizing memory usage.

WO2026128632A1PCT designated stage Publication Date: 2026-06-18DEEPMIND TECH LTD +1

Patent Information

Authority / Receiving Office
WO · WO
Patent Type
Applications
Current Assignee / Owner
DEEPMIND TECH LTD
Filing Date
2025-12-10
Publication Date
2026-06-18

AI Technical Summary

Technical Problem

Existing methods for partitioning machine learning models across multiple processing nodes are inefficient and impractical, especially for large models with complex operations, leading to significant computational challenges and resource constraints.

Method used

A system and method that uses static program analysis to partition machine learning models by assigning unique labels to tensor dimensions and applying rules to determine valid partitioning strategies, reducing the search space through dimension unification and conflict detection, and generating program data for optimal sharding across processing nodes.

🎯Benefits of technology

This approach significantly improves partitioning speed and efficiency, optimizing memory usage, reducing compute time, and discovering near-optimal sharding strategies that can be applied across various hardware types, leading to faster training and inference times.

✦ Generated by Eureka AI based on patent content.

Smart Images

  • Figure US2025059039_18062026_PF_FP_ABST
    Figure US2025059039_18062026_PF_FP_ABST
Patent Text Reader

Abstract

Methods and systems for partitioning a machine learning model, or neural network training, across multiple processing nodes. The system can obtain mesh data specifying a configuration of processing nodes, and machine learning model code specifying the model or a training step of the model. The system can then performing a named dimension analysis by assigning unique labels to each tensor dimension of the tensor operations, and applying rules based on types of tensor operation to obtain a reduced set of dimension labels that defines valid ways in which tensor operations can be sharded across the processing nodes. The reduced set of dimension labels can be processed to generate program data that, when executed by each of the plurality of processing nodes, partitions the machine learning model, or the training step, across the plurality of processing nodes.
Need to check novelty before this filing date? Find Prior Art

Description

AUTOMATIC PARTITIONING OF MACHINE LEARNING MODELS THROUGH STATIC PROGRAM ANALYSISCROSS-REFERENCE TO RELATED APPLICATIONS

[0001] This specification claims priority to Greek Application No. 20240100863, filed on December 10, 2024. The disclosure of the prior application is considered part of and is incorporated by reference in the disclosure of this application.BACKGROUND

[0002] This specification relates to processing data using machine learning models implemented on multiple processing nodes.

[0003] Machine learning models receive an input and generate an output, e.g., a predicted output, based on the received input. Some machine learning models are parametric models and generate the output based on the received input and on values of the parameters of the model.

[0004] Neural networks are machine learning models that employ one or more layers of nonlinear units to predict an output for a received input. Some neural networks include one or more hidden layers in addition to an output layer. The output of each hidden layer is used as input to the next layer in the network, i.e., the next hidden layer or the output layer. Each layer of the network generates an output from a received input in accordance with current value inputs of a respective set of parameters.SUMMARY

[0005] This specification describes systems and methods, implemented as computer programs on one or more computers in one or more locations, that partitions a machine learning model, or neural network training, across multiple processing nodes.

[0006] Generally, the multiple processing nodes can include any appropriate hardware devices that can carry out operations required by machine learning model code that specifies the model, e.g. to perform neural network training. Examples of such devices include central processing units (CPUs) and processing nodes that include one or more hardware accelerators that have circuitry for performing multiplication, e.g. matrix-vector multiplication, in hardware, e.g., graphics processing units (GPUs), tensor processing units (TPUs), and other ASICs that are optimized for performing machine learning computations.

[0007] The machine learning model may perform, or be trained to perform, many possible tasks. Merely as an example the machine learning model, or trained neural network, may comprise a neural network that is or includes a Transformer model, e.g. that forms part of a large language model (LLM) or vision language model (VLM).

[0008] In one example aspect a method involves receiving mesh data specifying a configuration of a system with a plurality of processing nodes, and obtaining machine learning model code specifying a machine learning model, or a training step for training a neural network. In particular the machine learning model code identifies a plurality of variables, each in the form of a tensor having one or more dimensions, and a plurality of operations, each associated with a number of the variables. Each operation includes at least one input variable with respect to which the respective operation is performed and at least one output variable resulting from the respective operation.

[0009] For each of the operations, and for each of the variables associated with the respective operation, implementations of the method provides unique labels for each of the dimensions of the respective tensor. A rule associated with a type of the operation is applied to the unique labels for the associated variables to obtain a reduced set of dimension labels, corresponding to a number of ways in which the respective operation and its associated variables can be partitioned across the plurality of processing nodes.

[0010] The mesh data, machine learning model code, and the reduced sets of dimension labels for each of the operations, are processed to generate program data that, when executed by each of the plurality of processing nodes, partitions the machine learning model, or the training step, across the plurality of processing nodes.

[0011] After the partitioned machine learning model, or neural network training step, has been implemented, an input can be obtained for the machine learning model or trained neural network, and processed using the machine learning model or trained neural network to generate an output from the machine learning model or trained neural network.

[0012] In another aspect there is provided a computer-implemented method that involves obtaining an input for a machine learning model or trained neural network and processing the input using the machine learning model or trained neural network to generate an output. The machine learning model, or neural network training step, has been partitioned across a plurality of processing nodes of a computing system by a process that involves obtaining mesh dataspecifying a configuration of the computing system, and machine learning model code defining tenor operations specifying the machine learning model, or the training step. A named dimension analysis is performed by assigning unique labels to each tensor dimension of the tensor operations, and applying rules based on types of tensor operation to obtain a reduced set of dimension labels that defines valid ways in which tensor operations can be sharded across the processing nodes. The reduced set of dimension labels is processed to generate program data that, when executed by each of the plurality of processing nodes, partitions the machine learning model, or the training step, across the plurality of processing nodes.

[0013] There is also described a system comprising one or more computers, and one or more storage devices communicatively coupled to the one or more computers. The storage devices store instructions that, when executed by the one or more computers, cause the one or more computers to perform the operations of the described methods.

[0014] There is further described one or more non-transitory computer storage media storing instructions that when executed by one or more computers perform the operations of the described methods.

[0015] Particular embodiments of the subject matter described in this specification can be implemented so as to realize one or more of the following advantages.

[0016] Implementations of the techniques identify which dimensions of the tensors in the program should be sharded identically, and can therefore avoid the communication that would arise when a tensor is sharded in one way during its defining computation and in another way when used by an operation. This also allows the sharding search space to be significantly pruned. In implementations, rather than rely upon iterative compiler-driven propagation along the dataflow in the program, the dimension analysis uses dimension unification, making it order and propagation-direction independent.

[0017] Implementations of the described techniques first compute this identification and then make the results available for sharding decisions. Pre-computing the identification can substantially improve the partitioning speed because a separate compiler pass is not needed for each decision. For example, a partitioning search algorithm such as can perform a fast, in memory mutation to record a sharding decision. This approach to finding sharding strategies can also work across a wide range of different types of processing node and model size.

[0018] Additionally, the labelling of the tensor dimensions supports conflict detection (i.e. where no sharding may be possible for an operation given a possible partitioning of the inputs to the operation). For example, conflicts may be detected by identifying cases where one of the tensors (i.e. variables) in the program has the same dimension label (belonging to the reduced sets of dimension labels) for multiple ones of its dimensions. An automatic partitioning algorithm, e.g. Monte-Carlo tree search, that takes explicit sharding actions on logical, named dimensions (later described as notional colors), may also take into account conflict resolution orders for conflicting dimensions, directly generated by the analysis performed to generate the dimension labels. In implementations this is facilitated by determining conflict “compatibility sets”, and exposing valid resolution options to the automatic partitioning algorithm.

[0019] These technical features translate into a number of significant, practical advantages. For example exposing the valid resolution options to the automatic partitioning algorithm can allow the system to optimize peak memory usage. This in turn allows the system to discover partitioning strategies that fit in local device memory, which other systems can find difficult. This can also result in faster training step times (forward and backward pass), or inference step times (forward pass), and a major reduction in compute, e.g. weeks of saved computation.

[0020] The system can also discover new and superior, and in implementations near-optimal, sharding strategies, demonstrated by the example results given later. Sharding strategies discovered by the system can be relatively general, e.g. they can be successfully used across different types of hardware, unlike manual sharding strategies which can often be brittle.

[0021] The details of one or more embodiments of the subject matter of this specification are set forth in the accompanying drawings and the description below. Other features, aspects, and advantages of the subject matter will become apparent from the description, the drawings, and the claims.BRIEF DESCRIPTION OF THE DRAWINGS

[0022] FIG. 1 shows an example partitioning system.

[0023] FIGS. 2A and 2B show, respectively, an example dimension graph, and examples of a compatible and incompatible conflicts.

[0024] FIG. 3 flow diagram of an example process for partitioning a machine learning model, or a neural network training step, across multiple processing nodes.

[0025] FIG. 4 illustrates an example implementation of a partitioning process.

[0026] FIG. 5 illustrates partitioning performance of an example of the described system.

[0027] FIG. 6 illustrates execution times of machine learning models, partitioned for training and inference using an example of the described system.

[0028] FIG. 7 illustrates auto-sharding search time for an example of the described system.

[0029] FIG. 8 illustrates scaling of the execution time of neural network training partitioned using an example of the described system.

[0030] Like reference numbers and designations in the various drawings indicate like elements.DETAILED DESCRIPTION

[0031] FIG. 1 shows a partitioning system 100, implemented as computer programs on one or more computers in one or more locations, for determining a partition of machine learning model code 102 for a computing system 150. The system of FIG. 1 can be implemented as computer programs on one or more computers in one or more locations.

[0032] The computing system 150 comprises a plurality of processing nodes 152. Mesh data 110 defines a configuration of computing system 150, in particular of the processing nodes 152A.. N. The increasing size of machine learning models dictates a distributed computing approach because the memory footprints of model parameters and input data, and during training, optimizer states, routinely exceeds the capacity of a single device.

[0033] In implementations a processing node 152 comprises a computing device such as a CPU, GPU, TPU, ASIC, or other device. Often a processing node 152 comprises a hardware accelerator device, e.g. for performing matrix multiplications in hardware, such as a GPU or TPU. In the following description “device” is sometimes used as a shorthand for a processing node.

[0034] In implementations the processing nodes 152 are arranged in a mesh. The mesh data 110 can assign each of the plurality of processing nodes 152 to a respective index along each of one or more axes. For example, the mesh can be a one-dimensional mesh, so that each processing node or device is assigned an index along a single axis. As another example, themesh can be a two-dimensional mesh, so that each device is assigned a respective index along each of the two axes. As yet another example, the mesh can be a three-dimensional mesh, so that each device is assigned a respective index along each of the three axes.

[0035] Generally, the configuration of the processing nodes 152 is a logical configuration of the processing nodes 152, but this can reflect the underlying system communication topology. For example a cluster of 4 servers with 8 accelerator devices can be viewed as having two axes, a fast communication axis along devices and a slower communication axis across servers. As another example 256 accelerator devices could be connected in a torus topology in which each processing node connects to its neighbors in multiple dimensions, e.g. a 2D torus or a 3D torus, e.g. for LLM training.

[0036] The machine learning model code 102 provides an input to the partitioning system 100. The machine learning code defines operations of the machine learning model, e.g. operations of a Transformer neural network (e.g. for an LLM or VLM) or of a U-Net (e.g. for a diffusion model or image processing task). For example the machine learning code can specify a training step for training a neural network, or the machine learning code can specify a model, e.g. operations of a model, for use in inference.

[0037] As used herein a training step for training a neural network can refer to the processing required to train the neural network on a batch of training data. This typically involves updating the parameters, e.g., weights, of the neural network by performing a forward step followed by a backward step through the neural network. The overall training process may be performed over multiple batches of training data; and one or more complete passes through the training data (training epochs) may comprise a plurality of training steps.

[0038] The machine learning model code 102 generally identifies a plurality of variables, each of which is in the form of a tensor having one or more dimensions. The machine learning model code 102 generally also identifies a plurality of operations. Each operation is associated with a number of the variables, including at least one input variable with respect to which the respective operation is performed and at least one output variable resulting from the respective operation.

[0039] In some implementations the machine learning model code 102 can be a so-called intermediate representation (IR) of code written in a higher level framework machine learning such as JAX, PyTorch or TensorFlow. For example, a front end to the system can translate or“lower” from a higher level framework to an IR. For example, the intermediate representation can be an array IR such as StableHLO - which is one MLIR (Multi-Level Intermediate Representation) dialect.

[0040] In general, in an intermediate representation complex calculations are broken down into a sequence of simple operations on tensor variables. For example, an intermediate representation can comprise multiple straight line tensor programs (i.e. without loops or conditional branches). An intermediate representation may be in so-called A-normal form (ANF) or use Single Static Assignment (SSA). In ANF an expression is decomposed so that arguments of its operation(s) are constants or variables rather than other expressions; intermediate computations are bound to a variable using a let expression. In SSA every variable is assigned a value exactly once in the program's static text.

[0041] The partitioning system 100 partitions to generate program data 120, as described in more detail later. That is, the program data 120, when executed by each of the plurality of processing nodes 120A.. N, partitions the machine learning model code, e.g. a training step for training a neural network, across the plurality of processing nodes 102A.. N.

[0042] Partitioning refers to preparing the machine learning model code 102 for execution on the plurality of processing nodes 102A.. N, e.g. on a mesh of devices. In general this involves each device operating on shards of tensors. A shard of a tensor is obtained by slicing the tensor along one of its dimensions. These shards are then distributed across the devices along one of the axes of the mesh. A dimension of a tensor refers to the mathematical shape of a tensor, and an axis refers to a dimension of the logical configuration of the processing nodes 152, e g. one of the axes of the mesh of devices. As used herein “dimension” refers to a tensor (i.e. data) and “axis” refers to the logical device mesh (i.e. hardware). Broadly speaking, partitioning involves actions that comprise sharding a tensor along one of its dimensions and assigning each shard to an axis of the mesh of devices. Partitioning entails sharding, and as shorthand partitioning is sometimes referred herein to as sharding.

[0043] The partitioning system 100 can shard the tensors referred to in the machine learning model code 102, to distribute the tensor processing over the processing nodes 120A.. N. Implementations of the partitioning system 100 can optimize the sharding, taking account of the mesh data, i.e. the device topology by making sharding decisions that partition machine learning model code, e.g. the neural network training step, across the plurality of processingnodes. The program data 120 can be a re-written version of the machine learning model code 102, that takes account of a determined sharding strategy, with communication primitives included as needed to implement the strategy e.g. where the sharding strategy requires synchronization.

[0044] For example, the program data 120 may comprise a device-local intermediate representation combined with any necessary collective communication operations. A devicelocal intermediate representation comprises code to be run on a single processing node 102: It describes the operations to be executed by that node on its specific shard of data (e.g. matmul and other operations that operate on a smaller, sharded part of a tensor). The program data 120 includes communication operations to move data between the node and other nodes, to synchronize tensor processing across the processing nodes, e.g. along a particular axis. As a particular example, the program data 120 may be a sharded version of the machine learning model code 102 in StableHLO or some other MLIR dialect.

[0045] As one particular example, the communication operations may comprise so-called MPI (Message Passing Interface) style primitives. Some examples of MPI primitives follow.

[0046] all reduce: This sums or aggregates partial results from all concerned devices and broadcasts the combined result back to all concerned devices; the devices concerned can be a subset of the devices along a particular axis. It can be used to combine contributions from different devices, e.g. after computing device-local matrix multiplications. As one example, this can be used to obtain the combined result of an attention operation in a Transformer neural network, for sharding and processing by a subsequent feedforward layer. As another example, in a training step, i.e. in a backwards pass through the neural network, it can be used to sum the local gradients (e.g. from each batch of data) calculated by each concerned device, to determine a global gradient (e.g. over the batch of data), that can be returned to all concerned devices to enable them to update their model weights.

[0047] all_gather: This gathers tensor shards from all concerned devices (e.g. along an axis) to reconstruct the full tensor on every concerned device, e.g. to remove partitioning. As an example, this can be used to gather query, key and value (Q, K, V) tensors along a sequence dimension, in the attention layer of a Transformer neural network, prior to matrix multiplication.

[0048] all_to_all: This enables every concerned device, e.g. along an axis, to send a portion of data to every other concerned device, e.g. to enable resharding, e.g. to change from sharding a tensor along one axis to sharding along another.

[0049] reduce_scatter: This sums or aggregates partial results and splits the sum or aggregation amongst concerned devices, e.g. along an axis, to shard a tensor.

[0050] The partitioning system 100 includes a named dimension analysis subsystem 104. This performs a “static” analysis of the machine learning model code 102, as described in detail later, to identify tensor dimensions that should be sharded identically. In particular the named dimension analysis subsystem 104 generates reduced sets (a reduced set) of dimension labels 106 for each of the operations of the plurality of operations of the machine learning model code 102. The reduced sets of dimension labels 106 for each of the operations correspond to a number of ways in which the respective operation and its associated variables can be partitioned across the plurality of processing nodes 102. The named dimension analysis subsystem 104 can also identify partitioning conflicts, where sharding ambiguity should be resolved, as described later.

[0051] The partitioning system 100 also includes an automatic partitioning system 108 that makes tensor sharding decisions based on the mesh data 110, the machine learning model code 102, and the reduced sets of dimension labels 106 for each of the operations. The automatic partitioning system 108 generates the program data 120 that partitions the machine learning model code 102, e.g. the training step, across the plurality of processing nodes 102.

[0052] The tensor sharding decisions in the program data 120 determine the communication operations that are needed for the decisions. That is, the decisions deterministically imply communication operations, e.g. MPI primitives, that should be added to the program data for it to be compiled and executed. The process of adding communication operations can be referred to as “lowering”; i.e. lowering can be performed before providing the program data 120 to a compiler for compilation, and subsequent execution.

[0053] As one example, a sharding decision defines that a tensor should be sharded differently from its current state can dictate that an all_gather operation is included, e.g. to convert a sharded Q, K or V tensor to a full tensor for an attention computation. As another example, a sharding decision that defines that a sharded dimension is summed over, e.g. in a matrix multiplication, can dictate that an all_reduce operation is included to implement the sum.As another example, a sharding decision that defines that a variable definition is sharded differently from its use can dictate that an all_to_all operation is included to enable the resharding.

[0054] In implementations, the automatic partitioning system 108 performs a search over possible partitionings of the reduced sets of dimension labels 106 provided by the named dimension analysis subsystem 104, e.g. using a Monte-Carlo Tree Search (MCTS) or any other suitable search technique, to generate the program data 120. As described later, this search can also take account of any conflicts identified by the named dimension analysis subsystem 104.

[0055] Such an approach, based on a static analysis followed by a search process, can reduce the huge search space of all possible partitionings to a search space of a practical size whilst retaining potentially useful partitions. For example, in implementations the reduced sets of dimension labels can identify dimensions of tensors in the model code that should be sharded identically.

[0056] In implementations the search can use a runtime cost model to optimize for computation speed, or cost in FLOPs (Floating Point Operations), subject to any memory constraints, e.g. to optimize the execution time for a single training or inference step. For example the search can find a fastest partitioning plan that fits any memory constraints of the available hardware. The runtime cost model, and hence the search, can also take account of the time needed for any necessary collective communication operations implied by particular partitioning (sharding) decisions made to resolve conflicts. For example, the search can identify the fastest partitioning plan that takes account of the collective communications overhead needed for conflict resolution.

[0057] To implement the program data 120 on the computing system 150 the program data 120 can be provided to a compiler for compilation. In implementations this also involves inserting the appropriate collective communication operations, such as MPI primitives (“lowering”). This can be done deterministically based on the sharding decisions: As some examples, if the program data shards a dimension that is contracted, e.g. summed over, in a matrix multiplication, an all_reduce can be inserted; if an operation requires a sharded tensor to be fill an all_gather can be inserted; if the tensor needs to change its shardinglayout (e.g. row sharded to column sharded) all_to_all or reduce_scatter can be inserted; and so on.

[0058] The compiler can be any compiler compatible with the program data 120 format, e.g. to compile intermediate representation code such as StableHLO. One illustrative example, suitable for StableHLO, is XLA (Accelerated Linear Algebra) from OpenXLA. The compiler generates compiled program data, e.g. an executable binary, for the specific hardware of computing system 150, i.e. for the processing nodes 152.

[0059] The computing system 150, in particular the processing nodes 152, can be caused to execute the compiled program data generated by the compiler. This can perform the training step for training the neural network, or it can implement an inference-optimized version of the machine learning model, e.g. an inference optimized Transformer, U-Net, or other model. The inference-optimized version of the machine learning model can be served, i.e. made available to end users or to other systems for use in inference (to generate outputs, e.g. predictions). The compiled model may be used for months, either in training or in inference, and it can therefore be very beneficial to optimize the tensor partitioning.

[0060] In some implementations the plurality of processing nodes 152 operates under a single program multiple data (SPMD) model. That is, each device, i.e. each processing node 152, can execute the same program, which program is specified by the program data 120. The program data 120 can comprise device-local SPMD code for execution by each of the plurality of processing nodes 152. For example executing compiled program data generated by the compiler can perform a training step for training a neural network, and under the SPMD model, the program data can specify a same program to be executed by each of the plurality of devices to perform the training step.

[0061] In this setting the program data 120 can specify instructions for a single processing node 152, and these instructions are replicated across the processing nodes 152 to specify the behavior of the full computing system 152. For example, the compiler can generate a single executable binary that is sent to every device. Operations are performed on tensor shards in the device’s local memory, and data is exchanged as defined by the collective communication operations, using device IDs to coordinate the data exchange.

[0062] It is helpful for understanding the described techniques to outline some of the difficulties that can arise when partitioning machine learning models for execution on a mesh of devices.

[0063] Consider the following simplified code for a MLP (multilayer perceptron), in which samples are passed in as x and the model parameters (weights) are given as arguments wl and w2:def mlp (x: [256, 32 ],wl: [32, 64 ],w2: [ 64, 16] ) {y: [256, 64 ] = matmul (x, wl)z: [256, 64 ] = ReLU (y)w: [256, 16] = matmul ( z, w2 )return w}

[0064] It can be seen that all the operations act as maps along the tensor dimension 256, and thus the operations can be executed in parallel on shards of tensors along these dimensions. This is illustrated below, where (b } indicates sharding along axis b and each operation now operates on smaller tensors:def mlp (x: [ 256 {b }, 32],wl: [32, 64 ],w2: [ 64, 16] ) {y: [256 {b], 64] = matmul (x, wl )z: [256 {b], 64] = ReLU (y)w: [256 {b], 16] = matmul (z, w2 )return w

[0065] This is device local code in which the returned result, w, on each device represents a shard of the full tensor result. It can be referred to as batch partitioning, and a shard of the samples x can be referred to as a batch. The code can be further partitioned, e.g. across the tenson dimension 64, and along axis m, which leads to device local computations up to and including the definition of z. However in the final matmul the dimension 64 is “reduced out”, i.e. it disappears, and each device only computes a contribution to the output tensor. An all_reduce is used to sum up the contributions along axis m (indicated by the attribute {m}):def mlp (x [256 {b], 32],wl [32, 64 {m} ],w2 [ 64 {m], 16] ) {y: [256 {b], 64 {m} ] = matmul (x, wl )z: [256 {b], 64 [m] ] = ReLU ( y) w_: [256 {b], 16] = matmul ( z, w2 )w: [256 {b], 16] = all reduce [m] w return w}

[0066] In the above example a useful sharding can be determined by inspection, but this becomes difficult with a large model, which can have tens or hundreds of thousands of operations, and where trying all possibilities is impractical. It would be useful to identify, in advance of partitioning, which dimensions should be sharded together.

[0067] Implementations of the described techniques address this problem by determining reduced set of dimension labels (using ND A subsystem 104), in which the labels correspond to notional colors or “logical dimension names”. If dimensions that should be sharded together are given a notional color then partitioning can be reduced to: pick an axis a, pick a color C, and shard all tensors whose dimensions include C along the dimension colored with C (unless, as a comer case, the tensor is already sharded across axis a). The colors can be discovered by applying rules that describe parallel and contracting dimensions for every operation, e.g. for every operation in the intermediate representation. The space of possible partitionings can then be explored by the automatic partitioning subsystem 108, e.g. using a search (such as an MCTS) that successively shards tensors along labelled, i.e. colored, dimensions, one axis at a time.

[0068] The named dimension analysis subsystem 104 can also identify sharding conflicts and possible resolutions of these. For example the named dimension analysis subsystem 104 can identify all possible options for resolving sharding conflicts and provide these to the automatic partitioning subsystem 108, so that these are also exposed in the search space.

[0069] Consider the function f below, in which x flows into the matmul directly and through y (which is a transpose of x):def f (x: [32, 4 ] ) {y: [4, 32 ] = transpose (x)z: [32, 32] = matmul (x, y)return z}

[0070] Propagating the sharding of x along both paths leads to ambiguity in the sharding of the matmul. For example, if x is sharded on its first dimension (of size 32), and y is sharded on the second dimension, it is unclear whether the matmul should be sharded along its firstor second dimension. Conceptually, if colors were assigned to tensor dimensions this situation would correspond to the matmul receiving the same color on both dimensions. This is referred to here as a sharding conflict: when partitioning the matmul it is necessary to pick for sharding one of the two dimensions that have the same color.

[0071] Implementations of the described techniques address this issue, and expose possible options for resolving sharding conflicts, by keeping track of different dimension names for different uses of tensors, as explained further later. This is particularly useful for partitioning attention computations, e.g. in a Transformer model.

[0072] Some of the principles underlying the operation of an example implementation of the named dimension analysis subsystem 104 are now described in more detail.

[0073] Each operation in the machine learning model code 102 is assigned unique dimensions for its operands and for its results. Rules are applied to identify, and unify, some of these. The rules define how the dimensions or operands can be sharded together. As an illustration, a matmul can be computed in a sharded manner along its first dimension if provided with shards of its first operand also on the first dimension. Once the dimensions have been identified and unified based on the rules, each remaining (unidentified) dimension name will appear in a set of dimensions analogous to the previously described reduced set of dimension labels, i.e. colors. (A process of allocating “fresh names”, described later, corresponds to the providing of unique labels, that are subsequently unified to obtain the reduced set of dimension labels or colors, that determines a number of partitioning decisions to be made).

[0074] An example set of rules is now described. These can be implemented, e.g. recursively, by the named dimension analysis subsystem 104. Without loss of generality, the example refers to a straight line tensor program in A-normal form or SSA form. For example, in an intermediate representation, such as StableHLO, loops, branches, and other control flow structures can be represented as operators that take functions (regions) as arguments, and the FUNCTION rule given later can be applied.

[0075] To illustrate, a loop can be defined as while ( condition_function, body_function, initial_variables ). The FUNCTION rule links the dimension names of the operand with the dimension names of the function’s arguments, and for a loop this effectively enforces that the input to the loop, arguments of the body, and output of the body all share compatible dimension names / sharding strategies. Similarly for a branch thenamed dimension analysis subsystem 104 can require that both branches of a conditional are sharded in the same way. Repeated layers can be characterized as an unrolled loop. These are discussed later; the subsystem 104 can enforce consistent partitioning across iterations (layers), to reduce the search space.

[0076] The rules operate in an environment E that maps free variables to dimension names. In implementations the named dimension analysis (NDA) subsystem 104 determines a triple comprising: (i) an assignment of dimension names, abto tensor operation (“op”) results; (ii) a mapping AT connecting dimension names of value definitions to the dimension names of their uses; and (iii) a set of identities / between dimension names:NDAE:Environment: (e : Expression) → (d̄ : Named dimensions) × (M : Map) × (I : Identities) Here an “identity” refers to the concept that if one dimension is sharded in a particular way then some other dimension must be sharded in the same way for an operation to be mathematically valid, i.e. they define alignments between dimension names. Sometimes such alignment does not require any collective communication; sometimes the alignment requires collective communication for the operation to be performed.

[0077] An example set of definitions and rules follows; an illustration of their use is given later. In the rules a bar denotes a vector so that, e.g., x denotes a list of arguments (e g. x1,x2,x3,...) and a in VARIABLE USE denotes a list of fresh variable names; and “fresh” refers to a new, unique name.Tensor expressions in ANFe x variable| let x - op(x) in e local definitionop ∈ {f, add, transpose, matmul, . . .}Auxiliary definitionsa, b, c,d,d\,..., dfc dimension names d[] j [dt, ■ - •, dfc ] Named dimensions £ ■ ] E, x: d EnvironmentM ::= ∅ | M ∪ {c ↦ d} MapI ::= ∅ | I ∪ {d ≐ c} IdentitiesNDA_E(let x = op(x̄) in e) := (d₂, M₁ ∪ M₂, I₁ ∪ I₂) (LET)where (d₁, M₁, I₁) = NDA_E(op(x̄)), (d₂, M₂, I₂) = NDA_{E,x:d₁}(e)

[0078] This LET rule analyses the operation op(%) to determine what dimension names its result should have. It produces ((d₁, M₁, I₁) where d₁ are the dimension names assigned to the result, M₁ is the mapping of variable uses within the op, and I₁ comprises any sharding identities generated by the op. It also analyses the so-called body, e (after “in”), which defines the subsequent scope of x, i.e. the remainder of the code where the newly defined variable x is valid and can be used. The rule updates the environment E to record that x now exists and has dimension names d₁ (denoted E, x : d₁), and recursively analyses the subsequent expression e using this updated environment. This produces (d₂, M₂, I₂) where d₂ are the dimension names of the final result of the expression e, M₂ is the mapping of variable uses within e, and I₂ comprises any sharding identities found within e.

[0079] The final product of the let rule comprises the result dimensions, d2, because the result of let is the result of its body e (the let part defines a new variable x and the value, and hence dimensions, of the expression is determined solely by evaluating the body, e). The final product of the rule also comprises a merged map,U M2, combining all variable usage history from both the definition and the body; and merged identities,U I2, accumulating all sharding constraints found in both the definition and the body. This rule allows the analysis to propagate dimension names forward. By analyzing op(x) and then passing the result d into the environment for e, the NDA subsystem 104 ensures that when x is used inside e, the analysis can determine what dimensions are referred to. In some implementations (but not necessarily) this information is used to build a connected "dimension graph", that can be used for partitioning by the automatic partitioning system 108.NDA_E(x) := (ā, {dᵢ ↦ aᵢ}, ∅),(VARIABLE USE) if x: d e E, with fresh

[0080] This VARIABLE USE rule can be used by the NDA subsystem 104 to process the occurrence of a variable x in the model code 102, e.g. as an argument to a function or an operand in a calculation. The environment E already contains a record of x from when it was defined, and the rule can look up that x has original dimension names d (e.g. a listdlt... d..., dk). The rule generates a list of “fresh”, i.e. new, dimension names d (e g. a list alt... ait..., ak) for this specific use of the variable. Here, even though it is use of the same variable x, the specific instance of the use is given a unique identity, so that if x is used in two different places each different use gets a respective set of fresh names. This helps the NDA subsystem 104 to detect when different parts of the code impose conflicting sharding requirements on the same tensor.

[0081] The VARIABLE USE rule generates a map {dti — > aq} that records a directed link from the definition’s dimension name, dL, to the uses dimension name, at. In some implementations this is map used to define directed edges in a dimension graph. For example it can be used to trace the history of a value backwards to determine that if, say, a constraint is applied to aᵢ then it also affects (applies to) dᵢ. The 0 in the rule indicates that no sharding identities (constraints) are produced by the rule. This is because using a variable does not, of itself, force any dimensions to be sharded or replicated: constraints (identities) arise when operations (such as add or matmul) interact with a variable.NDA£(f(x)):= & di}),(FUNCTION) where (d, At, 0) - NDAg(x),with d f fresh

[0082] The FUNCTION rule refers to function calls, where f (x) is a function and x is an argument passed to the function. The argument x has dimension names d̄, and the analysis returns the triple (ā, {eᵢ ≐ dᵢ}) where ā represents the dimension names of the result returned by f(x), and where are fresh (new) dimension names for the result. The notation eᵢ ≐ dᵢ defines that these dimension names are identified with one another, and hence that they should be sharded in the same way, i.e. that the argument passed to the function and the result returned from the function should have the same sharding strategy, so that the function does not require resharding.NDA_E(op(x, y)) := (ā, M₁ ∪ M₂, {aᵢ ≐ dᵢ, aᵢ ≐ cᵢ})(OP) where ([rfi,..., dk]» At i, 0) = NDA£(x),([c₁,...,cₖ], M₂, ∅) = NDA_E(y),with Gi fresh, for op e {add, mul,...}

[0083] The OP rule refers to elementwise operations, such as add and mul, that take two arguments x and y; such operations typically require their inputs to be the same shape, or to bebroadcastable. Input x has dimension names [d₁,..., dₖ] and input y has dimension names [c₁,..., cₖ]; the rule generates fresh (new) names [a₁,..., aₖ] for the result. The rule combines (merges) the mappings of the two arguments to track the variable use history, and the identities (7) define that the same sharding (partitioning) applies to both the two input arguments and the result.

[0084] In more detail, this rule generates a set of identities {at= dt, aL= c (which implies ct= dt), which asserts that the ith dimension of the result is equivalent for sharding purposes to the ith dimension of each of x and j'. This defines that both x and j’ should be sharded in the same way (e.g. along a particular axis), and that the result should also be sharded in the same way, if collective communication operations are to be avoided i.e. for local device processing of respective shards. If x and y are not sharded in the same way a collective communication, e.g. all_to_all, operation is inserted during lowering.NDA_E(reduce_op^r(x)) := ([a₁,...,aᵣ₋₁,aᵣ₊₁,...,aₖ], M, {aᵢ ≐ dᵢ}), (REDUCE)where ([d₁,...,dᵣ,...,dₖ], M, ∅) = NDA_E(x),with at fresh, for op t {add, mul,..

[0085] The REDUCE rule defines how dimension names and sharding constraints propagate when a tensor is reduced, e.g. summed, along a particular dimension. The operation (“op”), such as add or mul, takes a tensor x and reduces it along dimension index r. (e.g. operating on a 2D matrix, x, reduce-Laddx would sum over rows to reduce x to a vector representing column sums). The rule determines dimension names [d₁,..., dᵣ,..., dₖ] of x, and generates a new list of dimension names without dᵣ, assigning fresh (new) names [a₁,..., aᵣ₋₁, aᵣ₊₁,..., aₖ] to the remaining dimensions, and a set of identities {aᵢ ≐ dᵢ} without dᵣ or aᵣ. That is, the nonreduced dimensions should maintain their sharding configuration (this requires no collective communication as the reduction is independent for each shard). There is no constraint on the reduced dimension r, but if this dimension is sharded the lowering step will need to introduce, e.g., an all-reduce primitive, to sum the partial results from different devices).NDA_E(matmul(x, y)) := ([a₁,a₂], M₁ ∪ M₂, {a₁ ≐ d₁, a₂ ≐ c₂, d₂ ≐ c₁}), (MATMUL)where ([d₁, d₂], M₁, ∅) = NDA_E(x),( [ci, eg], M2, 0 ) " NDAE ( y )> with «i, az fresh

[0086] The matmul rule propagates dimension names and identifies sharding constraints as defined: the result if the matmul is assigned fresh dimension names and any existing variable-to-use mapping from the analyses of x and y are combined. The identities define that= dli.e. that the first dimension of the result corresponds to the first dimension of the first input, x; and that a2= c2, i.e. that the second dimension of the result corresponds to the second dimension of the second input, y,. These pairs of dimensions can be sharded in the same way without collective communication, i.e. sharding along one of these dimensions naturally results in sharding along the equivalent dimension. The identity d2= y defines that the columns of x should correspond to the rows of y, and hence these dimensions should be sharded in the same way. However this dimension is summed, i.e. reduced, over in the multiplication, and hence collective communication, e.g. an all_reduce, is required (introduced in lowering) if this dimension is sharded.NDA^ (transpose,. ( )) ~(TR NSPOSE)ar, ai,..., AL ( a, £ di } ),where ( | di, dy...,dr,...,dfc ], AL 0) - NDA / j ( x ), with <2 / fresh

[0087] The TRANSPOSE rule relates to the permutation of tensor dimensions, swopping the positions of the dimensions at indices 1 and r. It first identifies the list of dimension names of x ([d₁,..., dₗ,..., dᵣ,..., dₖ]), then generates fresh dimension names [a₁,..., aᵣ,..., aₗ,..., aₖ] for the result, and creates identities that link each fresh (new) name aᵢ to each corresponding old name dᵢ i.e. aᵢ and dᵢ should be sharded in the same way. Note that positions r and l are swopped in the list for a, so that, e.g., sharding aₗ in a particular way requires dᵣ to be sharded in the same way (new dimension l corresponds to old dimension r).NDA_E(broadcastₗ(x)) := (BROADCAST)([a₁,...,aₗ,...,aₖ], M, {a₁ ≐ d₁,...,aₖ ≐ dₖ})where ([d₁,..., dₗ₋₁, dₗ,..., dₖ], ∅, M) = NDA_E(x), with a₁,..., aₗ, aₗ₊₁,..., aₖ fresh

[0088] The BROADCAST rule relates to the insertion of a new dimension into a tensor at index I. The rule generates a new list of dimension names with a fresh name a at position I, and the existing dimensions are also allocated fresh names atwhere i ranges from 1 to k,. The identities, I, define that every dimension name aᵢ in the output is linked back to dimension name dᵢ in the input, so that these are sharded in the same way. The new dimension name does not appear in the set of identities, and this is unconstrained by the input sharding.

[0089] As an illustration of how these rules work, consider the definition and use of variable y- let y = x in yassuming x is assigned dimension names [dltd2] in the current environment E. This invokes the VARIABLE USE rule to determine that the use of x should be assigned fresh names [alttz2], and to populate the map with {d1i— > altd2•— » a2). The LET rule then adds the assignment y: [alta2] to E before invoking he use of the VARIABLE USE rule again, this time for the use ofy after “in”. The use ofy is assigned fresh names [b- b2] and the map M is extended with {ati— > blta2b2}. This map connects dimension names between definitions and uses: bxand b2can be tracked all the way back to d±and d2from the definition of x.

[0090] The identities, I, from the NDA subsystem 104 record which dimensions that appear in an operation should be sharded in the same way. As an illustration of this consider matmul(x, y): Assuming that the MATMUL rule has assigned dimension names [di,d2] and [ci, c2] to the use of x andy respectively, this instance of the MATMUL rule can be expressed as matmul(x: [dt, d2],y: [, c2]) together with the identities= d, a2= c2, and d2= c1. The first identity, aA expresses that matmul acts as a map on the leading dimension of the first operand. That is, if the first operand is sharded on that dimension the matmul can be computed in a sharded manner, concatenating the shards on the leading dimension of the result to obtain the result that would have been computed without partitioning. The second identity is analogous. The third identity expresses that matmul can be sharded by sharding both operands along the contracting dimension (but lowering must then introduce an all_reduce).

[0091] As a more detailed illustration, referring back to the previous example MLP code, the example set of rules given above can be applied to this code as follows:def mlp (x: [B, X], wl: [ T, U], w2: [V, W] ) {y: [Al, A2 ] = matmul (x: [Bl, XI ], wl: [Tl, Ul ] ) # from use of x: B -> Bl, X -> XI# from use of wl: T -> Tl, U -> Ul# from matmul: Al = Bl, A2 = Ul, XI = Tlz: [Cl, C2 ] = ReLU (y: [DI, D2 ] )# from use of y: Al -> DI, A2 -> D2# from ReLU: Cl = DI, C2 D2w: [ E l, E2 ] = matmul f z: [ Fl, F2 ], w2: [VI, Wl ] ) # from use o f z: Cl -> Fl, C2 -> F2# from use o f w2: V -> VI, W -> Wl# from matmul: E l = Fl, E2 = Wl, F2 = VIreturn w}

[0092] In this example, tensor variables are annotated with named dimensions (rather than, as before, their shapes), and the computed map M and identities I are in the comments. For example the input x is defined as [B, ] but is used with fresh names [Bl, XI] in the first matmul (the VARIABLE USE rule) and the mapi Uj] provides the mapping B RBI and X-> X1 (also the VARIABLE USE rule), allowing the NDA subsystem 104 to track this. The matmul, generates a result y which, because it is a new value, gets new names [All, V12] from the MATMUL rule, with identities, e.g. Al = Bl, that identify that Al should be sharded in the same way as B 1 (the MATMUL or OP rule). The ReLU operation is an example of an elementwise operation that can be handled by the OP rule.

[0093] The identities shown in the above code can be applied to reveal the different ways in which each operation can be partitioned, as below:de f mlp ( x: [ B, X ], wl: [ T, U], w2: [V, W ] ) {y: [Al, A2 ] = matmul ( x: [Al, XI ], wl: [XI, A2 ] ) # B -> Al, X -> XI, T -> XI, U -> A2z: [ Cl, C2 ] = ReLU ( y: [ Cl, C2 J )# Al -> Cl, A2 -> C2w: [ E l, E2 ] = matmul ( z: [ E l, VI ], w2: [VI, E2 ] ) # Cl -> El, C2 -> VI, V -> VI, W -> E2return w}

[0094] Here the NDA subsystem 104 is partitioning the operations in isolation, i.e. considering the constraints of individual operations. For example in matmul the output rows should match the input rows, and the contracting dimensions should match. The identities can be applied by updating the dimension names to reflect these requirements, as defined in the rules, e g. to retain just Al (for output rows) rather than also having Bl (input rows), instead renaming Bl to Al. In this code the lines of code are self-consi stent and the links between these lines, in particular from the dimension names as defined to those as used, are expressed explicitly bythe map. For example there is a comment B-> A1 which maps the source, B (a definition), to the destination Al (a use). This defines a directed graph of dependencies, as explained in more detail later.

[0095] The dimension names can be further linked using the map, as shown below, but in implementations this further step is not taken, to facilitate dealing with sharding conflicts: def mlp (x: [B, X], wl: [X, U], w2: [U, W] ) {y: [B, U] = matmul (x: [B, X], wl: [X, U] )z: [B, U] = ReLU ( y: [B, U] )w: [B, W] = matmul ( z: [B, U], w2: [U, W] )return w}

[0096] The above described analysis, which assigns fresh names and uses a map to enable dimension names to be tracked back through the model code, can facilitate conflict detection. For example is a particular variable, say y, were used in two different ways, e.g. in matmul (x, y) and in transpose^), each use would get its own fresh name. If the analysis later determines that one use, say, requires sharding whilst the other se requires replication, the NDA subsystem 104 can detect this conflict because a single definition would map to two incompatible usage requirements. Note that, in implementations that use a dimension graph to detect conflicts (described later) a different approach to detecting a conflict is used.

[0097] In more detail, the same dimension name may appear more than once amongst the names that annotate the same variable. Consider the following example:def f (x: [ S, T ] ) {y: [T, S ] = transpose (x)z: [S, S ] = matmul (x, y)return z}

[0098] When attempting to shard all dimensions that are labeled S there is a decision as to which dimension of z should be sharded because one device axis cannot shard more than one dimension of a single tensor, i.e. the hardware cannot split the data in two different directions using the same set of devices. Such a situation is referred to herein as a sharding conflict (or conflict).

[0099] Sharding conflicts can occur in an attention neural network layer, e.g. of a Transformer model. The following code represents a QKV (query, key, value) attention mechanism, with a simplification for presentational purposes (the softmax is replaced with an average which also involves reduction and elementwise operations). The code represents the result after theprevious rules, that define identities and a data flow map, have been applied. The tensors have dimension names such as S, D, and Hl; the code shows a “global” view in which S, the input sequence length, propagates through to keys k, queries q, and the attention matrix a.de f attn ( x: [ S, D ], wq: [ D, Hl ],wk: [ D, Hl ], wv: [ D, H2 ] ) {k: [ S, Hl ] = matmul ( x, wk) # keysv: [ S, H2 ] = matmul ( x, wv) # valuesq: [ S, Hl ] = matmul ( x, wq) # queriesqt: [ Hl, S ] = transpose ( q)a: [ S, S ] = matmul ( k, qt )# begin: mock so ftmax computation ( averaging )b: [ S ] = reduce 1, add ] ( a )c: [ S, S ] = broadcast_ { 0 } (b )d: [ S, S ] = div ( a, c )# end: mock softmax computation ( averaging )z: [ S, H2 ] = matmul ( d, v)return z}

[0100] This code illustrates a conflict in the definition of a, as S appears twice in its annotation, once from the keys k and once from the transposed queries qt. (This is resolved in the reduce which removes the reduced-over dimension). This conflict propagates to the definitions of c and d, in particular by the OP rule, which states that the dimensions of the input and result should be the same (if the result is to be computed elementwise without moving data). The final matmul contracts over the dimension S and removes the conflict (S only appears once in the result).

[0101] For this example, one possible resolution of the conflicts in the attention layer is by sharding the last dimensions of a, c, and d, along axis s, introducing an al l_gather and a reduce_scatter operation, i.e. by performing sequence sharding. This allows the very large [S, S] matrix to be sharded across devices along its second axis, enabling the model to handle large sequence lengths that could otherwise be difficult to fit into memory:de f attn ( x: [ S { s ], D], wq: [ D, Hl ],wk: [ D, Hl ], wv: [ D, H2 ] ) {k: [ S { s ], Hl ] = matmul (x, wk )v: [ S { s ], H2 ] = matmul (x, wv)q: [ S { s ], Hl ] = matmul (x, wq)qt: [ Hl, S [ s ] ] = transpose ( q)k_: [ S, Hl ] = all_gather { s } ka: [ S, S { s ] ] = matmul ( k_, qt )b: [ S { s ] ] = reduce { l, add ] ( a )c: [S, S { s } ] = broadcast_{ 0 } (b)d: [ S, S { s } ] = div (a, c)z_: [S, H2] = matmul (d, v)z: [S { s }, H2 ] = reduce_scatter { s } z_return z}

[0102] Implementations of the NDA subsystem 104 can identify possible resolutions of conflicts, e.g. to expose them to the automatic partitioning system 108. In implementations this is done by using the identities defined by the sharding rules to merge dimension names but then by keeping variable definition and variable use names distinct rather than by using the mapM to merge them.

[0103] Instead, for example, the map can be used to construct a “dimension graph” in which the nodes are dimension names and in which the map defines edges that connect these nodes. In implementations the edges can be directed from variable definition to use. That is, the identities (I) merge dimension names locally (as previously described), but the dimension names are then kept separate rather than being further merged using the map Jf.

[0104] The dimension graph can then reveal where conflicts arise, i.e. when one variable definition feeds into two incompatible uses. Keeping the nodes distinct allows the automatic partitioning system 108 to choose to shard connected nodes differently to resolve a conflict, from valid options for resolving the conflict determined by the NDA subsystem 104. More particularly the NDA subsystem 104 uses distinct nodes to represent dimension names associated with each individual operation, which allows sharding conflicts, as well as options for resolving these conflicts, to be identified for use by the automatic partitioning system 108.

[0105] Note that use of a dimension graph is optional. For example, dimension names can be merged using the map, as previously described and illustrated. This is sufficient for simpler machine learning models, and the approach can also be used for more complex models, e.g. by manually modifying or splitting the code to address conflicts. However such manual adaption of code can be difficult, and use of a dimension graph can further automate the partitioning process.

[0106] As an illustrative example of the use of a dimension graph, the attention neural network layer code below identifies dimension names using the identities defined in the previously given sharding rules, but has not applied the data flow map.def attn (x: [S, F], wq: [DI, Hl ],wk: [D2, H2 ], wv: [D3, H3] ) {k: [SI, H21 ] = matmul (x: [SI, Fl], wk: [Fl, H21 ] ) v: [S2, H31 ] = matmul (x: [S2, F2], wv: [F2, H31 ] ) q: [S3, Hll ] = matmul (x: [S3, F3], wq: [F3, Hll ] ) qt: [Hill, S31] = transpose (q: [S31, Hill] )a: [Sil, S311] = matmul (k: [Sil, H211 ], qt: [H211, S311 ] )b: [S3111] = reducej 1, add} (a: [Sill, S3111 ] ) c: [Sc, S31111] = broadcasts 0 } (b: [S31111 ] )d: [Scl, S3112] = div (a: [Scl, S3112], c: [Scl, S3112] )z: [Sell, H311] = matmul (d: [Sell, S21 ], v: [S21, H311 ] )return z}

[0107] Each operation is analyzed in isolation and generates fresh (new) dimension names, and hence there is a large number of new dimension names, such as SI, S2, S3, Sil, S311, Scl and so on. For example, the matmul defining a enforces that the inner dimensions (H211) match, and the outer dimensions become the result, but the code does not force S 11 to be the same as the S from the function input (from the first dimension of x). Thus the partitioning system 100 can treat the input S and the internal variable Sil as separate variables (so that there is no immediate conflict), whilst maintaining the link between them in the dimension graph.

[0108] FIG. 2A shows an example dimension graph for part of the above attention neural network layer code, specifically the connected component of S (a connected component is a subgraph in which every node is reachable from every other node). This is the only part of the map whose dimension names participate in conflicts.

[0109] More particularly FIG. 2A has nodes that are labelled with, and represent, the dimension names from the preceding analysis (each node represents a point in the code where the respective dimension exists). The directed edges represent the data flow map, Jf, and an arrow from, say A to B, defines that the data in dimension A flows into and is used as dimension B. It can be seen that all the nodes in FIG. 2A originate from S.

[0110] In a dimension graph a conflict between two dimension names can be identified as occurring when, in a tensor variable definition or use, there are two different dimension namesin the dimension graph (different dimension names in a connected component of the graph), i.e. when two different dimension names annotate the same variable definition or use.[OHl] For example in the above attention neural network layer code the tensor a is defined as a: [Si l, S311. Here Si l and S 311 are connected (they both trace back to S) in the example dimension graph of FIG. 2A, and they both annotate the attention matrix tensor a. Thus there is a conflict between Si l and S311, shown in FIG. 2A as an undirected line connecting these nodes. By contrast there is no conflict, say, between S 11 and S31 because Si l is a dimension name for tensor a and S31 is a dimension name for tensor qt. That is, S 11 and S 311, cannot be sharded on the same device axis, but it is possible to shard Si l and S31 on the same device axis. The example of FIG. 2A identifies three conflicts in variables a, c and d, which are also immediately visible in the above attention neural network layer code the tensor, and two further conflicts that come from the uses of c and d (in div and matmul).

[0112] Each dimension name corresponds to one way in which an operation can be partitioned. In general each conflict edge can be resolved in two ways, by picking one or the other endpoint, i.e. by selecting one or the other of the dimension names at the ends of a conflict edge as a tensor to shard. The selected dimension name can be sharded along a device axis; the other, unselected dimension name can, e.g., be replicated across the devices on that axis. For example Sil could be sharded, splitting the tensor along its rows and replicating S311 (along its columns), or S311 could be sharded, splitting the tensor along its columns and replicating S311 (along its rows). The replication results in inclusion of a collective communication, e.g. an all_gather, at lowering (or an all_to_all if swopping axes).

[0113] For the example of FIG. 2A this results in 32 possible resolutions that can be exposed to the automatic partitioning system 108. Typically the time taken to partition a model is a small fraction of the time for which the partitioned model is used, so that expending time on the partitioning is justified. It can nonetheless be useful to reduce the partitioning search space.

[0114] One option is to identify what can be termed “compatible conflicts”, i.e. conflicts that can be resolved in the same way, e.g. shard both or replicate both, to avoid moving data between devices. Two conflicts can be identified as compatible if they form an isomorphic “box” structure in the dimension graph, without crossings. Grouping conflicts in this way reduces the number of independent resolutions that are provided to the automatic partitioning subsystem 108.

[0115] FIG. 2B shows an example of a compatible conflict (left), and examples of conflicts with crossings that are not compatible (middle and right). In each case a tensor defined with dimensions N, O is used with dimensions L, R.

[0116] If the data flows without any crossing paths, N -> L and O -> R in the example of FIG.2B left, the conflicts between N and O and between L and R are compatible, and these conflicts can be resolved in the same way. That is the sharding of N and L does not affect the sharding of O and R. The conflicts could be resolved differently but then resharding, e.g. with an all_to_all, would be needed between variable definition and variable use. Resolving compatible conflicts in the same way can reduce the collective communication needed.

[0117] If there are crossing paths, e.g. N -> R or O -> L in the examples of FIG. 2B middle and right, the conflicts are incompatible because the diagonal paths mix the dependencies. The data define in one dimension, e.g. N, is being used in the opposite dimension, e.g. R, of the conflicting pair. For example the conflict between N and O (and between L and R) could be resolved by sharding N (and L) and replicating O (and R). However the crossing from N to R feeds a sharded tensor dimension into a full tensor dimension, which would require collective communication, e.g. an all gather. Thus two conflicts with crossing data flow paths are deemed incompatible.

[0118] Rather than determine pairs of compatible conflicts the NDA subsystem 104 can determine sets of compatible conflicts, here termed “compatibility sets” by grouping them. This can use the reflexive, symmetric and transitive closure properties of the compatibility relation, i.e. if conflict A is compatible with conflict B, and B with C, then A is compatible with C. This can chain conflicts into a larger compatibility set of conflicts.

[0119] A compatibility set of conflicts can be resolved in (just) two ways, by resolving one of the conflicts as previously described (selecting one or the other of the dimension names at the ends of a conflict edge as a tensor to shard) and by resolving all the other conflicts in the compatibility set in the same way. The automatic partitioning subsystem 108 can make a single sharding decision for the entire compatibility set.

[0120] A compatibility set may correspond to two or more compatible conflicts. A conflict that cannot be grouped with others may notionally be assigned a compatibility set (comprising just a single conflict), so that the automatic partitioning subsystem 108 can process compatibility sets rather than a mix of conflicts and compatibility sets, for simplicity.

[0121] Referring again to FIG. 2 A, there is a single compatibility set that contains the five conflicts: {(Sill, S311), (Sill, S31111), (Sc, S31111), (Scl, S3112), (Sell, S21)} (respectively referred to as conflicts 1 to 5). Conflict 1 is a root conflict at tensor a, between SI 1 and S311. Conflict 2 is compatible with conflict 1 (no crossings); conflict 2 is compatible with conflict 3 (no crossings), so is also compatible with conflict 1; and so on. There are two possible resolutions, and there is one sharding decision to make; both options result in collective communications.

[0122] A first option corresponds to sequence sharding (S { s } in the example code given above), where dimensions are sharded to distribute the sequence length across devices. This would result in the inclusion of all_gather and reduce_scatter communication primitives at lowering. The second option swops which dimension is sharded and which is replicated compared with the first option; this would result in the insertion of two al l_gather communication primitives. These two options will typically result in different performance (runtime) and memory use profiles.

[0123] Another way to reduce the partitioning search space is to define that compatibility conflicts should be resolved in the same way in repeated layers. These repeated layers will have isomorphic (i.e. structurally identical) dimension graphs. This can be extended to “backwards layers” i.e. to the gradient computations for the (forward) neural network layers, so that compatibility conflicts are resolved in the same way in repeated backwards layers.

[0124] As an example, for a training step for a Transformer neural network with multiple (forward) attention layers, and corresponding backwards layers, this results in just four sharding conflict resolution decisions to be made (a transformer block comprises an the attention layer and a subsequent MLP layer, two decisions for each of the forward and backwards passes). A Transformer neural network or U-Net can have tens or even hundreds of layers and this results in a very substantial reduction in the search space.

[0125] As described previously the automatic partitioning system 108 can make tensor sharding decisions based on the mesh data 110, the machine learning model code 102, the reduced sets of dimension labels 106 for each of the operations, to generate the program data 120. In implementations the automatic partitioning system 108 can also evaluate options for resolving identified conflicts, e.g. taking account of any conflict compatibility sets that havebeen identified. This can be done, e.g., using a runtime cost model to evaluate sharding options taking into account any necessary collective communications.

[0126] FIG. 3 is a flow diagram of an example process for partitioning a machine learning model such as a neural network, or a neural network training step, across multiple processing nodes. The process of FIG. 3 may be implemented by one or more computers in one or more locations; for convenience the process is described with reference to FIG. 1.

[0127] At step 300 the partitioning system 100 obtains mesh data 110 that defines a configuration of the computing system 150, in particular of the processing nodes 152A.. N.

[0128] The partitioning system 100 also obtains machine learning model code 102, such as code specifying a neural network, e.g. code specifying a training step for training a neural network (step 302). As previously described, the machine learning model code 102 identifies a plurality of variables, a tensor with one or more dimensions, and a plurality of operations each operation is associated with a number of the variables, and including at least one input variable and at least one output variable result.

[0129] For each of the operations, and for each of the variables associated with the respective operation, the partitioning system 100, e.g. the NDA subsystem 104, provides unique labels for each of the dimensions of the respective tensor (step 304).

[0130] The partitioning system 100, e.g. the NDA subsystem 104, also applies a rule associated with a type of the operation to the unique labels for the associated variables, to obtain a reduced set of dimension labels, also referred to above as notional colors (step 306). The reduced set of dimension labels corresponds to a number of ways in which the respective operation and its associated variables can be partitioned across the processing nodes 152A.. N.

[0131] As previously described, each of the unique labels is unique across the plurality of operations and plurality of variables. Each of the unique labels may be associated with use of one of the variables for processing or output by one of the operations. The reduced sets of dimension labels for the different operations together comprise a number of different dimension labels corresponding to the number of partitioning decisions to be made for all of the operations. The same one of the dimension labels may belong to multiple ones of the reduced sets of dimension labels for the different operations.

[0132] The partitioning system 100 processes the mesh data 110, the machine learning model code 102, and the reduced sets of dimension labels for each of the operations, to generate theprogram data 120 that, when executed by each of the plurality of processing nodes, partitions the machine learning model, e.g. the neural network, e.g. the training step, across the plurality of processing nodes 152A.. N (step 308). This can be done by an automatic partitioning algorithm, e.g. implemented by the automatic partitioning system 108.

[0133] The process, in particular the ND A subsystem 104, can also identify a conflict for an operation where no partitioning is possible given one or more possible partitionings for the at least one input variable associated with the operation. The conflict can be identified based on the reduced set of dimension labels, e.g. based on the same dimension label being provided for multiple dimensions of a tensor corresponding to one of the associated variables for the operation. For example (but not necessarily) this can be done using a dimension graph as previously described. In some implementations the neural network is a transformer model, and the conflict is identified in an attention layer of the transformer model.

[0134] In response to identifying the conflict a collective communication operation can be inserted into the program data, to remove at least one of the possible partitionings for at the least one input variable associated with the operation. The collective communication operation can be inserted during the final lowering, or by the automatic partitioning subsystem 108 or some other part of the partitioning system 100.

[0135] The reduced sets of dimension labels for the different operations can together comprise a number of different dimension labels corresponding to the number of partitioning decisions to be made for all of the operations.

[0136] Applying the rule to obtain the reduced set of dimension labels can involve determining identities between different ones of the unique labels associated with different tensor dimensions to be partitioned along a same axis of the computing system 150, e.g. identities I as previously described. For each of the determined identities, for the different unique labels for which the respective identity is defined, the different unique labels can be set to be the same label, e.g. the same notional color.

[0137] In implementations each of the unique labels is associated with use of one of the variables for processing or output by one of the operations. For each of the operations the process can determine mappings between the unique labels associated with the uses of the associated variables and dimension labels associated with variable definitions of the associated variables. This can facilitate tracing data flows.

[0138] For example, the mappings (map 퓜) can be determined by applying the VARIABLE USE rule. By assigning unique labels to every use and mapping them back to the variable definition the partitioning system 100 can construct a (directed) dimension graph, e.g. as shown in FIG. 2A. This can maintain separate edges (and nodes) for each entry in the map (rather than, say, just identifying a conflict in S), which facilitates identifying where sharding conflicts originate.

[0139] Implementations of the process can obtain the reduced sets of dimension labels by setting labels associated with different ones of the uses of the same variable across different ones of the operations to be a same label belonging to the reduced sets of dimension labels. Note that in implementations of the partitioning system 100 this step is not performed by the NDA subsystem 104 (which in implementations keeps the identities defined by the map separate rather than merging them), but is performed by the automatic partitioning subsystem 108. For example it can result from the partitioning, e.g. MCTS, search described later.

[0140] In some implementations of the process the identities (e.g. 7) are applied to obtain a reduced number of labels by, for each of the determined identities, setting the different unique labels to be the same label. This can be performed by the NDA subsystem 104, e.g. to create an action space for a partitioning search. The mappings can then be applied, e.g. by the automatic partitioning subsystem 108 during the partitioning search, to obtain the reduced sets of dimension labels. This can be done by setting labels associated with different uses of the same variable across different ones of the operations to be the same label. For example, this can be done selectively by the automatic partitioning subsystem 108, determining which mappings to follow (unify), and which to break to resolve conflicts (requiring the insertion of collective communication).

[0141] Some implementations of the process use the mappings to define a graph, e.g. the previously described dimension graph, comprising nodes representing dimension labels and edges provided by the mappings. In this graph each node of the graph represents one possible partitioning for one of the operations (e.g. implemented by providing fresh names as described earlier). Broadly, in this graph the nodes provide sharding options and the edges represent data flow dependencies.

[0142] Based on the reduced set of dimension labels conflicts for one or more of the operations can be identified. Such a conflict occurs where no partitioning is possible given one or morepossible partitionings for at the least one input variable associated with the operation. The graph can then be augmented, e.g. by the ND A subsystem 104, with one or more further edges (connecting the nodes), where each of the further edges represents one of the conflicts.

[0143] These edges correspond to the undirected lines in FIGS. 2 A and 2B. In implementations these edges are computed (i.e. they are not just visualizations), to facilitate determining groups of compatible conflicts (compatibility sets), i.e. groups of these further edges. The automatic partitioning subsystem 108 can then operate on these compatibility sets during the partitioning search.

[0144] In implementations the partitioning, i.e. sharding, decisions made to partition the machine learning model, e.g. the neural network, e.g. the training step, across the plurality of processing nodes 152A.. N are made by an automatic partitioning algorithm on the basis of the reduced sets of dimension labels, i.e. notional colors. The automatic partitioning algorithm can be implemented by the automatic partitioning subsystem 108.

[0145] For example, the automatic partitioning subsystem 108 can explore the space of possible partitionings represented by the sets of dimensions labels. In this space, each tensor dimension (of the variables in the program) having the same dimension label is restricted to be partitioned along the same axis of the system of processing nodes (i.e. the plurality of processing nodes 152).

[0146] There are many ways in which this can be done. For example, the automatic partitioning subsystem 108 can apply a tree search to the sets of possible partitionings (where each of the sets of possible partitionings corresponding to one of the labels of the reduced sets of dimension label), to obtain a set of partitioning decisions for each of the labels belonging to the reduced sets of dimension labels.

[0147] The tree search can be a look ahead tree search, e.g. Monte Carlo tree search (MCTS). The tree search can be applied to perform a search until one or more termination criteria are met. For example, the look ahead tree search may be an in-tree search and a termination criterion may be that a leaf node (an unopened node) is encountered, or the search may be able to select a stop action that enables it to terminate the search, or a termination criterion may depend on a search budget e.g. a budget number of search steps (actions), or on a search depth limit.

[0148] Once at least some of the partitioning decisions have been output by the automatic partitioning subsystem 108, e.g. based on the tree search, a runtime cost model can be applied to the decisions to provide a cost indication associated with those decisions. The cost indication may, for example, indicate the communication resources, processing resources, or memory resources consumed if the model, e.g. neural network, e.g. training step, is executed based on the partitioning decisions that have been made.

[0149] Once the cost indication is produced, the automatic partitioning algorithm may then be applied, e.g. by the automatic partitioning subsystem 108, to provide further partitioning decisions based on the cost indication. For example the automatic partitioning subsystem 108 can further explore the space of possible partitionings, based on the cost indication, to obtain further partitioning decisions. A further cost indication may then be applied based on the further partitioning decisions. The automatic partitioning algorithm and runtime cost model may be applied iteratively in this manner to update the set of partitioning decisions based on the cost indications.

[0150] In an example implementation the automatic partitioning subsystem 108 can perform a Monte Carlo tree search (MCTS). For example, the partitioning process can generate a distribution over the possible sets of partitioning decisions, and then select one of the sets of partitioning decisions from the possible sets of partitioning decisions. In some implementations, the automatic partitioning subsystem 108 generates a search tree distribution, where the distribution depends on statistics of child nodes of a root node. In broad terms the search tree is traversed from the root node, iteratively selecting edges based on, e.g. that maximize, the combination of an action-value, and the upper confidence bound, until an unopened, i.e. not yet expended, leaf node is encountered. This is then expanded by creating at least one new child node for the leaf node.

[0151] FIG. 4 illustrates an example implementation of a partitioning process as described herein, showing an example of the overall process. The process starts with the machine learning model code 102, in this example StableHLO, an intermediate representation, e g. complied from JAX or PyTorch. The process performs an argument grouping step, i.e. grouping repeated neural network layers, followed by named dimension analysis, performed by ND A subsystem 104. The named dimension analysis identifies notional colors, conflicts, and compatibility sets, and exposes options for conflict resolution. The automatic partitioningsubsystem 108 then uses the runtime cost model to make sharding decisions based on performance estimates, to provide a final sharded module, i.e. program data 120.

[0152] A runtime cost model for the automatic partitioning subsystem 108 can be implemented as an abstract interpreter of the partitioned machine learning model code. That is, the runtime cost model does not need to actually run the partitioned model code. Instead the runtime cost model can use the lowered, device-local code to estimate the computational cost (e.g. FLOPS), memory, and communications overhead for a particular tensor shard processed by a particular device. The runtime cost model does not need to be especially accurate to allow the automatic partitioning subsystem 108 to evaluate sharding strategies.

[0153] The runtime cost model does not need to determine an absolute cost, only a relative difference in cost between states that represent model code partitionings. For example, batch partitioning across b devices would be expected to reduce runtime by a factor of b. (in inference; all_reduce operations are involve in training). Runtime cost can be accumulated along the critical path (the longest chain of dependent tasks), disregarding operations that happen in parallel. It can take account of characteristics of devices (e.g. FLOPS and network characteristics) and costs can be included for collective communication operations. The runtime cost model can also take account of peak memory use, e.g. by determining which tensors need to exist in memory simultaneously.

[0154] As an example, the cost, C(s), of a state, s, e.g. defined by a sharding configuration for every named dimension, can be determined asC(s) = RT(s) + MP(s)where RT(s) is a relative runtime and MP(s) is a memory penalty, respectively determined as:current runtimeRT(s) = — — -:- initial runtime„ current peak - DM (1, if current peak > DMinitial peak memory I 0, otherwisewhere DM is the available amount of per-device memory. In this example MP only penalizes a state if the memory requirements (current peak) of the partitioned model code exceed the amount of local, per-device memory, where C is a constant that determines the value of the penalty.

[0155] In some implementations the runtime cost model can disregard cheap operations and may only take account of matrix multiple operations (and collective communication). For example, the runtime cost model can include stored data that defines that, say c=matmul ( a, b ).takes a particular number of microseconds on a particular device, give the sizes of a and b, and so on.

[0156] In some implementations of the partitioning system 100, e.g. when automatic partitioning subsystem 108 uses a tree search, the state 5 can include as yet unsharded tensors. In this case the runtime cost model can estimate the unsharded (replicated) single-device runtime (and collective communication and memory cost). If an operation is sharded the faster, parallelized runtime (and collective communication and memory cost) is used.

[0157] One particular example of how the automatic partitioning subsystem 108 can perform a Monte Carlo tree search (MCTS) is now described. The MCTS operates on the space of possible partitionings exposed by the NDA subsystem 104. It starts with all dimensions unsharded, and aims to find a sequence of sharding actions (taken by a search agent) that progressively shard the model code, whilst minimizing execution time and prioritizing staying within device memory limits.

[0158] The MCTS builds a search tree in which search nodes represent specific sharding decisions or states, 5. For example a state can comprise a set of partitioning decisions in which each partitioning decision corresponds to a respective label in the reduced sets of dimension labels and specifies whether the dimensions of the variables associated with that label are partitioned across the plurality of processing nodes.

[0159] In implementations this state comprises a map that indicates which dimension names of the model code have been sharded, and how. This is more efficient than tracking applied actions, which can also be ambiguous if two sequences of actions lead to the same state. In implementations the map can record the configuration of every named dimension in the model code 102, i.e. including every label in the reduced set of dimension labels, i.e. all notional colors.

[0160] The search tree also has edges where an edge represents applying a specific sharding decision, i.e. an action. Each state has a set of actions that can be taken from that state. The search proceeds through the tree, e.g. through several rounds, aiming to minimize the cost indication from the runtime cost model, and may terminate early if a round fails to improve(reduce) the best cost so far. Optionally the MCTS can use multi -threaded processing to perform multiple search rounds in which each round simulates multiple sequences of actions (trajectories) in parallel.

[0161] An action can be defined as a tuple (triplet) of the form:dim name x resolution order x axisHere dim_name can refer to a label in the reduced set of dimension labels, i.e. a color. The axis refers to an axis of the devices (processing nodes 152A.. N). The resolution_order can be a bitstring that defines a resolution of any sharding conflicts, e.g. a Z>-bit string for model code with b compatibility sets, where the ith bit defines a resolution for the ith compatibility set of the model. As described, these resolutions can be pre-computed, which enables them to be 0(1) time. Including the resolutions in the action tuple facilitates trade-offs related to the order of conflict resolution, which can provide significant performance improvements and, in particular, can help to reduce peak memory usage of the partitioned model.

[0162] As an example, referring to FIG. 2A, if, say, dim_name refers to S, the bitstring can define conflict resolution options for the compatibility set shown in FIG. 2B, i.e. one bit for a forward attention layer, and one bit for a backwards attention layer in training. For example, a resolution option for the forward attention layer could have a single bit that defines shard Sil and replicate S311 if, say the bit is 0, and vice-versa if the bit is 1.

[0163] Taking an action can involve following a data flow path in the dimension graph. In principle, actions could be augmented with tensor values, to facilitate distinguishing between different ways that an intermediate tensor (i.e. a tensor that is not an argument) can be sharded. Here a tensor value is a specific variable or intermediate result, e.g. in SSA representation. This increases the action space, which can then be pruned by identifying equivalent actions ahead of executing them. However this has been found not to be necessary in practice.

[0164] The initial search space, or root state, can begin with all possible triplets for the model code 102. Optionally actions that affect fewer than a threshold number of dimensions, e.g. 10, can be discarded, as they do not meaningfully affect performance. The search can then proceed by extending trajectories from the root state (MCTS simulation). Broadly, this can involve obtaining the action triples for dimension names in function arguments, recording values that have been sharded as a result of previous actions, and updating a list of available actions basedon whether there exist any remaining dimension names and values to be sharded, or on whether a termination condition has been reached. The best trajectory found becomes the final partitioning strategy.

[0165] The MCTS search process can involve selecting a path in the search tree, starting from the root and selecting successive child nodes, until a leaf node is reached. Nodes can be selected based on balanced values of an exploitation term (based on reward / runtime cost) and exploration (to choose moves with fewer simulations, based on visit counts). When the leaf node is reached a new child node is added, representing a new partitioning decision, e.g. sharding a tensor dimension along a particular axis. A simulated partitioning can then be completed from the new node until a terminating state, e.g. fully partitioned model code, e.g. using a fast, randomized policy. The search can then perform abackpropagation step in which, e g. a reward dependent on the cost indication from the cost model is propagated back up the tree to update the exploitation term and visit count of all the ancestor nodes.

[0166] In some implementations of the partitioning system 100 the search space for a machine learning model that includes repeated layers can usefully be reduced by making sharding decisions consistently across model layers (as well as by resolving sharding conflicts consistently as previously described). This can be done by identifying function arguments that have the same role in different layers, and grouping these. The automatic partitioning subsystem 108, e.g. the MCTS search, can then make a sharding decision for one of these function arguments and copy that decision to the other function arguments in the group. Function arguments can be grouped, e.g., by determining a key for each argument based on how its dimension names are used in the model code or dimension graph, and grouping those with similar keys. For example if matmul ( z, w) is used in two layers with different tensors w, a key could define that Matmul_Operand_l, Dim_0_Flows_To_Output_l to indicate that these uses have the same structure, and so on.

[0167] FIG. 5 compares the partitioning performance of an implementation of the described partitioning system 100 with some other partitioning techniques. The figure relates to a 2B (billion) parameter Transformer model; the y-axis shows Model FLOPS Utilization (MFU) (the ratio between FLPOS achieved and the peak FLOPS that the hardware is theoretically capable of, and the x-axis shows sequence length. Curve 500 is for the implementation of partitioning system 100; curve 502 is for the Alpa automatic partitioner (Zheng et al.,arXiv: 2201.12023, 2022); curves 504 and 506 are for a semi-manual approach based on FSDP (fully-sharded data partitioning). Missing data points indicate an out of memory condition, and hence curves 502, 504 and 506 are not visible: Only the partitioning system 100 succeeded at partitioning the model sufficiently to fit into local device memory.

[0168] FIG. 6 illustrates the execution time of machine learning models, comparing partitioning using an implementation of the described partitioning system 100 with some other partitioning techniques, for training a 3,6B parameter U-Net (top left), for training a T7B Transformer model with a 2k sequence length (top right), for training a 875M parameter graph neural network (bottom left), and for 5B parameter Transformer model in inference (Pope et al., arXiv: 2211.05102, 2022, the model using ROPE and aKV cache). They-axis shows model step time, the average execution time (in milliseconds) for a single training or inference step, for TPU and P100 platforms. Bar 600 is for partitioning system 100; bar 602 is for manual partitioning; bar 604 is for AutoMap (Schaarschmidt et al. arXiv: 2112.02958, 2021); bar 606 is for Alpa. The described system performs better than the U-Net and T7B manual strategies, although both these have been heavily optimized manually. Note that even a 1% improvement in step time can translate to weeks of saved computation. For the graph neural network the partitioning system 100 significantly improved both runtime and memory performance by comparison with the “industry standard” manual approach.

[0169] FIG. 7 compares the auto-sharding search time of an implementation of the described partitioning system 100 with some other partitioning techniques, for a U-Net (left) and for a 7B Transformer model with a 2k sequence length (right). The y-axis shows the average sharding time (in seconds), for TPU and Pl 00 platforms. Bar 700 is for partitioning system 100; bar 702 is for AutoMap; bar 704 is for Alpa; OOMed indicates an out of memory (OOM) condition. Implementations of the described techniques can be orders of magnitude faster than other techniques which can, e.g., facilitate fast development workflows.

[0170] FIG. 8 illustrates the scaling of execution time of neural network training, comparing partitioning using an implementation of the described partitioning system 100 with some other partitioning techniques, for a 2B Transformer model. The y-axis shows model step time, the execution time in milliseconds for an inference step; the x-axis shows how this scales with increasing sequence length, showing sequence length (below) ranging from 2048 to 32768 and TPU device mesh topology, e g. 2x32x2, i.e. 128 devices, for 32768 (above). Bar 800 is forpartitioning system 100; bar 802 is for manual partitioning; bar 804 is for AutoMap; bar 806 is for Alpa. The described system achieves near optimal performance, compared with expert driven sharding strategies that have taken hundreds of engineering hours to develop.

[0171] In this specification, the term "configured" is used in relation to computing systems and environments, as well as computer program components. A computing system or environment is considered "configured" to perform specific operations or actions when it possesses the necessary software, firmware, hardware, or a combination thereof, enabling it to carry out those operations or actions during operation. For instance, configuring a system might involve installing a software library with specific algorithms, updating firmware with new instructions for handling data, or adding a hardware component for enhanced processing capabilities. Similarly, one or more computer programs are "configured" to perform particular operations or actions when they contain instructions that, upon execution by a computing device or hardware, cause the device to perform those intended operations or actions.

[0172] The embodiments and functional operations described in this specification can be implemented in various forms, including digital electronic circuitry, software, firmware, computer hardware (encompassing the disclosed structures and their structural equivalents), or any combination thereof. The subject matter can be realized as one or more computer programs, essentially modules of computer program instructions encoded on a tangible non-transitory storage medium for execution by or to control the operation of a computing device or hardware. The storage medium can be a storage device such as a hard drive or solid-state drive (SSD), a storage medium, a random or serial access memory device, or a combination of these. Additionally or alternatively, the program instructions can be encoded on a transmitted signal, such as a machine-generated electrical, optical, or electromagnetic signal, designed to carry information for transmission to a receiving device or system for execution by a computing device or hardware. Furthermore, implementations may leverage emerging technologies like quantum computing or neuromorphic computing for specific applications, and may be deployed in distributed or cloud-based environments where components reside on different machines or within a cloud infrastructure.

[0173] The term "computing device or hardware" refers to the physical components involved in data processing and encompasses all types of devices and machines used for this purpose. Examples include processors or processing units, computers, multiple processors or computersworking together, graphics processing units (GPUs), tensor processing units (TPUs), and specialized processing hardware such as field-programmable gate arrays (FPGAs) or application-specific integrated circuits (ASICs). In addition to hardware, a computing device or hardware may also include code that creates an execution environment for computer programs. This code can take the form of processor firmware, a protocol stack, a database management system, an operating system, or a combination of these elements. Embodiments may particularly benefit from utilizing the parallel processing capabilities of GPUs, in a General-Purpose computing on Graphics Processing Units (GPU) context, where code specifically designed for GPU execution, often called kernels or shaders, is employed. Similarly, TPUs excel at running optimized tensor operations crucial for many machine learning algorithms. By leveraging these accelerators and their specialized programming models, the system can achieve significant speedups and efficiency gains for tasks involving artificial intelligence and machine learning, particularly in areas such as computer vision, natural language processing, and robotics.

[0174] A computer program, also referred to as software, an application, a module, a script, code, or simply a program, can be written in any programming language, including compiled or interpreted languages, and declarative or procedural languages. It can be deployed in various forms, such as a standalone program, a module, a component, a subroutine, or any other unit suitable for use within a computing environment. A program may or may not correspond to a single file in a file system and can be stored in various ways. This includes being embedded within a file containing other programs or data (e.g., scripts within a markup language document), residing in a dedicated file, or distributed across multiple coordinated files (e.g., files storing modules, subprograms, or code segments). A computer program can be executed on a single computer or across multiple computers, whether located at a single site or distributed across multiple sites and interconnected through a data communication network. The specific implementation of the computer programs may involve a combination of traditional programming languages and specialized languages or libraries designed for GPGPU programming or TPU utilization, depending on the chosen hardware platform and desired performance characteristics.

[0175] In this specification, the term "engine" broadly refers to a software-based system, subsystem, or process designed to perform one or more specific functions. An engine istypically implemented as one or more software modules or components installed on one or more computers, which can be located at a single site or distributed across multiple locations. In some instances, one or more dedicated computers may be used for a particular engine, while in other cases, multiple engines may operate concurrently on the same one or more computers. Examples of engine functions within the context of Al and machine learning could include data pre-processing and cleaning, feature engineering and extraction, model training and optimization, inference and prediction generation, and post-processing of results. The specific design and implementation of engines will depend on the overall architecture and the distribution of computational tasks across various hardware components, including CPUs, GPUs, TPUs, and other specialized processors.

[0176] The processes and logic flows described in this specification can be executed by one or more programmable computers running one or more computer programs to perform functions by operating on input data and generating output. Additionally, graphics processing units (GPUs) and tensor processing units (TPUs) can be utilized to enable concurrent execution of aspects of these processes and logic flows, significantly accelerating performance. This approach offers significant advantages for computationally intensive tasks often found in Al and machine learning applications, such as matrix multiplications, convolutions, and other operations that exhibit a high degree of parallelism. By leveraging the parallel processing capabilities of GPUs and TPUs, significant speedups and efficiency gains compared to relying solely on CPUs can be achieved. Alternatively or in combination with programmable computers and specialized processors, these processes and logic flows can also be implemented using specialized processing hardware, such as field-programmable gate arrays (FPGAs) or application-specific integrated circuits (ASICs), for even greater performance or energy efficiency in specific use cases.

[0177] Computers capable of executing a computer program can be based on general-purpose microprocessors, special-purpose microprocessors, or a combination of both. They can also utilize any other type of central processing unit (CPU). Additionally, graphics processing units (GPUs), tensor processing units (TPUs), and other machine learning accelerators can be employed to enhance performance, particularly for tasks involving artificial intelligence and machine learning. These accelerators often work in conjunction with CPUs, handling specialized computations while the CPU manages overall system operations and other tasks.Typically, a CPU receives instructions and data from read-only memory (ROM), random access memory (RAM), or both. The essential elements of a computer include a CPU for executing instructions and one or more memory devices for storing instructions and data. The specific configuration of processing units and memory will depend on factors like the complexity of the Al model, the volume of data being processed, and the desired performance and latency requirements. Embodiments can be implemented on a wide range of computing platforms, from small embedded devices with limited resources to large-scale data center systems with high-performance computing capabilities. The system may include storage devices like hard drives, SSDs, or flash memory for persistent data storage.

[0178] Computer-readable media suitable for storing computer program instructions and data encompass all forms of non-volatile memory, media, and memory devices. Examples include semiconductor memory devices such as read-only memory (ROM), solid-state drives (SSDs), and flash memory devices; hard disk drives (HDDs); optical media; and optical discs such as CDs, DVDs, and Blu-ray discs. The specific type of computer-readable media used will depend on factors such as the size of the data, access speed requirements, cost considerations, and the desired level of portability or permanence.

[0179] To facilitate user interaction, embodiments of the subject matter described in this specification can be implemented on a computing device equipped with a display device, such as a liquid crystal display (LCD) or an organic light-emitting diode (OLED) display, for presenting information to the user. Input can be provided by the user through various means, including a keyboard), touchscreens, voice commands, gesture recognition, or other input modalities depending on the specific device and application. Additional input methods can include acoustic, speech, or tactile input, while feedback to the user can take the form of visual, auditory, or tactile feedback. Furthermore, computers can interact with users by exchanging documents with a user's device or application. This can involve sending web content or data in response to requests or sending and receiving text messages or other forms of messages through mobile devices or messaging platforms. The selection of input and output modalities will depend on the specific application and the desired form of user interaction.

[0180] Machine learning models can be implemented and deployed using machine learning frameworks, such as TensorFlow or JAX. These frameworks offer comprehensive tools and libraries that facilitate the development, training, and deployment of machine learning models.

[0181] Embodiments of the subject matter described in this specification can be implemented within a computing system comprising one or more components, depending on the specific application and requirements. These may include a back-end component, such as a back-end server or cloud-based infrastructure; an optional middleware component, such as a middleware server or application programming interface (API), to facilitate communication and data exchange; and a front-end component, such as a client device with a user interface, a web browser, or an app, through which a user can interact with the implemented subject matter. For instance, the described functionality could be implemented solely on a client device (e.g., for on-device machine learning) or deployed as a combination of front-end and back-end components for more complex applications. These components, when present, can be interconnected using any form or medium of digital data communication, such as a communication network like a local area network (LAN) or a wide area network (WAN) including the Internet. The specific system architecture and choice of components will depend on factors such as the scale of the application, the need for real-time processing, data security requirements, and the desired user experience.

[0182] The computing system can include clients and servers that may be geographically separated and interact through a communication network. The specific type of network, such as a local area network (LAN), a wide area network (WAN), or the Internet, will depend on the reach and scale of the application. The client-server relationship is established through computer programs running on the respective computers and designed to communicate with each other using appropriate protocols. These protocols may include HTTP, TCP / IP, or other specialized protocols depending on the nature of the data being exchanged and the security requirements of the system. In certain embodiments, a server transmits data or instructions to a user's device, such as a computer, smartphone, or tablet, acting as a client. The client device can then process the received information, display results to the user, and potentially send data or feedback back to the server for further processing or storage. This allows for dynamic interactions between the user and the system, enabling a wide range of applications and functionalities.

[0183] While this specification contains many specific implementation details, these should not be construed as limitations on the scope of any invention or on the scope of what may be claimed, but rather as descriptions of features that may be specific to particular embodimentsof particular inventions. Certain features that are described in this specification in the context of separate embodiments can also be implemented in combination in a single embodiment. Conversely, various features that are described in the context of a single embodiment can also be implemented in multiple embodiments separately or in any suitable subcombination. Moreover, although features may be described above as acting in certain combinations and even initially be claimed as such, one or more features from a claimed combination can in some cases be excised from the combination, and the claimed combination may be directed to a subcombination or variation of a subcombination.

[0184] Similarly, while operations are depicted in the drawings and recited in the claims in a particular order, this should not be understood as requiring that such operations be performed in the particular order shown or in sequential order, or that all illustrated operations be performed, to achieve desirable results. In certain circumstances, multitasking and parallel processing may be advantageous. Moreover, the separation of various system modules and components in the embodiments described above should not be understood as requiring such separation in all embodiments, and it should be understood that the described program components and systems can generally be integrated together in a single software product or packaged into multiple software products.

[0185] Particular embodiments of the subject matter have been described. Other embodiments are within the scope of the following claims. For example, the actions recited in the claims can be performed in a different order and still achieve desirable results. As one example, the processes depicted in the accompanying figures do not necessarily require the particular order shown, or sequential order, to achieve desirable results. In some cases, multitasking and parallel processing may be advantageous.

[0186] What is claimed is:

Claims

CLAIMS1. A method performed by one or more computers, the method comprising:receiving mesh data specifying a configuration of a system comprising a plurality of processing nodes;obtaining machine learning model code specifying a machine learning model, or a training step for training a neural network, the machine learning model code identifying:a plurality of variables, each of which is in the form of a tensor having one or more dimensions; anda plurality of operations, each of which is associated with a number of the variables including at least one input variable with respect to which the respective operation is performed and at least one output variable resulting from the respective operation;for each of the operations:for each of the variables associated with the respective operation, providing unique labels for each of the dimensions of the respective tensor; andapplying a rule associated with a type of the operation to the unique labels for the associated variables to obtain a reduced set of dimension labels corresponding to a number of ways in which the respective operation and its associated variables can be partitioned across the plurality of processing nodes; andprocessing the mesh data, machine learning model code, and the reduced sets of dimension labels for each of the operations, to generate program data that, when executed by each of the plurality of processing nodes, partitions the machine learning model, or the training step, across the plurality of processing nodes.

2. A method as claimed in claim 1, further comprising:based on the reduced set of dimension labels, identifying a conflict for one of the operations where no partitioning is possible given one or more possible partitionings for the at least one input variable associated with the one of the operations; andin response to identifying the conflict, causing a collective communication operation to be inserted into the program data to remove at least one of the possible partitionings for at the least one input variable associated with the one of the operations.

3. A method as claimed in claim 2, further comprising detecting the conflict based on a same one of the dimension labels being provided for multiple dimensions of a tensor corresponding to one of the associated variables for the one of the operations.

4. A method as claimed in claim 2 or claim 3, wherein the neural network is a transformer model, wherein the conflict is identified in an attention layer of the transformer model.

5. A method as claimed in any preceding claim, wherein the reduced sets of dimension labels for the different operations together comprise a number of different dimension labels corresponding to the number of partitioning decisions to be made for all of the operations.

6. A method as claimed in any preceding claim, wherein for each of the operations, the applying the rule to obtain the reduced set of dimension labels comprises:determining identities between different ones of the unique labels associated with different tensor dimensions to be partitioned along a same axis of the system; andfor each of the determined identities:for the different ones of the unique labels for which the respective identity is defined, setting the different ones of the unique labels to be a same label.

7. A method as claimed in any preceding claim, wherein each of the unique labels is associated with a use of one of the variables for processing or output by one of the operations,wherein the method further comprises, for each of the operations:determining mappings between the unique labels associated with the uses of the associated variables and dimension labels associated with variable definitions of the associated variables.

8. A method as claimed in claim 7, further comprising obtaining the reduced sets of dimension labels by:setting labels associated with different ones of the uses of the same variable across different ones of the operations to be a same label belonging to the reduced sets of dimension labels.

9. A method as claimed in claim 8 when dependent upon claim 6, wherein obtaining the reduced sets of dimension labels comprises:applying the identities to obtain a reduced number of labels by performing the step of, for each of the determined identities, setting the different ones of the unique labels to be a same label; andsubsequently, applying the mappings to obtain the reduced sets of dimension labels by performing the step of setting labels associated with different ones of the uses of the same variable across different ones of the operations to be a same label.

10. A method as claimed in any of claims 7 to 9, further comprising:using the mappings to define a graph comprising nodes representing dimension labels and edges provided by the mappings,wherein each node of the graph represents one possible partitioning for one of the operations.

11. A method as claimed in claim 10, further comprising:based on the reduced set of dimension labels, identifying one or more conflicts for at least one of the operations where no partitioning is possible given one or more possible partitionings for at the least one input variable associated with the operation; andadding to the graph one or more further edges connecting ones of the nodes, each of the further edges representing one of the conflicts.

12. A method as claimed in any preceding claim, wherein the processing the mesh data, model code, and the reduced sets of dimension labels for each of the operations comprises applying an automatic partitioning algorithm to explore a space of possible partitionings represented by the reduced sets of dimension labels.

13. A method as claimed in claim 12, wherein the automatic partition algorithm comprises a Monte Carlo Tree Search (MCTS) algorithm configured to explore the space of possible partitionings represented by the reduced sets of dimension labels.

14. A method as claimed in claim 12 or claim 13, further comprising:after applying the automatic partitioning algorithm to make one or more partitioning decisions based on the space of possible partitionings, applying a runtime cost model to the one or more partitioning decisions to output a cost indication; andapplying the automatic partitioning algorithm to further explore the space of possible partitionings based on the cost indication to obtain a further one or more partitioning decisions.

15. A method as claimed in any preceding claim, further comprising providing the program data to a compiler for compilation.

16. A method as claimed in claim 15, further comprising:causing the plurality of processing nodes to execute compiled program data generated by the compiler in order to implement the machine learning model or perform the training step for training the neural network.

17. The method of any preceding claim, wherein when implementing the machine learning model, the plurality of processing nodes operate under a single program multiple data (SPMD) model and each execute a same program, and wherein the program data specifies the same program to be executed by each of the plurality of devices to implement the machine learning model or to perform the training step.

18. The method of claim 17, wherein the program data comprises device-local SPMD code for execution by each of the plurality of processing nodes.

19. A computer-implemented method, comprising:performing the method of claim 16, or claim 17 or 18 when dependent on claim 16, to implemented the machine learning model, or perform the training step of the neural network;obtaining an input for the machine learning model or trained neural network; and processing the input using the machine learning model or trained neural network to generate an output.

20. A computer-implemented method, comprising:obtaining an input for a machine learning model or trained neural network; and processing the input using the machine learning model or trained neural network to generate an output;wherein the machine learning model, or a training step for training the neural network, has been partitioned across a plurality of processing nodes of a computing system by a process comprising:obtaining mesh data specifying a configuration of the computing system; obtaining machine learning model code defining tenor operations specifying the machine learning model, or the training step;performing a named dimension analysis by assigning unique labels to each tensor dimension of the tensor operations, and applying rules based on types of tensor operation to obtain a reduced set of dimension labels that defines valid ways in which tensor operations can be sharded across the processing nodes; andprocessing the reduced set of dimension labels to generate program data that, when executed by each of the plurality of processing nodes, partitions the machine learning model, or the training step, across the plurality of processing nodes.

21. A system comprising one or more computers and one or more storage devices storing instructions that when executed by the one or more computers cause the one more computers to perform the operations of the respective method of any one of claims 1-20.

22. One or more computer storage media storing instructions that when executed by one or more computers cause the one more computers to perform the operations of the respective method of any one of claims 1-20.