您的当前位置:首页正文

DeepSpeed ZeRO 通信量分析

2024-11-08 来源:个人技术集锦

假设模型参数为 Ψ \Psi Ψ

传统数据并行通信量分析

传统的数据并行,在每一步计算完梯度后,需要通过 All-Reduce 来计算梯度的平均值。

All-Reduce 操作分为 Reduce-Scatter 和 All-Gather 两部分,每张卡在发送和接收两个方向上的通信量是2*模型参数。

Reduce-Scatter 阶段

在 Reduce-Scatter 阶段,如下图所示:

  1. GPU 0 发送了 c3+c2+c4+c0,接收了 b2+b1+b3+b4

总的算下来的话,GPU 0 总共发送/接收了4次,每次发送/接收的参数量为 Ψ 5 \frac{\Psi}{5} 5Ψ,因此总的发送量为 4 Ψ 5 \frac{4\Psi}{5} 5

All-Gather 阶段

在 All-Gather 阶段,如下图所示:

  1. GPU 0 发送了 b2+b1+b3+b4+b0

  1. GPU 0 发送了 a1+a0+a2+a3+a4

  1. GPU 0 发送了 e0+e4+e1+e2+e3

  1. GPU 0 发送了 d4+d3+d0+d1+d2


总的算下来的话,GPU 0 也是总共发送了4次,每次发送的参数量为 Ψ 5 \frac{\Psi}{5} 5Ψ,因此总的发送量为 4 Ψ 5 \frac{4\Psi}{5} 5

通信量分析

我们来计算下整个AllReduce过程的通信成本。假设我们有N个设备,每个设备的数据大小为 Ψ \Psi Ψ,在一次Allreduce过程中,我们进行了N-1次Scatter-Reduce操作和N-1次Allgather操作,每一次操作所需要传递的数据大小为 Ψ \Psi Ψ/N,所以整个Allreduce过程所传输的数据大小为2(N-1) * Ψ \Psi Ψ/N (< 2 Ψ \Psi Ψ,与设备数N无关)。

整个Allreduce的通信速度只受限于逻辑环中最慢的两个GPU的连接(每次需要通信的数据大小仅为K/N,随着N增大,通信量减少,一般小于network bandwidth)。

ZeRO-DP 通信量分析

ZeRO-1

单个显卡会保存完整的模型参数和梯度。

随后使用reduce-scatter将梯度reduce至不同的显卡上(此时不同显卡仅拥有完整平均梯度的一部分),该步骤的通信量是 Ψ \Psi Ψ。各个显卡使用部分梯度更新对应的优化器状态,然后再更新对应的参数(此时每个显卡上的模型都更新了一部分参数)。

最后,使用all-gather将分布在各个显卡上的更新后参数分发自所有显卡上(此时所有显卡上都有了完整的更新后参数),该步骤的通信量是 Ψ \Psi Ψ。总的来说,各个显卡仅需要持有部分优化器状态即可,且总的通信量仍然是2 Ψ \Psi Ψ

ZeRO-2

每张卡只存储 1 N \frac{1}{N} N1 的优化器状态和梯度,对于 GPU 0 来说,为了计算它这 1 N \frac{1}{N} N1 梯度的均值,需要进行一次 Reduce 操作,通信数据量是 Ψ \Psi Ψ,然后其余显卡则不需要保存这部分梯度值了。实现中使用了bucket策略,保证 1 N \frac{1}{N} N1 的梯度只发送一次。

当 GPU 0 计算好梯度均值后,就可以更新局部的优化器状态了,当反向传播过程结束,进行一次Gather操作,更新模型参数,通信数据量是 Ψ \Psi Ψ

总的通信量仍然是2 Ψ \Psi Ψ

ZeRO-3

若使用参数划分,每个显卡仅保存部分参数。因此在前向传播和后向传播过程中需要从其他显卡那里接收必要的模型参数。为了避免参数广播的显存开销,可以使用流水线的方式。这里假设模型在计算第一层前向传播时,持有第一层参数的显卡会将参数广播至其他显卡。当所有显卡都拿到参数后,进行第一层的前向传播。得到前向传播结果后,其他显卡就可以丢弃这部分模型的参数。

每个显卡都有 Ψ / N \frac{\Psi}{/N} /NΨ的模型参数,有N个显卡,则需要广播N次,所以前向传播过程中的参数通信量为 Ψ / N ∗ N = Ψ \frac{\Psi}{/N} * N = \Psi /NΨN=Ψ

后向传播时也需要逆向完成一次参数广播,通信量同样是 Ψ \Psi Ψ

最后,梯度完成计算后还需要经过一次reduce-scatter,通信量也是 Ψ \Psi Ψ

由于各个显卡持有不同的参数,所以不需要前面将所有更新后参数进行all-gather的操作了。总的来说,通信量为"前向传播的参数广播"+“后向传播的参数广播”+“梯度的reduce-scatter”=3 Ψ \Psi Ψ,也就是标准通信量的1.5倍。

参考资料





Top