# Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning

Lianmin Zheng\*, Zhuohan Li\*, and Hao Zhang\*, UC Berkeley; Yonghao Zhuang, Shanghai Jiao Tong University; Zhifeng Chen and Yanping Huang, Google; Yida Wang, Amazon Web Services; Yuanzhong Xu, Google; Danyang Zhuo, Duke University; Eric P. Xing, MBZUAI and Carnegie Mellon University; Joseph E. Gonzalez and Ion Stoica, UC Berkeley

Presented by Khoa Pham and Julian Yu

## Background - Parallel Training

Parallel Training can boost the training of large-scale model

Many parallelism strategy have been proposed: DP, PP, TP, ZeRO, etc.

Some works try to **combine different parallelism**: Megatron-LM, etc.

- Most of them heavily rely on **manual tuning** and requires system expert experiences

## Background - Google Training Stack

- $\bullet$  XLA
	- ML Compiler that can take models written in TF, PyTorch, Jax and optimizes them for high-performance execution across GPUs, TPUs, Trainium, …
- GSPMD
	- Implements at XLA-level, can infer tensor sharding configuration based on users' annotations.
		- mesh\_split(tensor, device\_mesh, dims\_mapping)
	- GSPDM automatically generate parallel instructions and insert communication collective.
	- Natively support intra-op parallelism.
	- Alpa intra-op sharding spec take inspiration from and build heavily on it. See more later!

## Target & Challenges

#### **Target**: **Auto-parallelization**

It can significantly accelerate ML research by freeing developers from struggling with underlying system challenges

**Main challenge**: It requires navigating **a complex space of plans** that **grows exponentially** with the dimensions of parallelism and the size of the model and cluster:

- 1. how many data-parallel **replicas**
- 2. which **axis** to be partitioned
- 3. how to split the model into pipeline **stages**
- 4. how to **map** devices to the resulting parallel executables

## Target & Challenges (Cont.)

**Existing Works** For Auto-Parallelization:

- **1. Dapple:** only for DP + PP
- **2. PipeDream:** only for PP
- 3. **Autosync:** only for DP
- **4. Tofu:** only support single node, no PP
- **5. FlexFlow:** randomized search, can't find optimal/near-optimal plan

## Design Overview - Recategorizing Parallelism

**Re-categorizing** parallelism as **intra-operator** and **inter-operator**:

- 1. **Intra-op**: data/operator parallelism
	- a. Higher utilization
	- b. Higher communication volume
	- c. Fit devices with faster network connectivity

- 2. **Inter-op**: pipeline parallelism
	- a. Lower communication volume
	- b. Idle time
	- c. Fit devices with slower network connectivity





## Design Overview - Problem Formulation

**Hierarchically** optimizing the parallel plan **at two levels**: intra-op and inter-op.

**total cost** = inter-op cost + intra-op cost



## Design Overview - Compilation Passes



Inter-op Parallelism

## Design Overview - API

Annotate train\_step() by @parallelize

Upon the first call to train\_step():

- 1. Traces the whole function to get the **model IR**
- 2. Invokes the **compilation passes** to converts the function to a optimized parallel version

```
# Put @parallelize decorator on top of the Jax functions
@parallelize
def train step(state, batch):
    def loss func(params):
        out = state.forward(params, batch["x"])
        return jax.numpy.mean((out - batch["y"]) ** 2)
```

```
grads = grad (lossfunc)(state.params)new state = state.appendy gradient(grads)return new state
```

```
# A typical training loop
state = create train state()for batch in data loader:
    state = train \ step (state, batch)
```
## Intra-Op Parallelism - Goal

**Goal**: find a **intra-op parallel plan** to minimize the **intra-op cost**

**How**:

- **Building the searching space**: device mesh, sharding spec, resharding
- **- Formulating the cost**
- **- Optimizing the cost**

#### Intra-Op Parallelism - Device Mesh

Device mesh is the **logical 2D mesh view** of a set of GPUs



**Which mapping? optimized by inter-op pass!**

## Intra-Op Parallelism - Sharding Spec

Sharding spec is to define the **layout of a tensor**

N-dimensional matrix: **X0X1…Xn-1**, where **Xi**∈**{S, R},** means **sliced/replicated** on

i-th dimension 2D matrix SR: row-partitioned RS: column-partitioned RR: no partitioning SS: row- and columnpartitioned

## Intra-Op Parallelism - Sharding Spec (Cont.)

Mapping **tensor axes** to **device mesh axes:** add **superscript** to S



## Intra-Op Parallelism - Resharding

Means **layout conversion**, when an input tensor does not satisfy the sharding spec of the chosen parallel plan for the operator. It will introduce **communication cost**



several cases of communication cost



#### Intra-Op Parallelism - Parallel Algorithms of An Operator

Means map the **loop axes** to **mesh axes**, introducing **communication cost**

C=AB 
$$
\iff
$$
 C<sub>b,i,j</sub> =  $\sum_k A_{b,i,k} B_{b,k,j}$  mesh shape: (n<sub>0</sub>, n<sub>1</sub>)  
\n**loop axes:** b, i, j, k  
\n**mesh axes:** 0, 1  
\n**If using i→0, k→1 mapping**

\nInput spec:  $\mathbb{R} \mathbb{S}^{\wedge} \mathbb{O} \mathbb{S}^{\wedge} \mathbb{1}, \mathbb{R} \mathbb{S}^{\wedge} \mathbb{I} \mathbb{R}^{\text{parallel} \wedge \text{Output } \text{Input} \wedge \text{I}}$ 

\nOutput spec:  $\mathbb{R} \mathbb{S}^{\wedge} \mathbb{O} \mathbb{S}^{\wedge} \mathbb{R}$ 

\nOutput spec:  $\mathbb{R} \mathbb{S}^{\wedge} \mathbb{O} \mathbb{R}$ 

\nCommutation cost:  $all\text{-reduce}(\frac{M}{n_0}, 1)$ 

\nSubstituting the image,  $\sum_{i=0, k \neq 1}^{n} \sum_{i=0, k \neq 1}^{n} \sum_{j=0, k \neq$ 

 $RRS^{01}$ ,  $RS^{01}R$ 

all-reduce $(M, \{0, 1\})$ 

**RRR** 

#### Intra-Op Parallelism - ILP Formulation

Formulating the **total intra-op cost** and **optimizing it** by an Integer Linear Programming (ILP) solver: on graph G=(V, E),  $e \in E$ , u,  $v \in V$ 



**Comp.** and **comm. cost** of **node** *v*: **number** of parallel plan:  $k_v$ **comp. cost** vector of plans:  $c_v \in \mathbb{R}^{k_v}$ **comm. cost** vector of plans:  $d_v \in \mathbb{R}^{k_v}$ **choice** of parallel: **one hot** vec  $s_v \in \{0, 1\}^{k_v}$ 

**Resharding cost** of **edge e**: **number** of parallel plan:  $k_v$ ,  $k_u$ **resharding cost** matrix:  $R_{vu} \in \mathbb{R}^{k_v \times k_u}$ 

**Total Inta-op Cost**  $\sum_{v \in V} \delta_v^{\intercal}(c_v + d_v) + \sum_{(v,u) \in E} \delta_v^{\intercal}R_{vu}\delta_u$ *Su Sv* !

## Intra-Op Parallelism - ILP Formulation (Cont.)

#### **How to get** *cv***,** *dv***,** *Ruv* **?**

By profiling? too much cases!

By estimating for simplicity:

- **comp. cost** *cv*: set as 0
	- heavy ops (e.g. matmul): no replication, so arithmetic complexity is same for all parallel plans
	- light ops (e.g. element-wise): negligible
- **comm. cost** *dv* and **resharding cost** *Ruv*: communication bytes

$$
\sum_{v \in V} s_v^{\mathsf{T}}(c_v + d_v) + \sum_{(v,u) \in E} s_v^{\mathsf{T}} R_{vu} s_u
$$

#### Inter-op Parallelism - Goal

**Goal**: Slice computation graph and device cluster to *stage-mesh* pair such that

Pipeline execution latency is minimized and model is fit into memory

$$
T^* = \min_{\substack{s_1, \dots, s_S;\\(n_1, m_1), \dots, (n_S, m_S)}} \left\{ \sum_{i=1}^S t_i + (B - 1) \cdot \max_{1 \le j \le S} \{t_j\} \right\}.
$$
 (2)

We want to solve (2), under these additional constraints

- Colocate forward with corresponding backward operator on the same submesh
- The sliced submesh  $(n_1, m_1), ..., (n_S, m_S)$  must fully cover the N x M cluster mesh (use all compute devices)

#### Inter-op Parallelism - Goal

**Goal**: Slice computation graph and device cluster to *stage-mesh* pair such that

Pipeline execution latency is minimized and model is fit into memory



Figure 5: Illustration of the total latency of a pipeline, which is determined by two parts: the total latency of all stages  $(t_1 +$  $t_2 + t_3 + t_4$ ) and the latency of the slowest stage  $((B - 1) \cdot t_3)$ .



 $\mathsf{D}$ 

#### Inter-op Parallelism - Challenges

**Challenges:** There are many ways to slice computation graph and device cluster to stage-mesh pair. How do we know which stage-mesh mapping is the best?



#### Inter-op Parallelism - Challenges

**Challenges:** There are many ways to slice computation graph and device cluster to stage-mesh pair. How do we know which stage-mesh mapping is the best?



#### Inter-op Parallelism - DP Formulation

Our submesh spaces  $(n_1, m_1), ..., (n_S, m_S)$  consists of two options

- One-dimensional submeshes  $(1, 1), (1, 2), (1, 4), ..., (1, M)$ 
	- $\circ$  I.e use 1, 2, 4, 8,  $\ldots$  devices in a single node
- Two-dimensional submeshes  $(2, M), (3, M), ..., (N, M)$ 
	- i.e use multiple nodes and all the devices of those nodes

Other choices, such as  $(n, m)$  where  $n > 1$  and  $m < M$ , (i.e use multiple nodes but not all devices on those nodes) leads to inferior result. The above two choices can fully cover the device mesh N x M (proof in paper)

#### Inter-op Parallelism - DP Formulation

| $F(s, k, d; t_{max})$                                                                                                                                                                              | (3)                     | Lowest latency to run $(o_k, ..., o_i)$ |
|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------|-----------------------------------------|
| $= \min_{\substack{k \leq i \leq K \\ n_s \cdot m_s \leq d}} \left\{ \left( \frac{t_{intra}((o_k, ..., o_i), Mesh(n_s, m_s), s)}{t_{intra}((o_k, ..., o_i), Mesh(n_s, m_s), s)} \right) \right\},$ | (3)                     | on $Mesh(n_s, m_s)$                     |
| $= \min_{\substack{k \leq i \leq K \\ n_s \cdot m_s \leq d}} \left\{ \left( \frac{t_{intra}((o_k, ..., o_i), Mesh(n_s, m_s), s)}{t_{max}} \right) \right\},$                                       | Set to infinity if OOMs |                                         |

Lowest latency

Represents the minimal total latency when slicing operators ok to oK into s stages and putting them onto d devices so that the latency of each stage is less than tmax

$$
T^*(t_{max}) = \min_{s} \{ F(s, 0, N \cdot M; t_{max}) \} + (B - 1) \cdot t_{max}. \quad (4)
$$

Flatten the computation graph and condense the operators into layers Algorithm 1 Inter-op pass summary.

- 1: **Input:** Model graph G and cluster C with shape  $(N, M)$ .
- 2. **Output:** The minimal pipeline execution latency  $T^*$ .
- 3: // Preprocess graph.
- 4:  $(o_1, \ldots, o_K) \leftarrow \text{Flatten}(G)$
- 5:  $(l_1, \ldots, l_L) \leftarrow$  OperatorClustering $(o_1, \ldots, o_K)$
- $\overline{6}$ : // Run the intra-op pass to get costs of different stagemesh pairs.
- 7: submesh\_shapes  $\leftarrow \{(1,1), (1,2), (1,4), \ldots, (1,M)\} \cup$  $\{(2,M), (3,M), \ldots, (N,M)\}\$
- 8: for  $1 \leq i \leq j \leq L$  do
- stage  $\leftarrow (l_i, \ldots, l_j)$  $9:$
- for  $(n,m) \in submesh\_shapes$  do  $10:$
- for s from 1 to  $L$  do  $11:$
- $t\_intra(state,Mesh(n,m), s) \leftarrow \infty$  $12:$
- end for  $13:$
- for  $(n_l, m_l)$ ,  $opt \in$  LogicalMeshShapeAndIntraOp  $14:$ Options $(n,m)$  do
- $plan \leftarrow IntraOpPass(stage,Mesh(n_1, m_1), opt)$  $15:$
- $t_l$ , mem<sub>stage</sub>, mem<sub>act</sub>  $\leftarrow$  Profile(plan)  $16:$
- for s satisfies Eq.  $5$  do  $17:$
- **if**  $t_l < t$ *\_intra*(*stage*,*Mesh*(*n*,*m*),*s*) **then**  $18:$ 
	- $t\_intra(state,Mesh(n, m), s) \leftarrow t_l$
- end if  $20:$
- end for  $21:$
- end for  $22:$
- end for  $23:$
- $24:$  end for

 $19:$ 

25: // Run the inter-op dynamic programming

Precalculate lowest execution latency for every stage-mesh pair using

IntraOpPass

Algorithm 1 Inter-op pass summary.

- 1: **Input:** Model graph G and cluster C with shape  $(N, M)$ .
- 2. **Output:** The minimal pipeline execution latency  $T^*$ .
- 3: // Preprocess graph.
- 4:  $(o_1, \ldots, o_K) \leftarrow \text{Flatten}(G)$
- 5:  $(l_1, \ldots, l_L) \leftarrow$  OperatorClustering $(o_1, \ldots, o_K)$
- 6: // Run the intra-op pass to get costs of different stagemesh pairs.
- 7: submesh\_shapes  $\leftarrow \{(1,1), (1,2), (1,4), \ldots, (1,M)\} \cup$  $\{(2,M), (3,M), \ldots, (N,M)\}\$
- 8: for  $1 \leq i \leq j \leq L$  do
- stage  $\leftarrow (l_i, \ldots, l_j)$  $9:$ **for**  $(n,m) \in submesh\_shapes$  **do**  $10:$
- for s from 1 to  $L$  do  $11:$ 
	- $t\_intra(state,Mesh(n,m), s) \leftarrow \infty$

end for  $13:$ 

 $12:$ 

 $19:$ 

- for  $(n_l, m_l)$ ,  $opt \in$  LogicalMeshShapeAndIntraOp  $14:$ Options $(n,m)$  do
- $plan \leftarrow \text{IntraOpPass}(stage,Mesh(n_l, m_l), opt)$  $15:$
- $t_l$ , mem<sub>stage</sub>, mem<sub>act</sub>  $\leftarrow$  Profile(plan)  $16:$
- for s satisfies Eq.  $5$  do  $17:$ **if**  $t_l < t$ *\_intra*(*stage*,*Mesh*(*n*,*m*),*s*) **then**  $18:$ 
	- $t\_intra(stage,Mesh(n,m), s) \leftarrow t_l$

end if  $20:$ 

- end for  $21:$
- end for  $22:$
- end for  $23:$
- $24:$  end for
- 25: // Run the inter-op dynamic programming

Precalculate lowest execution latency for every stage-mesh pair using

IntraOpPass

Can interpret a (n, m) physical mesh as any (n', m') virtual mesh such that  $n'm' = nm$ 

Algorithm 1 Inter-op pass summary.

- 1: **Input:** Model graph G and cluster C with shape  $(N, M)$ .
- 2. **Output:** The minimal pipeline execution latency  $T^*$ .
- 3: // Preprocess graph.
- 4:  $(o_1, \ldots, o_K) \leftarrow \text{Flatten}(G)$
- 5:  $(l_1, \ldots, l_L) \leftarrow$  OperatorClustering $(o_1, \ldots, o_K)$
- 6: // Run the intra-op pass to get costs of different stagemesh pairs.
- 7: submesh\_shapes  $\leftarrow \{(1,1), (1,2), (1,4), \ldots, (1,M)\} \cup$  $\{(2,M), (3,M), \ldots, (N,M)\}\$
- 8: for  $1 \leq i \leq j \leq L$  do
- stage  $\leftarrow (l_i, \ldots, l_j)$  $9:$ **for**  $(n,m) \in submesh\_shapes$  **do**  $10:$
- for s from 1 to  $L$  do  $11:$ 
	- $t\_intra(stage,Mesh(n,m), s) \leftarrow \infty$

end for

- for  $(n_l, m_l)$ , *opt*  $\in$  LogicalMeshShapeAndIntraOp  $14:$ Options $(n,m)$  do
- $plan \leftarrow IntraOpPass(stage,Mesh(n_l, m_l), opt)$  $15:$
- $t_l$ , mem<sub>stage</sub>, mem<sub>act</sub>  $\leftarrow$  Profile(*plan*)  $16:$  $17:$ 
	- for s satisfies Eq.  $5$  do **if**  $t_l < t$ *\_intra*(*stage*,*Mesh*(*n*,*m*),*s*) **then**
- $t\_intra(stage,Mesh(n,m), s) \leftarrow t_l$  $19:$
- end if  $20:$
- end for  $21:$
- end for  $22:$
- end for  $23:$
- $24:$  end for

 $12:$ 

 $13:$ 

 $18:$ 

25: // Run the inter-op dynamic programming

Precalculate lowest execution latency for every stage-mesh pair using

IntraOpPass

Can interpret a (n, m) physical mesh as any (n', m') virtual mesh such that  $n'm' = nm$ 

Profile memory usage and only keep the Intra-op parallelism plans that do not result in OOM

Algorithm 1 Inter-op pass summary.

- 1: **Input:** Model graph G and cluster C with shape  $(N, M)$ .
- 2: **Output:** The minimal pipeline execution latency  $T^*$ .
- 3: // Preprocess graph.
- 4:  $(o_1, \ldots, o_K) \leftarrow \text{Flatten}(G)$
- 5:  $(l_1, \ldots, l_L) \leftarrow$  OperatorClustering $(o_1, \ldots, o_K)$
- 6: // Run the intra-op pass to get costs of different stagemesh pairs.
- 7: submesh\_shapes  $\leftarrow \{(1,1), (1,2), (1,4), \ldots, (1,M)\} \cup$  $\{(2,M), (3,M), \ldots, (N,M)\}\$
- 8: for  $1 \leq i \leq j \leq L$  do
- stage  $\leftarrow (l_i, \ldots, l_j)$  $9:$ **for**  $(n,m) \in submesh\_shapes$  **do**  $10:$
- for s from 1 to  $L$  do  $11:$ 
	- $t\_intra(stage,Mesh(n,m), s) \leftarrow \infty$

end for

- for  $(n_l, m_l)$ ,  $opt \in$  LogicalMeshShapeAndIntraOp Options $(n,m)$  do  $plan \leftarrow \text{IntraOpPass}(stage,Mesh(n_l, m_l), opt)$
- $t_l$ , mem<sub>stage</sub>, mem<sub>act</sub>  $\leftarrow$  Profile(*plan*)
- for  $s$  satisfies Eq.  $5$  do if  $t_l < t_l$  intra(stage, Mesh $(n, m)$ , s) then
- $t\_intra(stage,Mesh(n,m), s) \leftarrow t_l$ end if
- $20:$ end for
- $21:$ end for  $22:$
- end for  $23:$
- $24:$  end for

 $12:$  $13:$ 

 $14:$ 

 $15:$ 

16:

 $17:$ 

 $18:$ 

 $19:$ 

25: // Run the inter-op dynamic programming

#### Run the inter-op dynamic programming

Algorithm 1 Inter-op pass summary.

- 1: **Input:** Model graph G and cluster C with shape  $(N, M)$ .
- 2. **Output:** The minimal pipeline execution latency  $T^*$ .
- 3: // Preprocess graph.
- 4:  $(o_1, \ldots, o_K) \leftarrow \text{Flatten}(G)$
- 5:  $(l_1, \ldots, l_L) \leftarrow$  OperatorClustering $(o_1, \ldots, o_K)$
- 6: // Run the intra-op pass to get costs of different stagemesh pairs.
- 7: submesh\_shapes  $\leftarrow \{(1,1), (1,2), (1,4), \ldots, (1,M)\} \cup$  $\{(2,M), (3,M), \ldots, (N,M)\}\$
- 8: for  $1 \le i \le j \le L$  do
- stage  $\leftarrow (l_i, \ldots, l_j)$  $9:$
- **for**  $(n,m) \in submesh\_shapes$  **do**  $10:$
- for s from 1 to  $L$  do  $11:$
- t\_intra(stage, Mesh $(n,m)$ , s)  $\leftarrow \infty$  $12:$
- end for  $13:$
- for  $(n_l, m_l)$ ,  $opt \in$  LogicalMeshShapeAndIntraOp  $14:$ Options $(n,m)$  do
- $plan \leftarrow IntraOpPass(stage,Mesh(n_1, m_1), opt)$  $15:$
- $t_l$ , mem<sub>stage</sub>, mem<sub>act</sub>  $\leftarrow$  Profile(plan)  $16:$
- for s satisfies Eq.  $5$  do  $17:$
- **if**  $t_l < t$ *\_intra*(*stage*,*Mesh*(*n*,*m*),*s*) **then**  $18:$ 
	- $t\_intra(state,Mesh(n, m), s) \leftarrow t_l$
- end if  $20:$
- end for  $21:$
- end for  $22:$
- end for  $23:$

 $19:$ 

- $24:$  end for
- 25: // Run the inter-op dynamic programming

#### Parallelism Orchestration



Inter-op Parallelism

#### Parallelism Orchestration - Cross-mesh resharding

- In Megatron-LM, each pipeline stages have same degrees of data and tensor parallelism. Point-to-point communication between correspondent devices
- For Alpa, device meshes holding two consecutive stages may have different shapes.



#### Evaluation - Setup

- Each node is an Amazon EC2 p3.16xlarge instance with 8 NVIDIA. V100 16 GB GPUs, 64 vCPUs, and 488 GB memory.
	- The 8 GPUs in a node are connected via NVLink. 25Gbps cross-node bandwidth
- Respects the semantics of synchronous gradient descent, thus does not affect model convergence
- Evaluate weak scaling by increasing model size along with number of GPUs

Table 4: Models used in the end-to-end evaluation.  $LM =$ language model.  $IC = image classification$ .



## Evaluation - End-to-end (Weak Scaling)



- Alpa generated parallelism plan closely resembles Megatron-LM best-performed plans
- Key diff: Alpa also partitions weight-update operation when DP exists => slight improvement to Megatron-LM in some config

## Evaluation - End-to-end (Weak Scaling)



- Slightly better/matches DeepSpeed for single node performance
- DeepSpeed MoE does not have PP. Alpa performs 3.5x on 2 nodes and 9.7x on 4 nodes
- Heterogeneous architecture. Very hard for manual parallelism plan
- Alpa still manage to find 80% scaling parallelism plan

#### Evaluation - Intra-op only study



- ZeRO optimizes for memory but not communication overhead
- Alpa's ILP always figure out the correct plan that minimize communication overhead in all cases, achieving near linear-scaling, while making sure the model fits into memory
- For MoE, Alpa ILP managed to find and combine expert parallelism and ZeRO-flavour data parallelism

#### Case Study: Wide-ResNet



Figure 12: Visualization of the parallel strategy of Wide-ResNet on 16 GPUs. Different colors represent the devices a tensor is distributed on. Grey blocks indicate a tensor is replicated across the devices. The input data and resulting activation of each convolution and dense layer can be partitioned along the batch axis and the hidden axis. The weights can be partitioned along the input and output channel axis.

## Compilation Overhead (Runtime of Algorithm 1)

- Most of the time is spent on enumerating and profiling stage-mesh (preprocessing)
- Speedup profiling by a simple cost model built at XLA instruction level
- Compile executable for each stage in parallel with distributed workers



Figure 10: Alpa's compilation time on all GPT models. The model size and #GPUs are simultaneously scaled.

Table 5: Compilation time breakdown of GPT-39B.

| <b>Steps</b>                 | Ours      | w/o optimization  |
|------------------------------|-----------|-------------------|
| Compilation                  | 1582.66 s | >16hr             |
| Profiling                    | 804.48 s  | $> 24$ hr         |
| <b>Stage Construction DP</b> | $1.65$ s  | N/A               |
| Other                        | 4.47 s    | N/A               |
| Total                        | 2393.26 s | >40 <sub>hr</sub> |

#### Alpa Present and Future

- [Alpa project](https://github.com/alpa-projects/alpa) is no longer actively maintaining
- Instead, integrating into **XLA's autosharding**, idea is to compile model code (Torch, Jax, TensorFlow) to automatic parallelism executable without relying on users' annotation unlike GSPDM

## **Thoughts**

- The only functional open-source automatic parallelism framework as of today!
- Works for any model without user code changes
- Built automatic support for GSPMD intra-op parallelism
	- Generalizable view of parallelism
	- All about choosing what dim to replicate/shard
- Matches performance of Megatron-LM in GPT and search results closely resembles Megatron-LM best-performed plans
- Cross-mesh resharding is not optimal (also acknowledged in the paper)
	- Follow up [work](https://arxiv.org/pdf/2211.05322) MLSys 23'