Distributed Llama 作为github上的一个明星项目,有2.6kstar,可以在8台树莓派上面跑起来一个Llama3-405B参数的模型。截至今天(2025/4/02)我都没有看到中文互联网上有人写过它的源码分析,我很难过(其实是我打ASC希望学这个来打好我分布式代码的底层基础),今天我们就来分析一下这个dllama的源码,主要侧重分布式推理过程中的信息传递以及数据同步。

入口函数

这个语言作为一个几乎纯CPP实现的项目,目前支持树莓派(arm设备)以及具有AVX2指令集的CPU,并且使用vulkan做到了对GPU的支持。首先我们从它启动脚本调用的函数入手:

可以看到这个仓库从 dllama 这个二进制文件开始执行,并用其中向外提供的api inference 以及 worker 提供了分布式的调用基础。首先参数被传入 src/dllama.cpp 的main()函数中,在进行 sockets 的初始化之后前面的参数被丢进 src/app.cpp 一个AppCliArgs的类里面去解析,接着根据解析出来结果的 mode 属性来判断这个进程要进入哪种模式 (inference/worker) 。

推理主体

在确定好模式之后头结点的推理进程就会直接进到 runInferenceApp 这个主推理函数里面, 我们去 src/app.cpp 里面去看看这个函数在干嘛。这个函数会进行一些LLM的初始化比如LLM头部信息、采样器初始化、Tokenizer初始化等工作,接下来重点来了(敲黑板敲黑板);

网络初始化与交互流程

左侧是head结点的 runInferenceApp 函数,在初始化完LLM的一些信息后 LlmNet net = buildLlmNet(&header, nNodes, args->nBatches);非常关键的一步,这里根据模型头信息和总节点数 nNodes 来构建 Llama 网络 LlmNet)。这暗示着模型的计算图被划分,以便分配到 nNodes 个节点上执行net.nodeConfigs 可能是一个数组或列表,存储了每个节点(包括 root 节点)负责计算的模型部分的配置信息(例如,哪些层、哪些 attention head 由哪个节点处理)。

核心网络交互逻辑 (在 if (nNodes > 1) 分支中):

  1. 红线:左侧head结点中的networkPtr = NnNetwork::connect();head节点会尝试主动建立与 worker 节点的网络连接,head 节点在启动时 runInferenceApp) 会尝试连接所有配置的 worker。在右侧这个 while(true) 结构中,使得 worker 会一直等到直到能够被head 结点连接。networkPtr = NnNetwork::serve(args->port);这里主要干的事情是创建一个 socket,将该 socket 绑定到本机的所有网络接口和指定的 args->port 端口,开始监听 listen) 这个端口,等待传入的连接请求。成功建立的连接(socket 文件描述符)会被 NnNetwork 类实例 networkPtr 指向的对象)管理起来。这个 NnNetwork 对象就封装了与所有 worker 的通信通道。

  2. 黄线NnNetworkNodeSynchronizer (network, &execution, &netConfig, &nodeConfig);创建网络同步器,同步器是分布式计算中至关重要的组件,这个同步器持有前面创建的 network 对象,它将使用这些建立好的网络连接来进行实际的同步操作。它需要传入 network 连接、执行环境、全局配置和本地节点配置。这个同步器将负责 head 在运行时与其他节点进行协调和数据交换,当一方需要发送同步信号或数据时,它的 synchronizer 会使用 network 对象发送,用于协调不同节点间的计算步骤和数据交换。

  3. 绿线:configWriter(network);创建了一个专门用于从 head节点向 worker 节点发送配置信息的写入器 NnRootConfigWriter,并且它也持有 network 对象,说明配置信息的发送也是通过已建立的网络连接进行的。随后调用写入器的 writeToWorkers 方法,将全局网络配置 net.netConfig,(可能包含模型划分方式、张量分布信息等,和每个 worker 节点的具体配置 net.nodeConfigs 中对应 worker 的部分)发送给相应的 worker 节点。Worker结点解析从network对象中读取到的信息, 首先获取 NnNetwork 对象的原始指针,随后创建一个 configReader 对象,并将刚刚建立的 network 连接传递给它,这表明 configReader 将使用这条网络连接来读取配置信息。接着NnNetConfig netConfig = configReader.readNet();会调用 readNet() 方法,这个方法内部会通过 network 对象接收来自 root 节点的数据,并将其反序列化为 NnNetConfig 结构NnNetConfig 很可能包含了整个分布式网络的全局信息,比如总节点数、模型的整体划分策略等。 类似地,调用 readNode() 方法,通过 network 接收 head 发送的、专门针对当前这个 worker 节点的配置信息 NnNodeConfig。这部分配置会告诉 worker 它具体负责计算模型的哪些部分(比如哪些层、哪些头)。

  4. 蓝线:随后head结点创建了一个权重加载的对象 weightLoader 它需要 executor(可能要知道应该把权重加载到哪里)和 network 连接(因为权重数据要从网络来)以及nNodes(需要知道有多少个结点)然后通过调用loadLlmNetWeight() 方法执行读取权重的操作,在这一步可能会同时把权重加载到head结点以及worker结点中。后worker结点创建了一个专门用于从网络读取模型权重的对象 weightReader,可能是在 read() 方法内部循环接收这些数据块,反序列化,并通过 executor 将它们加载到正确的 device 内存位置。

最后Worker结点创建了一个WorkerLlmInference 对象。这个对象在封装了 worker 在接收完配置和权重后,实际执行推理计算的主循环或处理逻辑。它需要 execution 来执行计算,需要 network (或者更可能是通过内部持有的 synchronizer) 来在每一步推理中与 root (或其他节点) 进行通信。