

Written by
AI Team
Published on
Apr 2, 2025
Efficient and Portable Mixture-of-Experts Communication
An overview of portable Mixture-of-Experts (MoE) communication, focusing on optimizing GPU parallelism and reducing latency in large-scale AI models
We present a high-performance, portable, open-source library for Mixture-of-Experts (MoE) communication that achieves 10x faster performance compared to standard All-to-All communication primitives.
Our implementation features several key technical innovations that deliver superior MoE communication efficiency:
GPU-initiated communication (IBGDA): Supports direct GPU-to-NIC communication, significantly reducing latency by bypassing CPU involvement
Communication and computation overlap: Split kernel architecture with separate send and receive stages enables computation to proceed concurrently with network transfers
Fastest single-node performance: 2.5x lower latency than the previously fastest implementation on single-node configurations
Efficient and portable multi-node performance: Our implementation achieves speeds up to 10x faster than standard all-to-all communication. Although approximately 2x slower than highly specialized implementations, our approach offers better portability across NVSHMEM versions and network environments (NVLink, CX-7, and EFA)
The library is fully open-source and available at https://github.com/ppl-ai/pplx-kernels.
In this article, we explore the challenges of expert parallelism in large-scale MoE models and describe our approach to efficient token dispatch and combination across distributed GPU environments.
Introduction
Mixture-of-Experts (MoE) models, such as DeepSeek R1 and Mixtral 8x7B, improve upon dense models by limiting the number of weights that are activated for each token. Instead of achieving high parameter counts by increasing the number of layers or increasing the size of the linear layers in the Multi-Layer Perceptron (MLP) of each decoder layer, MoE models replace the traditional MLP with multiple experts and a router. For example, out of the 671B parameters of DeepSeek R1, only 37B are multiplied with a given token during inference. This greatly reduces the amount of computation required to decode a token and improves latency compared to a dense model.
MoE models present some additional challenges for inference, compared to dense models. While the experts themselves are small MLP layers, each decoder layer includes a router that decides which experts a token is dispatched to, with each token being dispatched to multiple experts. The router is typically a small linear layer producing a probability distribution. Usually the experts with the top-K scores are picked and the final activation is computed as a weighted average, summing the expert outputs multiplied by a weight derived from the probability distribution of the router.
To minimize latency in distributed MoE systems, parallelism can be exploited across multiple devices, but this introduces communication challenges. Models such as Mixtral 8x7B or Llama-70B fit within 8 devices across a single node, benefiting from fast and low-latency NVLink interconnections (up to 900Gbps). However, larger models require sharding experts across multiple nodes using InfiniBand (peaking at 400Gbps), which introduces additional latency challenges.

In this article, we explore the problem of expert parallelism, describing our implementation of GPU kernels that efficiently dispatch tokens to devices holding the assigned experts and collect them back by computing the weighted sum of expert outputs. While aggressively optimized implementations targeting specific network hardware exist, we present a highly portable implementation that relies on a minimal set of NVSHMEM primitives yet delivers performance 10x faster than standard all-to-all communication. Our implementation achieves state-of-the-art performance on single-node configurations while maintaining excellent portability across various network environments.
Sharding and Parallelism
For efficient inference, the weights of a model must be held in device memory, while also leaving sufficient space for activations, KV caches and other buffers required by the forward pass through the model. The most capable models exceed the capacity of even the most capable GPUs, thus inference must be spread across multiple devices, which can collectively store the weights of the model. Based on the sharding schemes of weights, different communication and computation schemes must be used to offload computation and synchronize devices.

Expert Parallelism (EP), illustrated in the first figure, only parallelizes the expert computation. Different experts are assigned to different devices, which hold their weights. After routing, tokens are sent to the corresponding device, with the results being gathered and accumulated afterwards. The complexity of routing depends on the degree of parallelism in the other parts of the model: replicating other layers could eliminate the need for routing altogether, as each rank can select the tokens from a locally replicated routing table. However, if only one of the rank run routing, a broadcast is required to dispatch tokens, indices and weights to their respective experts. Finally, an all-gather or an all-to-all broadcast synchronizes the output tokens with whichever rank continues the execution of the model. Such an implementation is relatively simple as torch already exposes the required primitives, albeit some benefits could be gained by fusing accumulation with the gather operation collecting the tokens from the experts.

Expert-only parallelism does not scale ideally, as nodes in a cluster might be idle while the model is running non-expert layers, such as attention, norm and sampling. However, the computation of these layers, primarily attention, can also benefit from Tensor Parallelism (TP). Most models rely on multi-head attention, meaning that the attention heads and their corresponding Q, K and V projections can also be sharded across devices, replicating or gathering the slices between various layers. If attention is spread across all devices, an all-gather can synchronize the activations, allowing routing to be replicated, requiring synchronization primitives similar to the expert parallelism case for an efficient implementation. However, there are limits to parallelism at this level, as reducing the number of attention heads below a certain threshold will yield diminishing returns.

To best utilize all devices and support a very high degree of expert parallelism, of up to 128 or 256 GPUs, Data Parallelism (DP) is required. Under this scheme, the devices a model is split across are grouped to handle requests concurrently, computing attention and maintaining KV caches sharded across their local group. Multiple instances of these parallel groups collaborate on expert evaluation, with each hosting a different subset of the expert weights. Based on the number of attention heads, a group may typically span up to the size of an entire node, as intra-node communication is faster. For example, in the figure above, one DP rank independently services two requests, A and B, handling attention, norm and any other bookkeeping for the requests. The other DP rank processes a distinct request C. However, the first node hosts half of the experts, while the other node the other half, thus after routing tokens from A and B might be sent to the second node and vice-versa. This leads to a sparse communication problem: each device might send a different number of tokens to any other destination rank. Existing primitives from torch, primarily all_to_all, are not particularly well suited, as they might require some form of padding or GPU-to-CPU synchronization and metadata broadcast. To implement communication effectively, custom kernels are required to process the routing information and initiate communication from the devices without CPU intervention. After routing, a dispatch kernel must send tokens to the ranks they were routed to, while on the combine side the activations belonging to the requests in the current DP group must be collected. Additionally, work must be balanced within a DP group which may have multiple devices, in order to minimize broadcast and correctly replicate the tokens on all ranks for the subsequent layers.
NVSHMEM
NVSHMEM is an NVIDIA-specific OpenSHMEM implementation, providing portable inter-device communication facilities that abstract away the complexity of the underlying hardware. The API can express device-to-device reads and writes, which are mapped to the primitives of individual transport layers. Both NVLink and RDMA are supported, granting a degree of portability and allowing NVSHMEM kernels to operate within a node or across a cluster of multiple nodes. In our kernels, we use remote memory writes and signaling operations in order to transfer data and synchronize the participating devices.
NVSHMEM operations are built around the concept of symmetric memory: they operate on buffers which have been allocated on all the devices participating in inter-device communication. A local device can use an address derived within its own local buffer alongside the index of the target rank to specify the destination of an operation. The figure below illustrates this concept: both GPUs allocate symmetric buffers of the same size, retaining src
pointers to them. The first rank wants to send 3 integers to the second one, placing them at the start of the buffer: nvshmem_int_put_nbi
derives the start address from the local buffer, specifying the target device. The second rank derives an offset from its own buffer, sending one element to the first device, offsetting by one. While destination addresses must always be symmetric buffers allocated using nvshmem_alloc
, source buffers can be arbitrary regions of device memory, provided they are pre-registered with NVSHMEM.
While NVSHMEM provides a wide range of primitives, our kernels rely on only 3 functions, building all synchronization and fencing upon them.

nvshmemx_putmem_signal_nbi_warp
: Transfers a block of data from one device to another, while also setting a flag on the remote device. The operation either sets (NVSHMEM_SIGNAL_SET
) or increments (NVSHMEM_SIGNAL_ADD
) a 64-bit location. The flag is updated after the entire block of memory is transferred. If the remote device observed a change in the flag, it can safely access the buffer in its own memory. This function is useful for coalescing data transfer and synchronization.nvshmemx_signal_op
operates on a single memory location, typically a 64-bit flag, atomically setting or incrementing it. It is useful in sending over metadata and synchronizing devices.nvshmem_uint64_wait_until
is used on the receiving end of a signal, to poll a flag until the remote updates it.
Ensuring any form of ordering between operations through nvshmem_quiet
or nvshmem_fence
is expensive, thus our implementation of the kernels avoids the use of any other NVSHMEM primitives and minimizes any dependencies other than individual send-receive pairs. When sending data, we always use the non-blocking version of functions, without waiting for the data to be even sent out of the local rank. Other, implicit synchronization across the group of kernels ensures that these buffers are not destructively overwritten while the NICs need access to them.
NVSHMEM uses GPUDirect RDMA so that the NIC can directly read and write GPU memory for data exchange without involving CPUs. Furthermore, on ConnectX NICs, NVSHEMEM supports GPU-initiated communication (also known as Infiniband GPUDirect Async, or IBGDA), which allows GPU to post RDMA operations directly to NIC and to poll completion queue of NIC directly. On platforms that does not support GPU-initiated communication, NVSHMEM spawns a “proxy” thread on host CPU to initiate communication and poll completion on behalf of the GPU. NVSHMEM program is portable regardless of whether GPU-initiated communication is supported or not. However, GPU-initiated communication significantly cuts latency because it completely bypasses the detour to CPU.
Portable Kernels
We implement MoE communication through a pair of dispatch and combine kernels. The dispatch kernels are responsible for reading tokens and routing information on each rank, dispatching them to the appropriate experts. The combine kernels collect the activations produced by the experts and send them back to their source ranks, while also computing the weighted average from the selected experts based on the weights computed by the router. The kernels are further split into a send and receive component, in order to allow data transfers to be overlapped with computation. The send kernels are all non-blocking and non-synchronizing: they simply dispatch all the writes to the remotes. On the other end, the combine kernels only read from memory, waiting until all required data has been transferred. After dispatching work to the NICs, while the data is transferred asynchronously over the wire, the GPUs can do other useful work locally, such as applying shared experts or computing attention.
Each pair of kernels has its own symmetric memory buffers, which each rank allocating private storage to receive data from all other ranks. The maximal number of tokens that can be dispatched from a DP group is fixed across the ranks, which also sets the upper bound each rank can receive for each local expert from each DP rank. This allows sender ranks to derive a unique address on the destination rank to write to, without requiring any synchronization among them. After data is received, the kernels shuffle the tokens around to lay them out in memory in the optimal format required by the receiving kernels. While the buffers have a sizable dimension, they are re-used across all sequential layers of a model.

The only form of global synchronization, as illustrated in the figure above, is implemented in the combine-receive kernel. Once data is dispatched from send, a rank returns and continues execution when it receives all the tokens it requires from the other ranks and enters the combine send kernels as soon as they are processed. The barrier in combine-receive ensures that no rank can run ahead and start dispatch-send while any other rank is still waiting to receive data, potentially causing destructive overlapping. Synchronization is done in the combine kernel for simplicity, as an optimization the same barrier could be implementing between the dispatch receive and the combine send kernels, further reducing latency by allowing the NICs more time to settle the transfers while the GPU is running other kernels.
Both kernels are split across all available SMs of the devices: while the dispatch send and combine receive kernels must parallelize across a per-rank maximum token count (max_m
), the dispatch receive and combine send kernels must handle that many tokens from each rank and local expert, for a maximum of max_m * (num_experts // EP) * (EP // TP)
.
Dispatch
The dispatch kernel is responsible for dispatching tokens and block scaling factors to peers which host the experts the tokens were routed to. The sender side relies on warp specialization to parallelize two tasks: aggregating routing information to determine how many tokens are sent from each rank and packing the tokens into messages to be dispatched to the remote. The receiver side first waits for all the token counts to be received, then it reads and unpacks the payloads from shared memory into the tensors that will be passed on to the downstream Grouped GEMM kernels. It also stores information into buffers shared with the combine kernels, to indicate where each token should be sent back. This information is required as the receive kernel shuffles tokens around in contiguous buffers in a non-deterministic order.

In the sender part, a single warp of 32 threads on each block is responsible for reading the routing information and counting the number of tokens each destination expert is expected to receive from this rank, sending the count plus one using nvshmem_signal_op
. The count is incremented by one, as the transition from zero to non-zero on the remote end signifies the receipt of the counts from a source rank. In parallel, the remaining warps co-operate to copy tokens into symmetric memory across all blocks in parallel, packing activations, scaling factors and their index on the local rank into a contiguous chunk. The index is required by the combine sender to determine the address where the token will be written to. Next, after ensuring all the data has been copied through a barrier across the warp groups, the warps yet again operate independently, each sending the same buffer to a different expert in paralle. The tokens are sent using nvshmemx_putmem_signal_nbi_warp
, which also atomically increases the count of sent tokens from the local rank on the remote device. Within a DP group, since each rank ows a replica of the token to be sent, dispatch is balanced evenly, with each device sending out a subset of the tokens.
On the receive end, all kernels first synchronize with the sender ranks by waiting for the total token counts to be sent over. Afterwards, they all wait for the atomically incremented sent token counts to settle to the total counts, indicating that all the payloads from the source ranks have also been sent over, thanks to the semantics of the putmem
call. The kernels poll on the counts using nvshmem_uint64_wait_until
, parallelizing the operation across all blocks and threads. Subsequently, a cross-block barrier ensures that no block reads the buffers unless all data has been correctly received. Spread across blocks and synchronized via an atomically incremented counter, the tokens are copied from symmetric memory into regular contiguous buffers which become the inputs to the downstream kernels implementing the experts. The source rank, expert index and token index are stored separately, exactly pinpointing the location where the combine kernel has to send the activations. Even though tokens from within a DP group are sent from different devices, they are all grouped together to be passed on to the corresponding expert.
Combine
The combine kernels are yet again split into send and receiver halves: the senders copy the un-quantized 16-bit activations over to the remote devices, with the receive kernels waiting for the data to be sent over and computing the weighted average of all expert contributions locally. Additionally, they also act as a barrier, to synchronize the dispatch-combine sequence: each rank sets a flag on each peer on entry to the send kernels, with the receive kernels not being allowed to return unless they observe the flag being set. The latency of synchronization is minimal, as it overlaps with the actual communication and computation.

On the sender side, the kernels traverse the list of tokens assigned to all local experts in parallel, writing them to a buffer on the destination rank. The target rank, expert index and token index are read from the per-token buffers populated by the scater kernels upon the receipt of the tokens. Each sender has its own private memory region per expert to write to, as indicated in the figure above, avoiding the need to synchronize. Similarly to dispatch, combine atomically increments per token counters on the destination rank to indicate the receipt of the data: when the counter matches the number of experts a token was dispatched to, the token contents can be accessed.
In the receive kernel, the list of tokens is traversed in parallel across multiple blocks, waiting for their contents to arrive by polling the flag set by the signalling operation. Upon arrival, the payloads are read from the private buffers, with the routing table indicating which buffer to read from and what weight to assign to each expert. The results are then written to externally allocated tensors, with the kernel finishing execution once all devices passed the barrier. Across a DP group, all ranks receive a copy of each expert activation to compute their own replicas.
Benchmarks
We evaluate our kernels on a cluster of 128 GPUs spread across 16 nodes connected via InfiniBand and CX-7 NICs. We try both GPUDirect Async (IBGDA) and Reliable Connection (RC) with a CPU proxy. We compare them to the highly optimized DeepSeek implementation, as well as the dense primitives provided by PyTorch (through NCCL) or NVSHMEM.

Although on the dense NVSHMEM all-to-all operation, the performance of IBRC and IBGDA is similar (6378 µs vs 6180 µs), IBGDA is significantly faster with the sparse kernels. Adding up Dispatch and Combine, IBGDA uses 902 µs whereas IBRC takes 3223 µs - a 3.6x improvement in latency. While all-to-all is bandwidth bound, the sparse kernels broadcast orders of magnitudes less data, being bound by latency. By triggering network transfers directly from a GPU, without requiring a CPU proxy to coordinate the GPU and the NIC, end-to-end latency is significantly reduced.
While our portable kernels are about 2x slower than the highly optimized DeepSeek kernels, they improve latency by 10x compared to the dense kernels (902 µs vs 9944 µs).
Additionally, the split into sender-receiver components also allow some of the latency to be hidden away, unlike the library primitives.

On single-node (EP8), NVSHMEM utilizes NVLINK for transportation, delivering lower latency and higher throughput than inter-node networking. Our portable kernels are about 2.5x faster than DeepEP on single-node (186µs vs 481 µs).
Further Improvements
The kernels described here outperform the built-in primitives of ML frameworks and offer decent performance without over-specializing for particular inter-device transports, such as NVLink or InfiniBand. Besides the already-mentioned opportunities, further performance gains are attainable by replacing the communication primitives with more specialized versions. For example, across NVLink, the use of symmetric memory could be replaced with buffers shared across the devices, eliminating some copying and allowing for finer-grained synchronizations schemes, across individual tokens instead of token batches. Across InfiniBand, an implementation could access the underlying queue pairs directly, eliminating the need to transfer both the expected and sent token counts and flushing tokens from the queues into output buffers as they arrive. However, such implementations come at the cost of portability, whilst our implementations allows clusters with specific configurations to be evaluated before committing development effort to optimize for particular hardware.
Conclusion
We have presented a high-performance, portable library for MoE communication that achieves 10x faster performance compared to standard all-to-all communication while maintaining compatibility across diverse hardware configurations. On single-node deployments with NVLink, our solution demonstrates 2.5x lower latency than previous implementations.
Our approach balances performance with portability through key innovations including GPU-initiated communication support, a split kernel architecture enabling computation-communication overlap, and efficient token dispatch using minimal NVSHMEM primitives. While approximately 2x slower than highly specialized implementations on multi-node setups, our library offers superior flexibility across various network environments (NVLink, CX-7, and EFA).
As MoE models continue to scale, efficient communication strategies like ours will become increasingly important for practical deployment. Our fully open-source implementation is available at https://github.com/ppl-ai/pplx-kernels.
References
Efficient and Portable Mixture-of-Experts Communication
An overview of portable Mixture-of-Experts (MoE) communication, focusing on optimizing GPU parallelism and reducing latency in large-scale AI models
We present a high-performance, portable, open-source library for Mixture-of-Experts (MoE) communication that achieves 10x faster performance compared to standard All-to-All communication primitives.
Our implementation features several key technical innovations that deliver superior MoE communication efficiency:
GPU-initiated communication (IBGDA): Supports direct GPU-to-NIC communication, significantly reducing latency by bypassing CPU involvement
Communication and computation overlap: Split kernel architecture with separate send and receive stages enables computation to proceed concurrently with network transfers
Fastest single-node performance: 2.5x lower latency than the previously fastest implementation on single-node configurations
Efficient and portable multi-node performance: Our implementation achieves speeds up to 10x faster than standard all-to-all communication. Although approximately 2x slower than highly specialized implementations, our approach offers better portability across NVSHMEM versions and network environments (NVLink, CX-7, and EFA)
The library is fully open-source and available at https://github.com/ppl-ai/pplx-kernels.
In this article, we explore the challenges of expert parallelism in large-scale MoE models and describe our approach to efficient token dispatch and combination across distributed GPU environments.
Introduction
Mixture-of-Experts (MoE) models, such as DeepSeek R1 and Mixtral 8x7B, improve upon dense models by limiting the number of weights that are activated for each token. Instead of achieving high parameter counts by increasing the number of layers or increasing the size of the linear layers in the Multi-Layer Perceptron (MLP) of each decoder layer, MoE models replace the traditional MLP with multiple experts and a router. For example, out of the 671B parameters of DeepSeek R1, only 37B are multiplied with a given token during inference. This greatly reduces the amount of computation required to decode a token and improves latency compared to a dense model.
MoE models present some additional challenges for inference, compared to dense models. While the experts themselves are small MLP layers, each decoder layer includes a router that decides which experts a token is dispatched to, with each token being dispatched to multiple experts. The router is typically a small linear layer producing a probability distribution. Usually the experts with the top-K scores are picked and the final activation is computed as a weighted average, summing the expert outputs multiplied by a weight derived from the probability distribution of the router.
To minimize latency in distributed MoE systems, parallelism can be exploited across multiple devices, but this introduces communication challenges. Models such as Mixtral 8x7B or Llama-70B fit within 8 devices across a single node, benefiting from fast and low-latency NVLink interconnections (up to 900Gbps). However, larger models require sharding experts across multiple nodes using InfiniBand (peaking at 400Gbps), which introduces additional latency challenges.

In this article, we explore the problem of expert parallelism, describing our implementation of GPU kernels that efficiently dispatch tokens to devices holding the assigned experts and collect them back by computing the weighted sum of expert outputs. While aggressively optimized implementations targeting specific network hardware exist, we present a highly portable implementation that relies on a minimal set of NVSHMEM primitives yet delivers performance 10x faster than standard all-to-all communication. Our implementation achieves state-of-the-art performance on single-node configurations while maintaining excellent portability across various network environments.
Sharding and Parallelism
For efficient inference, the weights of a model must be held in device memory, while also leaving sufficient space for activations, KV caches and other buffers required by the forward pass through the model. The most capable models exceed the capacity of even the most capable GPUs, thus inference must be spread across multiple devices, which can collectively store the weights of the model. Based on the sharding schemes of weights, different communication and computation schemes must be used to offload computation and synchronize devices.

Expert Parallelism (EP), illustrated in the first figure, only parallelizes the expert computation. Different experts are assigned to different devices, which hold their weights. After routing, tokens are sent to the corresponding device, with the results being gathered and accumulated afterwards. The complexity of routing depends on the degree of parallelism in the other parts of the model: replicating other layers could eliminate the need for routing altogether, as each rank can select the tokens from a locally replicated routing table. However, if only one of the rank run routing, a broadcast is required to dispatch tokens, indices and weights to their respective experts. Finally, an all-gather or an all-to-all broadcast synchronizes the output tokens with whichever rank continues the execution of the model. Such an implementation is relatively simple as torch already exposes the required primitives, albeit some benefits could be gained by fusing accumulation with the gather operation collecting the tokens from the experts.

Expert-only parallelism does not scale ideally, as nodes in a cluster might be idle while the model is running non-expert layers, such as attention, norm and sampling. However, the computation of these layers, primarily attention, can also benefit from Tensor Parallelism (TP). Most models rely on multi-head attention, meaning that the attention heads and their corresponding Q, K and V projections can also be sharded across devices, replicating or gathering the slices between various layers. If attention is spread across all devices, an all-gather can synchronize the activations, allowing routing to be replicated, requiring synchronization primitives similar to the expert parallelism case for an efficient implementation. However, there are limits to parallelism at this level, as reducing the number of attention heads below a certain threshold will yield diminishing returns.

To best utilize all devices and support a very high degree of expert parallelism, of up to 128 or 256 GPUs, Data Parallelism (DP) is required. Under this scheme, the devices a model is split across are grouped to handle requests concurrently, computing attention and maintaining KV caches sharded across their local group. Multiple instances of these parallel groups collaborate on expert evaluation, with each hosting a different subset of the expert weights. Based on the number of attention heads, a group may typically span up to the size of an entire node, as intra-node communication is faster. For example, in the figure above, one DP rank independently services two requests, A and B, handling attention, norm and any other bookkeeping for the requests. The other DP rank processes a distinct request C. However, the first node hosts half of the experts, while the other node the other half, thus after routing tokens from A and B might be sent to the second node and vice-versa. This leads to a sparse communication problem: each device might send a different number of tokens to any other destination rank. Existing primitives from torch, primarily all_to_all, are not particularly well suited, as they might require some form of padding or GPU-to-CPU synchronization and metadata broadcast. To implement communication effectively, custom kernels are required to process the routing information and initiate communication from the devices without CPU intervention. After routing, a dispatch kernel must send tokens to the ranks they were routed to, while on the combine side the activations belonging to the requests in the current DP group must be collected. Additionally, work must be balanced within a DP group which may have multiple devices, in order to minimize broadcast and correctly replicate the tokens on all ranks for the subsequent layers.
NVSHMEM
NVSHMEM is an NVIDIA-specific OpenSHMEM implementation, providing portable inter-device communication facilities that abstract away the complexity of the underlying hardware. The API can express device-to-device reads and writes, which are mapped to the primitives of individual transport layers. Both NVLink and RDMA are supported, granting a degree of portability and allowing NVSHMEM kernels to operate within a node or across a cluster of multiple nodes. In our kernels, we use remote memory writes and signaling operations in order to transfer data and synchronize the participating devices.
NVSHMEM operations are built around the concept of symmetric memory: they operate on buffers which have been allocated on all the devices participating in inter-device communication. A local device can use an address derived within its own local buffer alongside the index of the target rank to specify the destination of an operation. The figure below illustrates this concept: both GPUs allocate symmetric buffers of the same size, retaining src
pointers to them. The first rank wants to send 3 integers to the second one, placing them at the start of the buffer: nvshmem_int_put_nbi
derives the start address from the local buffer, specifying the target device. The second rank derives an offset from its own buffer, sending one element to the first device, offsetting by one. While destination addresses must always be symmetric buffers allocated using nvshmem_alloc
, source buffers can be arbitrary regions of device memory, provided they are pre-registered with NVSHMEM.
While NVSHMEM provides a wide range of primitives, our kernels rely on only 3 functions, building all synchronization and fencing upon them.

nvshmemx_putmem_signal_nbi_warp
: Transfers a block of data from one device to another, while also setting a flag on the remote device. The operation either sets (NVSHMEM_SIGNAL_SET
) or increments (NVSHMEM_SIGNAL_ADD
) a 64-bit location. The flag is updated after the entire block of memory is transferred. If the remote device observed a change in the flag, it can safely access the buffer in its own memory. This function is useful for coalescing data transfer and synchronization.nvshmemx_signal_op
operates on a single memory location, typically a 64-bit flag, atomically setting or incrementing it. It is useful in sending over metadata and synchronizing devices.nvshmem_uint64_wait_until
is used on the receiving end of a signal, to poll a flag until the remote updates it.
Ensuring any form of ordering between operations through nvshmem_quiet
or nvshmem_fence
is expensive, thus our implementation of the kernels avoids the use of any other NVSHMEM primitives and minimizes any dependencies other than individual send-receive pairs. When sending data, we always use the non-blocking version of functions, without waiting for the data to be even sent out of the local rank. Other, implicit synchronization across the group of kernels ensures that these buffers are not destructively overwritten while the NICs need access to them.
NVSHMEM uses GPUDirect RDMA so that the NIC can directly read and write GPU memory for data exchange without involving CPUs. Furthermore, on ConnectX NICs, NVSHEMEM supports GPU-initiated communication (also known as Infiniband GPUDirect Async, or IBGDA), which allows GPU to post RDMA operations directly to NIC and to poll completion queue of NIC directly. On platforms that does not support GPU-initiated communication, NVSHMEM spawns a “proxy” thread on host CPU to initiate communication and poll completion on behalf of the GPU. NVSHMEM program is portable regardless of whether GPU-initiated communication is supported or not. However, GPU-initiated communication significantly cuts latency because it completely bypasses the detour to CPU.
Portable Kernels
We implement MoE communication through a pair of dispatch and combine kernels. The dispatch kernels are responsible for reading tokens and routing information on each rank, dispatching them to the appropriate experts. The combine kernels collect the activations produced by the experts and send them back to their source ranks, while also computing the weighted average from the selected experts based on the weights computed by the router. The kernels are further split into a send and receive component, in order to allow data transfers to be overlapped with computation. The send kernels are all non-blocking and non-synchronizing: they simply dispatch all the writes to the remotes. On the other end, the combine kernels only read from memory, waiting until all required data has been transferred. After dispatching work to the NICs, while the data is transferred asynchronously over the wire, the GPUs can do other useful work locally, such as applying shared experts or computing attention.
Each pair of kernels has its own symmetric memory buffers, which each rank allocating private storage to receive data from all other ranks. The maximal number of tokens that can be dispatched from a DP group is fixed across the ranks, which also sets the upper bound each rank can receive for each local expert from each DP rank. This allows sender ranks to derive a unique address on the destination rank to write to, without requiring any synchronization among them. After data is received, the kernels shuffle the tokens around to lay them out in memory in the optimal format required by the receiving kernels. While the buffers have a sizable dimension, they are re-used across all sequential layers of a model.

The only form of global synchronization, as illustrated in the figure above, is implemented in the combine-receive kernel. Once data is dispatched from send, a rank returns and continues execution when it receives all the tokens it requires from the other ranks and enters the combine send kernels as soon as they are processed. The barrier in combine-receive ensures that no rank can run ahead and start dispatch-send while any other rank is still waiting to receive data, potentially causing destructive overlapping. Synchronization is done in the combine kernel for simplicity, as an optimization the same barrier could be implementing between the dispatch receive and the combine send kernels, further reducing latency by allowing the NICs more time to settle the transfers while the GPU is running other kernels.
Both kernels are split across all available SMs of the devices: while the dispatch send and combine receive kernels must parallelize across a per-rank maximum token count (max_m
), the dispatch receive and combine send kernels must handle that many tokens from each rank and local expert, for a maximum of max_m * (num_experts // EP) * (EP // TP)
.
Dispatch
The dispatch kernel is responsible for dispatching tokens and block scaling factors to peers which host the experts the tokens were routed to. The sender side relies on warp specialization to parallelize two tasks: aggregating routing information to determine how many tokens are sent from each rank and packing the tokens into messages to be dispatched to the remote. The receiver side first waits for all the token counts to be received, then it reads and unpacks the payloads from shared memory into the tensors that will be passed on to the downstream Grouped GEMM kernels. It also stores information into buffers shared with the combine kernels, to indicate where each token should be sent back. This information is required as the receive kernel shuffles tokens around in contiguous buffers in a non-deterministic order.

In the sender part, a single warp of 32 threads on each block is responsible for reading the routing information and counting the number of tokens each destination expert is expected to receive from this rank, sending the count plus one using nvshmem_signal_op
. The count is incremented by one, as the transition from zero to non-zero on the remote end signifies the receipt of the counts from a source rank. In parallel, the remaining warps co-operate to copy tokens into symmetric memory across all blocks in parallel, packing activations, scaling factors and their index on the local rank into a contiguous chunk. The index is required by the combine sender to determine the address where the token will be written to. Next, after ensuring all the data has been copied through a barrier across the warp groups, the warps yet again operate independently, each sending the same buffer to a different expert in paralle. The tokens are sent using nvshmemx_putmem_signal_nbi_warp
, which also atomically increases the count of sent tokens from the local rank on the remote device. Within a DP group, since each rank ows a replica of the token to be sent, dispatch is balanced evenly, with each device sending out a subset of the tokens.
On the receive end, all kernels first synchronize with the sender ranks by waiting for the total token counts to be sent over. Afterwards, they all wait for the atomically incremented sent token counts to settle to the total counts, indicating that all the payloads from the source ranks have also been sent over, thanks to the semantics of the putmem
call. The kernels poll on the counts using nvshmem_uint64_wait_until
, parallelizing the operation across all blocks and threads. Subsequently, a cross-block barrier ensures that no block reads the buffers unless all data has been correctly received. Spread across blocks and synchronized via an atomically incremented counter, the tokens are copied from symmetric memory into regular contiguous buffers which become the inputs to the downstream kernels implementing the experts. The source rank, expert index and token index are stored separately, exactly pinpointing the location where the combine kernel has to send the activations. Even though tokens from within a DP group are sent from different devices, they are all grouped together to be passed on to the corresponding expert.
Combine
The combine kernels are yet again split into send and receiver halves: the senders copy the un-quantized 16-bit activations over to the remote devices, with the receive kernels waiting for the data to be sent over and computing the weighted average of all expert contributions locally. Additionally, they also act as a barrier, to synchronize the dispatch-combine sequence: each rank sets a flag on each peer on entry to the send kernels, with the receive kernels not being allowed to return unless they observe the flag being set. The latency of synchronization is minimal, as it overlaps with the actual communication and computation.

On the sender side, the kernels traverse the list of tokens assigned to all local experts in parallel, writing them to a buffer on the destination rank. The target rank, expert index and token index are read from the per-token buffers populated by the scater kernels upon the receipt of the tokens. Each sender has its own private memory region per expert to write to, as indicated in the figure above, avoiding the need to synchronize. Similarly to dispatch, combine atomically increments per token counters on the destination rank to indicate the receipt of the data: when the counter matches the number of experts a token was dispatched to, the token contents can be accessed.
In the receive kernel, the list of tokens is traversed in parallel across multiple blocks, waiting for their contents to arrive by polling the flag set by the signalling operation. Upon arrival, the payloads are read from the private buffers, with the routing table indicating which buffer to read from and what weight to assign to each expert. The results are then written to externally allocated tensors, with the kernel finishing execution once all devices passed the barrier. Across a DP group, all ranks receive a copy of each expert activation to compute their own replicas.
Benchmarks
We evaluate our kernels on a cluster of 128 GPUs spread across 16 nodes connected via InfiniBand and CX-7 NICs. We try both GPUDirect Async (IBGDA) and Reliable Connection (RC) with a CPU proxy. We compare them to the highly optimized DeepSeek implementation, as well as the dense primitives provided by PyTorch (through NCCL) or NVSHMEM.

Although on the dense NVSHMEM all-to-all operation, the performance of IBRC and IBGDA is similar (6378 µs vs 6180 µs), IBGDA is significantly faster with the sparse kernels. Adding up Dispatch and Combine, IBGDA uses 902 µs whereas IBRC takes 3223 µs - a 3.6x improvement in latency. While all-to-all is bandwidth bound, the sparse kernels broadcast orders of magnitudes less data, being bound by latency. By triggering network transfers directly from a GPU, without requiring a CPU proxy to coordinate the GPU and the NIC, end-to-end latency is significantly reduced.
While our portable kernels are about 2x slower than the highly optimized DeepSeek kernels, they improve latency by 10x compared to the dense kernels (902 µs vs 9944 µs).
Additionally, the split into sender-receiver components also allow some of the latency to be hidden away, unlike the library primitives.

On single-node (EP8), NVSHMEM utilizes NVLINK for transportation, delivering lower latency and higher throughput than inter-node networking. Our portable kernels are about 2.5x faster than DeepEP on single-node (186µs vs 481 µs).
Further Improvements
The kernels described here outperform the built-in primitives of ML frameworks and offer decent performance without over-specializing for particular inter-device transports, such as NVLink or InfiniBand. Besides the already-mentioned opportunities, further performance gains are attainable by replacing the communication primitives with more specialized versions. For example, across NVLink, the use of symmetric memory could be replaced with buffers shared across the devices, eliminating some copying and allowing for finer-grained synchronizations schemes, across individual tokens instead of token batches. Across InfiniBand, an implementation could access the underlying queue pairs directly, eliminating the need to transfer both the expected and sent token counts and flushing tokens from the queues into output buffers as they arrive. However, such implementations come at the cost of portability, whilst our implementations allows clusters with specific configurations to be evaluated before committing development effort to optimize for particular hardware.
Conclusion
We have presented a high-performance, portable library for MoE communication that achieves 10x faster performance compared to standard all-to-all communication while maintaining compatibility across diverse hardware configurations. On single-node deployments with NVLink, our solution demonstrates 2.5x lower latency than previous implementations.
Our approach balances performance with portability through key innovations including GPU-initiated communication support, a split kernel architecture enabling computation-communication overlap, and efficient token dispatch using minimal NVSHMEM primitives. While approximately 2x slower than highly specialized implementations on multi-node setups, our library offers superior flexibility across various network environments (NVLink, CX-7, and EFA).
As MoE models continue to scale, efficient communication strategies like ours will become increasingly important for practical deployment. Our fully open-source implementation is available at https://github.com/ppl-ai/pplx-kernels.
References
Efficient and Portable Mixture-of-Experts Communication
An overview of portable Mixture-of-Experts (MoE) communication, focusing on optimizing GPU parallelism and reducing latency in large-scale AI models
We present a high-performance, portable, open-source library for Mixture-of-Experts (MoE) communication that achieves 10x faster performance compared to standard All-to-All communication primitives.
Our implementation features several key technical innovations that deliver superior MoE communication efficiency:
GPU-initiated communication (IBGDA): Supports direct GPU-to-NIC communication, significantly reducing latency by bypassing CPU involvement
Communication and computation overlap: Split kernel architecture with separate send and receive stages enables computation to proceed concurrently with network transfers
Fastest single-node performance: 2.5x lower latency than the previously fastest implementation on single-node configurations
Efficient and portable multi-node performance: Our implementation achieves speeds up to 10x faster than standard all-to-all communication. Although approximately 2x slower than highly specialized implementations, our approach offers better portability across NVSHMEM versions and network environments (NVLink, CX-7, and EFA)
The library is fully open-source and available at https://github.com/ppl-ai/pplx-kernels.
In this article, we explore the challenges of expert parallelism in large-scale MoE models and describe our approach to efficient token dispatch and combination across distributed GPU environments.
Introduction
Mixture-of-Experts (MoE) models, such as DeepSeek R1 and Mixtral 8x7B, improve upon dense models by limiting the number of weights that are activated for each token. Instead of achieving high parameter counts by increasing the number of layers or increasing the size of the linear layers in the Multi-Layer Perceptron (MLP) of each decoder layer, MoE models replace the traditional MLP with multiple experts and a router. For example, out of the 671B parameters of DeepSeek R1, only 37B are multiplied with a given token during inference. This greatly reduces the amount of computation required to decode a token and improves latency compared to a dense model.
MoE models present some additional challenges for inference, compared to dense models. While the experts themselves are small MLP layers, each decoder layer includes a router that decides which experts a token is dispatched to, with each token being dispatched to multiple experts. The router is typically a small linear layer producing a probability distribution. Usually the experts with the top-K scores are picked and the final activation is computed as a weighted average, summing the expert outputs multiplied by a weight derived from the probability distribution of the router.
To minimize latency in distributed MoE systems, parallelism can be exploited across multiple devices, but this introduces communication challenges. Models such as Mixtral 8x7B or Llama-70B fit within 8 devices across a single node, benefiting from fast and low-latency NVLink interconnections (up to 900Gbps). However, larger models require sharding experts across multiple nodes using InfiniBand (peaking at 400Gbps), which introduces additional latency challenges.

In this article, we explore the problem of expert parallelism, describing our implementation of GPU kernels that efficiently dispatch tokens to devices holding the assigned experts and collect them back by computing the weighted sum of expert outputs. While aggressively optimized implementations targeting specific network hardware exist, we present a highly portable implementation that relies on a minimal set of NVSHMEM primitives yet delivers performance 10x faster than standard all-to-all communication. Our implementation achieves state-of-the-art performance on single-node configurations while maintaining excellent portability across various network environments.
Sharding and Parallelism
For efficient inference, the weights of a model must be held in device memory, while also leaving sufficient space for activations, KV caches and other buffers required by the forward pass through the model. The most capable models exceed the capacity of even the most capable GPUs, thus inference must be spread across multiple devices, which can collectively store the weights of the model. Based on the sharding schemes of weights, different communication and computation schemes must be used to offload computation and synchronize devices.

Expert Parallelism (EP), illustrated in the first figure, only parallelizes the expert computation. Different experts are assigned to different devices, which hold their weights. After routing, tokens are sent to the corresponding device, with the results being gathered and accumulated afterwards. The complexity of routing depends on the degree of parallelism in the other parts of the model: replicating other layers could eliminate the need for routing altogether, as each rank can select the tokens from a locally replicated routing table. However, if only one of the rank run routing, a broadcast is required to dispatch tokens, indices and weights to their respective experts. Finally, an all-gather or an all-to-all broadcast synchronizes the output tokens with whichever rank continues the execution of the model. Such an implementation is relatively simple as torch already exposes the required primitives, albeit some benefits could be gained by fusing accumulation with the gather operation collecting the tokens from the experts.

Expert-only parallelism does not scale ideally, as nodes in a cluster might be idle while the model is running non-expert layers, such as attention, norm and sampling. However, the computation of these layers, primarily attention, can also benefit from Tensor Parallelism (TP). Most models rely on multi-head attention, meaning that the attention heads and their corresponding Q, K and V projections can also be sharded across devices, replicating or gathering the slices between various layers. If attention is spread across all devices, an all-gather can synchronize the activations, allowing routing to be replicated, requiring synchronization primitives similar to the expert parallelism case for an efficient implementation. However, there are limits to parallelism at this level, as reducing the number of attention heads below a certain threshold will yield diminishing returns.

To best utilize all devices and support a very high degree of expert parallelism, of up to 128 or 256 GPUs, Data Parallelism (DP) is required. Under this scheme, the devices a model is split across are grouped to handle requests concurrently, computing attention and maintaining KV caches sharded across their local group. Multiple instances of these parallel groups collaborate on expert evaluation, with each hosting a different subset of the expert weights. Based on the number of attention heads, a group may typically span up to the size of an entire node, as intra-node communication is faster. For example, in the figure above, one DP rank independently services two requests, A and B, handling attention, norm and any other bookkeeping for the requests. The other DP rank processes a distinct request C. However, the first node hosts half of the experts, while the other node the other half, thus after routing tokens from A and B might be sent to the second node and vice-versa. This leads to a sparse communication problem: each device might send a different number of tokens to any other destination rank. Existing primitives from torch, primarily all_to_all, are not particularly well suited, as they might require some form of padding or GPU-to-CPU synchronization and metadata broadcast. To implement communication effectively, custom kernels are required to process the routing information and initiate communication from the devices without CPU intervention. After routing, a dispatch kernel must send tokens to the ranks they were routed to, while on the combine side the activations belonging to the requests in the current DP group must be collected. Additionally, work must be balanced within a DP group which may have multiple devices, in order to minimize broadcast and correctly replicate the tokens on all ranks for the subsequent layers.
NVSHMEM
NVSHMEM is an NVIDIA-specific OpenSHMEM implementation, providing portable inter-device communication facilities that abstract away the complexity of the underlying hardware. The API can express device-to-device reads and writes, which are mapped to the primitives of individual transport layers. Both NVLink and RDMA are supported, granting a degree of portability and allowing NVSHMEM kernels to operate within a node or across a cluster of multiple nodes. In our kernels, we use remote memory writes and signaling operations in order to transfer data and synchronize the participating devices.
NVSHMEM operations are built around the concept of symmetric memory: they operate on buffers which have been allocated on all the devices participating in inter-device communication. A local device can use an address derived within its own local buffer alongside the index of the target rank to specify the destination of an operation. The figure below illustrates this concept: both GPUs allocate symmetric buffers of the same size, retaining src
pointers to them. The first rank wants to send 3 integers to the second one, placing them at the start of the buffer: nvshmem_int_put_nbi
derives the start address from the local buffer, specifying the target device. The second rank derives an offset from its own buffer, sending one element to the first device, offsetting by one. While destination addresses must always be symmetric buffers allocated using nvshmem_alloc
, source buffers can be arbitrary regions of device memory, provided they are pre-registered with NVSHMEM.
While NVSHMEM provides a wide range of primitives, our kernels rely on only 3 functions, building all synchronization and fencing upon them.

nvshmemx_putmem_signal_nbi_warp
: Transfers a block of data from one device to another, while also setting a flag on the remote device. The operation either sets (NVSHMEM_SIGNAL_SET
) or increments (NVSHMEM_SIGNAL_ADD
) a 64-bit location. The flag is updated after the entire block of memory is transferred. If the remote device observed a change in the flag, it can safely access the buffer in its own memory. This function is useful for coalescing data transfer and synchronization.nvshmemx_signal_op
operates on a single memory location, typically a 64-bit flag, atomically setting or incrementing it. It is useful in sending over metadata and synchronizing devices.nvshmem_uint64_wait_until
is used on the receiving end of a signal, to poll a flag until the remote updates it.
Ensuring any form of ordering between operations through nvshmem_quiet
or nvshmem_fence
is expensive, thus our implementation of the kernels avoids the use of any other NVSHMEM primitives and minimizes any dependencies other than individual send-receive pairs. When sending data, we always use the non-blocking version of functions, without waiting for the data to be even sent out of the local rank. Other, implicit synchronization across the group of kernels ensures that these buffers are not destructively overwritten while the NICs need access to them.
NVSHMEM uses GPUDirect RDMA so that the NIC can directly read and write GPU memory for data exchange without involving CPUs. Furthermore, on ConnectX NICs, NVSHEMEM supports GPU-initiated communication (also known as Infiniband GPUDirect Async, or IBGDA), which allows GPU to post RDMA operations directly to NIC and to poll completion queue of NIC directly. On platforms that does not support GPU-initiated communication, NVSHMEM spawns a “proxy” thread on host CPU to initiate communication and poll completion on behalf of the GPU. NVSHMEM program is portable regardless of whether GPU-initiated communication is supported or not. However, GPU-initiated communication significantly cuts latency because it completely bypasses the detour to CPU.
Portable Kernels
We implement MoE communication through a pair of dispatch and combine kernels. The dispatch kernels are responsible for reading tokens and routing information on each rank, dispatching them to the appropriate experts. The combine kernels collect the activations produced by the experts and send them back to their source ranks, while also computing the weighted average from the selected experts based on the weights computed by the router. The kernels are further split into a send and receive component, in order to allow data transfers to be overlapped with computation. The send kernels are all non-blocking and non-synchronizing: they simply dispatch all the writes to the remotes. On the other end, the combine kernels only read from memory, waiting until all required data has been transferred. After dispatching work to the NICs, while the data is transferred asynchronously over the wire, the GPUs can do other useful work locally, such as applying shared experts or computing attention.
Each pair of kernels has its own symmetric memory buffers, which each rank allocating private storage to receive data from all other ranks. The maximal number of tokens that can be dispatched from a DP group is fixed across the ranks, which also sets the upper bound each rank can receive for each local expert from each DP rank. This allows sender ranks to derive a unique address on the destination rank to write to, without requiring any synchronization among them. After data is received, the kernels shuffle the tokens around to lay them out in memory in the optimal format required by the receiving kernels. While the buffers have a sizable dimension, they are re-used across all sequential layers of a model.

The only form of global synchronization, as illustrated in the figure above, is implemented in the combine-receive kernel. Once data is dispatched from send, a rank returns and continues execution when it receives all the tokens it requires from the other ranks and enters the combine send kernels as soon as they are processed. The barrier in combine-receive ensures that no rank can run ahead and start dispatch-send while any other rank is still waiting to receive data, potentially causing destructive overlapping. Synchronization is done in the combine kernel for simplicity, as an optimization the same barrier could be implementing between the dispatch receive and the combine send kernels, further reducing latency by allowing the NICs more time to settle the transfers while the GPU is running other kernels.
Both kernels are split across all available SMs of the devices: while the dispatch send and combine receive kernels must parallelize across a per-rank maximum token count (max_m
), the dispatch receive and combine send kernels must handle that many tokens from each rank and local expert, for a maximum of max_m * (num_experts // EP) * (EP // TP)
.
Dispatch
The dispatch kernel is responsible for dispatching tokens and block scaling factors to peers which host the experts the tokens were routed to. The sender side relies on warp specialization to parallelize two tasks: aggregating routing information to determine how many tokens are sent from each rank and packing the tokens into messages to be dispatched to the remote. The receiver side first waits for all the token counts to be received, then it reads and unpacks the payloads from shared memory into the tensors that will be passed on to the downstream Grouped GEMM kernels. It also stores information into buffers shared with the combine kernels, to indicate where each token should be sent back. This information is required as the receive kernel shuffles tokens around in contiguous buffers in a non-deterministic order.

In the sender part, a single warp of 32 threads on each block is responsible for reading the routing information and counting the number of tokens each destination expert is expected to receive from this rank, sending the count plus one using nvshmem_signal_op
. The count is incremented by one, as the transition from zero to non-zero on the remote end signifies the receipt of the counts from a source rank. In parallel, the remaining warps co-operate to copy tokens into symmetric memory across all blocks in parallel, packing activations, scaling factors and their index on the local rank into a contiguous chunk. The index is required by the combine sender to determine the address where the token will be written to. Next, after ensuring all the data has been copied through a barrier across the warp groups, the warps yet again operate independently, each sending the same buffer to a different expert in paralle. The tokens are sent using nvshmemx_putmem_signal_nbi_warp
, which also atomically increases the count of sent tokens from the local rank on the remote device. Within a DP group, since each rank ows a replica of the token to be sent, dispatch is balanced evenly, with each device sending out a subset of the tokens.
On the receive end, all kernels first synchronize with the sender ranks by waiting for the total token counts to be sent over. Afterwards, they all wait for the atomically incremented sent token counts to settle to the total counts, indicating that all the payloads from the source ranks have also been sent over, thanks to the semantics of the putmem
call. The kernels poll on the counts using nvshmem_uint64_wait_until
, parallelizing the operation across all blocks and threads. Subsequently, a cross-block barrier ensures that no block reads the buffers unless all data has been correctly received. Spread across blocks and synchronized via an atomically incremented counter, the tokens are copied from symmetric memory into regular contiguous buffers which become the inputs to the downstream kernels implementing the experts. The source rank, expert index and token index are stored separately, exactly pinpointing the location where the combine kernel has to send the activations. Even though tokens from within a DP group are sent from different devices, they are all grouped together to be passed on to the corresponding expert.
Combine
The combine kernels are yet again split into send and receiver halves: the senders copy the un-quantized 16-bit activations over to the remote devices, with the receive kernels waiting for the data to be sent over and computing the weighted average of all expert contributions locally. Additionally, they also act as a barrier, to synchronize the dispatch-combine sequence: each rank sets a flag on each peer on entry to the send kernels, with the receive kernels not being allowed to return unless they observe the flag being set. The latency of synchronization is minimal, as it overlaps with the actual communication and computation.

On the sender side, the kernels traverse the list of tokens assigned to all local experts in parallel, writing them to a buffer on the destination rank. The target rank, expert index and token index are read from the per-token buffers populated by the scater kernels upon the receipt of the tokens. Each sender has its own private memory region per expert to write to, as indicated in the figure above, avoiding the need to synchronize. Similarly to dispatch, combine atomically increments per token counters on the destination rank to indicate the receipt of the data: when the counter matches the number of experts a token was dispatched to, the token contents can be accessed.
In the receive kernel, the list of tokens is traversed in parallel across multiple blocks, waiting for their contents to arrive by polling the flag set by the signalling operation. Upon arrival, the payloads are read from the private buffers, with the routing table indicating which buffer to read from and what weight to assign to each expert. The results are then written to externally allocated tensors, with the kernel finishing execution once all devices passed the barrier. Across a DP group, all ranks receive a copy of each expert activation to compute their own replicas.
Benchmarks
We evaluate our kernels on a cluster of 128 GPUs spread across 16 nodes connected via InfiniBand and CX-7 NICs. We try both GPUDirect Async (IBGDA) and Reliable Connection (RC) with a CPU proxy. We compare them to the highly optimized DeepSeek implementation, as well as the dense primitives provided by PyTorch (through NCCL) or NVSHMEM.

Although on the dense NVSHMEM all-to-all operation, the performance of IBRC and IBGDA is similar (6378 µs vs 6180 µs), IBGDA is significantly faster with the sparse kernels. Adding up Dispatch and Combine, IBGDA uses 902 µs whereas IBRC takes 3223 µs - a 3.6x improvement in latency. While all-to-all is bandwidth bound, the sparse kernels broadcast orders of magnitudes less data, being bound by latency. By triggering network transfers directly from a GPU, without requiring a CPU proxy to coordinate the GPU and the NIC, end-to-end latency is significantly reduced.
While our portable kernels are about 2x slower than the highly optimized DeepSeek kernels, they improve latency by 10x compared to the dense kernels (902 µs vs 9944 µs).
Additionally, the split into sender-receiver components also allow some of the latency to be hidden away, unlike the library primitives.

On single-node (EP8), NVSHMEM utilizes NVLINK for transportation, delivering lower latency and higher throughput than inter-node networking. Our portable kernels are about 2.5x faster than DeepEP on single-node (186µs vs 481 µs).
Further Improvements
The kernels described here outperform the built-in primitives of ML frameworks and offer decent performance without over-specializing for particular inter-device transports, such as NVLink or InfiniBand. Besides the already-mentioned opportunities, further performance gains are attainable by replacing the communication primitives with more specialized versions. For example, across NVLink, the use of symmetric memory could be replaced with buffers shared across the devices, eliminating some copying and allowing for finer-grained synchronizations schemes, across individual tokens instead of token batches. Across InfiniBand, an implementation could access the underlying queue pairs directly, eliminating the need to transfer both the expected and sent token counts and flushing tokens from the queues into output buffers as they arrive. However, such implementations come at the cost of portability, whilst our implementations allows clusters with specific configurations to be evaluated before committing development effort to optimize for particular hardware.
Conclusion
We have presented a high-performance, portable library for MoE communication that achieves 10x faster performance compared to standard all-to-all communication while maintaining compatibility across diverse hardware configurations. On single-node deployments with NVLink, our solution demonstrates 2.5x lower latency than previous implementations.
Our approach balances performance with portability through key innovations including GPU-initiated communication support, a split kernel architecture enabling computation-communication overlap, and efficient token dispatch using minimal NVSHMEM primitives. While approximately 2x slower than highly specialized implementations on multi-node setups, our library offers superior flexibility across various network environments (NVLink, CX-7, and EFA).
As MoE models continue to scale, efficient communication strategies like ours will become increasingly important for practical deployment. Our fully open-source implementation is available at https://github.com/ppl-ai/pplx-kernels.
References
Share this article