

Written by
AI Team
Published on
Aug 8, 2025
GPT-OSS on Day 0
Perplexity is committed to making the best AI models available to curious people who demand accurate and trustworthy answers and agentic actions. OpenAI recently announced their open-weight models (gpt-oss-20b
and gpt-oss-120b
). As with many leading models, Perplexity is among the first organizations to evaluate gpt-oss-20b
and gpt-oss-120b
.
In this post, we share the infrastructure decisions of our in-house inference stack that made Day-0 support possible. We focus on serving these models on NVIDIA H200 GPUs, detailing the kernel changes, deployment choices, and speed-cost trade-offs.
GPT-OSS on Hopper
The open-weight models are shipped using MXFP4 quantization, which helps them fit in the memory on consumer-grade hardware or achieve peak throughput on NVIDIA Blackwell. However, for initial evaluation, we wanted to run them on existing H200 Hopper clusters with minimal inference-engine changes. Hopper does not have dedicated FP4 tensor cores, which were introduced by Blackwell. Consequently, we decided to serve the models with FP8 precision instead to minimize the need for custom kernels and make best use of the available hardware.
At a high level, transformer LLMs are structurally simple:
Input embedding
A sequence of transformer layers. Each contains an attention block, and a dense MLP or sparse MoE block.
Output logit projection
When bringing up a new transformer model, we are primarily interested in learning to what extent they match existing models to identify the changes required. OpenAI kindly shared an overview of their models ahead of time, allowing us to tweak our inference setup accordingly. We found that the model did require quite a number of small changes.
Attention: GQA with Attention Sinks
Each attention head has a pre-trained sink parameter which is factored into Softmax
YaRN positional encoding, similar to DeepSeek-V3
QKV projections have biases, similar to Qwen 2
Output projection has bias
MLP: Sparse Mixture-of-Expert (MoE)
SwiGLU activation function, different from most open-weight models
Softmax after Top-K for expert weights
Experts have biases
Expecting the number of parameters to be around 100B-200B, we decided to:
Extend FlashInfer to support attention sinks
Re-use the MoE implementation already built for the DeepSeek-V3 family of models
Choose FP8 128x128 block scaling as the quantization scheme, extending the DeepGEMM kernels to support a bias term.
In-house Inference Engine
Our inference efforts are backed by our in-house inference engine, ROSE. We are building ROSE out to be a flexible framework that allows new models to be quickly brought up. Additionally, once we are committed to serving a model at scale, it provides a platform to optimize and boost performance to reach the state of the art.
ROSE exposes an LLM engine, which can load a model and generate decoded tokens for it through multiple decoders, including speculative and MTP decoders. The models themselves are composed of a custom set of layers, which support various configuration knows for quantization and parallelism. Underneath the layers, it bundles both Triton and CUDA kernels that efficiently implement and fuse the computations performed by individual layers.
ROSE is primarily built in Python and relies of PyTorch for model definitions. While Python allows us to be flexible and adapt to new models, we are migrating most performance-sensitive components, such as serving and batch scheduling, to Rust.
To bring up a new model, we firstly define the model hierarchy and build a weight converter for it. After all weights are accounted for, we define a forward pass for the simplest TP=1 use case and ensure it works end-to-end. Finally, we implement and test support for various forms of parallelism before we release a new container and deploy the model on a cluster for evaluation. We then iterate on performance, relying on feedback from both dashboards and finer-grained benchmarks.
ROSE operates on numeric tokens, with little knowledge of chat formats outside of integration tests. In parallel, we had to look into adjusting surrounding infrastructure to support the Harmony tokenizer and correctly map requests to input token sequences.
With the GPT OSS model, we found that no changes were required to our decoders, CUDA graph harnesses or surrounding infrastructure, allowing us to focus our efforts on tweaking the kernels backing the model.
Kernels for GPT-OSS
Sink Attention
Sink attention introduces an extra bias factor prior to softmax, accumulated with the product of Q
and K
:
attn = softmax(q @ k^T * sm_scale + sink) * V
While a naive implementation is fairly trivial, efficient attention implementations, such as FlashInfer, parallelize the computation of attention across both independent heads and across the KV
sequence length. Additionally, for numerical stability, Softmax is computed online while also subtracting the maximal element for numerical stability:
softmax(x) = e^x / sum(e^x) = e ^ (x - max(x)) / sum(e ^ (x - max(x))
The online implementation tracked a running maximum m
initialized to -inf
and a scale d
to 1
. By initializing the maximum to the sink value of the current head, we would ensure that it would get accumulated into the Softmax correctly. We had to take care to ensure that this maximum would only be included in the first chunk if attention was split across multiple blocks along the sequence length.
We have adjusted the FlashInfer FA2 kernel template, used for both prefill and decode, to include the sink parameters.
MoE
The main difference between the MoE layer of different models is usually the routing scheme. Expecting substantial variability, ROSE implements this kernel in Triton, allowing us to easily adapt it to new models. In the case of GPT-OSS, we had to adjust weight scoring to perform softmax only across the Top-K select experts, unlike some DeepSeek-style models which compute softmax across all experts. We found it is critical for accuracy to compute the expert weights in bfloat16
precision.
The challenge with the MoE layer was due to our decision to re-use the existing DeepGEMM kernels, which lacked support for bias terms. Fortunately, over time we integrated deeply with them, making changes to better connect them with our custom all-to-all dispatch kernels. We added bias support to DeepGEMM by initializing the WGMMA accumulator registers to the bias values loaded from memory. This is a one-off during the processing of an MN block, thus the cost is likely negligible and the L2 cache ensures efficient access to these values.
Since these models do not readily benefit from Data Parallelism, we re-used TP-only dispatch and combine kernels. The inputs to dispatch are replicated across all rank, thus dispatch selects the tokens for the locally available experts. Combine performs the weighted accumulation across NVLink.
Performance Tweaks
Even though our goal was to bring up the models quickly, we did want to ensure that performance was still reasonable. After ensuring correctness against a reference implementation provided by OpenAI, we profiled our implementation using the default PyTorch profiler, watching out for block and thread allocations across kernels. The functionality we re-used was built out for models which had a larger hidden dimension of 7168, substantially larger than the 2880 of the GPT-OSS models. We specialized kernel launchers to pick better grid dimensions, significantly boosting throughput.
Tokenization and Channels
Harmony’s chat format has several important concepts for structuring conversations, reasoning content, and function calls. In this format, each message consists of a header and content:
<|start|>{header}<|message|>{content}<
The header carries metadata — most importantly the role (system, developer, user, assistant, tool), the channel, and a recipient such as to=functions.<function-name>
when invoking a tool call.
Channels (such as analysis, commentary, and final answer) make the model’s outputs more transparent and segmented. Because the format enforces this separation, the model can reason transparently (and be inspected for debugging) without leaking chain-of-thought into production. Recipient tags (to=
headers) tell the runtime exactly which actor should receive the next payload — whether that’s a function, the model itself, or the end user.
Our inference engine ROSE is format-agnostic. It receives a batch of input tokens and streams back output tokens. In our inference stack, the task of formatting inputs into a tokenized prompt and parsing LLM results back into structured messages is handled by the dedicated JSON API frontend service. We integrated Harmony by adding a new pluggable formatter/parser to this frontend, allowing the backend to remain unchanged while the frontend cleanly manages chat-format rules, tool calls, and streaming. This separation of concerns keeps the architecture modular, enabling quick adoption of new formats, response features with minimal backend changes.
Cost and Performance
Given the relatively small size of GPT-OSS models, we limit the deployment of a replica to a single node to avoid inter-node communication overhead. Our implementation supports flexible combinations of expert parallelism (EP), data parallelism (DP), and tensor parallelism (TP). To determine the optimal deployment setup, we benchmarked all combinations of sharding strategies across batch sizes from 1 to 128. Thanks to Prefill-Decode Disaggregation, we are able to discuss the performance-cost trade-off separately for prefill and decode.
For cost estimates, we assume a nominal H200 market rate of $3.00/hour (note: this is not our actual GPU price).
We discovered that setting batch_size=1 and DP=1 yields the best prefill performance. Different TP configurations then provide varying trade-offs between first-token latency and cost. The tables below present prefill latency and cost per million input tokens for various configurations and input lengths.
GPT-OSS 120B:
Input Length | 8192 | 32768 | 65536 | 128000 |
---|---|---|---|---|
EP1 DP1 TP1 | 0.364s, $0.037 | 2.421s, $0.062 | 7.375s, $0.094 | 23.902s, $0.156 |
EP2 DP1 TP2 | 0.190s, $0.039 | 1.245s, $0.063 | 3.786s, $0.096 | 12.116s, $0.158 |
EP4 DP1 TP4 | 0.165s, $0.067 | 0.900s, $0.092 | 2.449s, $0.125 | 7.165s, $0.187 |
EP8 DP1 TP8 | 0.179s, $0.145 | 0.839s, $0.171 | 2.003s, $0.204 | 5.119s, $0.267 |
GPT-OSS 20B:
Input Length | 8192 | 32768 | 65536 | 128000 |
---|---|---|---|---|
EP1 DP1 TP1 | 0.239s, $0.024 | 1.610s, $0.041 | 4.895s, $0.062 | 16.037s, $0.104 |
EP2 DP1 TP2 | 0.127s, $0.026 | 0.833s, $0.042 | 2.522s, $0.064 | 8.143s, $0.106 |
EP4 DP1 TP4 | 0.108s, $0.044 | 0.593s, $0.060 | 1.616s, $0.082 | 4.766s, $0.124 |
EP8 DP1 TP8 | 0.116s, $0.094 | 0.547s, $0.111 | 1.310s, $0.133 | 3.365s, $0.175 |
For decode, we discovered that EP4 DP1 TP4 is the best configuration for most cases. The four plots below illustrate the trade-off between decode speed and cost per million output tokens for both the 120B and 20B models at sequence lengths of 8192 and 32768.




We have also considered using the smaller 20B model as a speculative decoding draft model for the larger 120B model. However, because the number of activated parameters is similar (3B vs. 5B), the resulting speedups were marginal.
References
GPT-OSS on Day 0
Perplexity is committed to making the best AI models available to curious people who demand accurate and trustworthy answers and agentic actions. OpenAI recently announced their open-weight models (gpt-oss-20b
and gpt-oss-120b
). As with many leading models, Perplexity is among the first organizations to evaluate gpt-oss-20b
and gpt-oss-120b
.
In this post, we share the infrastructure decisions of our in-house inference stack that made Day-0 support possible. We focus on serving these models on NVIDIA H200 GPUs, detailing the kernel changes, deployment choices, and speed-cost trade-offs.
GPT-OSS on Hopper
The open-weight models are shipped using MXFP4 quantization, which helps them fit in the memory on consumer-grade hardware or achieve peak throughput on NVIDIA Blackwell. However, for initial evaluation, we wanted to run them on existing H200 Hopper clusters with minimal inference-engine changes. Hopper does not have dedicated FP4 tensor cores, which were introduced by Blackwell. Consequently, we decided to serve the models with FP8 precision instead to minimize the need for custom kernels and make best use of the available hardware.
At a high level, transformer LLMs are structurally simple:
Input embedding
A sequence of transformer layers. Each contains an attention block, and a dense MLP or sparse MoE block.
Output logit projection
When bringing up a new transformer model, we are primarily interested in learning to what extent they match existing models to identify the changes required. OpenAI kindly shared an overview of their models ahead of time, allowing us to tweak our inference setup accordingly. We found that the model did require quite a number of small changes.
Attention: GQA with Attention Sinks
Each attention head has a pre-trained sink parameter which is factored into Softmax
YaRN positional encoding, similar to DeepSeek-V3
QKV projections have biases, similar to Qwen 2
Output projection has bias
MLP: Sparse Mixture-of-Expert (MoE)
SwiGLU activation function, different from most open-weight models
Softmax after Top-K for expert weights
Experts have biases
Expecting the number of parameters to be around 100B-200B, we decided to:
Extend FlashInfer to support attention sinks
Re-use the MoE implementation already built for the DeepSeek-V3 family of models
Choose FP8 128x128 block scaling as the quantization scheme, extending the DeepGEMM kernels to support a bias term.
In-house Inference Engine
Our inference efforts are backed by our in-house inference engine, ROSE. We are building ROSE out to be a flexible framework that allows new models to be quickly brought up. Additionally, once we are committed to serving a model at scale, it provides a platform to optimize and boost performance to reach the state of the art.
ROSE exposes an LLM engine, which can load a model and generate decoded tokens for it through multiple decoders, including speculative and MTP decoders. The models themselves are composed of a custom set of layers, which support various configuration knows for quantization and parallelism. Underneath the layers, it bundles both Triton and CUDA kernels that efficiently implement and fuse the computations performed by individual layers.
ROSE is primarily built in Python and relies of PyTorch for model definitions. While Python allows us to be flexible and adapt to new models, we are migrating most performance-sensitive components, such as serving and batch scheduling, to Rust.
To bring up a new model, we firstly define the model hierarchy and build a weight converter for it. After all weights are accounted for, we define a forward pass for the simplest TP=1 use case and ensure it works end-to-end. Finally, we implement and test support for various forms of parallelism before we release a new container and deploy the model on a cluster for evaluation. We then iterate on performance, relying on feedback from both dashboards and finer-grained benchmarks.
ROSE operates on numeric tokens, with little knowledge of chat formats outside of integration tests. In parallel, we had to look into adjusting surrounding infrastructure to support the Harmony tokenizer and correctly map requests to input token sequences.
With the GPT OSS model, we found that no changes were required to our decoders, CUDA graph harnesses or surrounding infrastructure, allowing us to focus our efforts on tweaking the kernels backing the model.
Kernels for GPT-OSS
Sink Attention
Sink attention introduces an extra bias factor prior to softmax, accumulated with the product of Q
and K
:
attn = softmax(q @ k^T * sm_scale + sink) * V
While a naive implementation is fairly trivial, efficient attention implementations, such as FlashInfer, parallelize the computation of attention across both independent heads and across the KV
sequence length. Additionally, for numerical stability, Softmax is computed online while also subtracting the maximal element for numerical stability:
softmax(x) = e^x / sum(e^x) = e ^ (x - max(x)) / sum(e ^ (x - max(x))
The online implementation tracked a running maximum m
initialized to -inf
and a scale d
to 1
. By initializing the maximum to the sink value of the current head, we would ensure that it would get accumulated into the Softmax correctly. We had to take care to ensure that this maximum would only be included in the first chunk if attention was split across multiple blocks along the sequence length.
We have adjusted the FlashInfer FA2 kernel template, used for both prefill and decode, to include the sink parameters.
MoE
The main difference between the MoE layer of different models is usually the routing scheme. Expecting substantial variability, ROSE implements this kernel in Triton, allowing us to easily adapt it to new models. In the case of GPT-OSS, we had to adjust weight scoring to perform softmax only across the Top-K select experts, unlike some DeepSeek-style models which compute softmax across all experts. We found it is critical for accuracy to compute the expert weights in bfloat16
precision.
The challenge with the MoE layer was due to our decision to re-use the existing DeepGEMM kernels, which lacked support for bias terms. Fortunately, over time we integrated deeply with them, making changes to better connect them with our custom all-to-all dispatch kernels. We added bias support to DeepGEMM by initializing the WGMMA accumulator registers to the bias values loaded from memory. This is a one-off during the processing of an MN block, thus the cost is likely negligible and the L2 cache ensures efficient access to these values.
Since these models do not readily benefit from Data Parallelism, we re-used TP-only dispatch and combine kernels. The inputs to dispatch are replicated across all rank, thus dispatch selects the tokens for the locally available experts. Combine performs the weighted accumulation across NVLink.
Performance Tweaks
Even though our goal was to bring up the models quickly, we did want to ensure that performance was still reasonable. After ensuring correctness against a reference implementation provided by OpenAI, we profiled our implementation using the default PyTorch profiler, watching out for block and thread allocations across kernels. The functionality we re-used was built out for models which had a larger hidden dimension of 7168, substantially larger than the 2880 of the GPT-OSS models. We specialized kernel launchers to pick better grid dimensions, significantly boosting throughput.
Tokenization and Channels
Harmony’s chat format has several important concepts for structuring conversations, reasoning content, and function calls. In this format, each message consists of a header and content:
<|start|>{header}<|message|>{content}<
The header carries metadata — most importantly the role (system, developer, user, assistant, tool), the channel, and a recipient such as to=functions.<function-name>
when invoking a tool call.
Channels (such as analysis, commentary, and final answer) make the model’s outputs more transparent and segmented. Because the format enforces this separation, the model can reason transparently (and be inspected for debugging) without leaking chain-of-thought into production. Recipient tags (to=
headers) tell the runtime exactly which actor should receive the next payload — whether that’s a function, the model itself, or the end user.
Our inference engine ROSE is format-agnostic. It receives a batch of input tokens and streams back output tokens. In our inference stack, the task of formatting inputs into a tokenized prompt and parsing LLM results back into structured messages is handled by the dedicated JSON API frontend service. We integrated Harmony by adding a new pluggable formatter/parser to this frontend, allowing the backend to remain unchanged while the frontend cleanly manages chat-format rules, tool calls, and streaming. This separation of concerns keeps the architecture modular, enabling quick adoption of new formats, response features with minimal backend changes.
Cost and Performance
Given the relatively small size of GPT-OSS models, we limit the deployment of a replica to a single node to avoid inter-node communication overhead. Our implementation supports flexible combinations of expert parallelism (EP), data parallelism (DP), and tensor parallelism (TP). To determine the optimal deployment setup, we benchmarked all combinations of sharding strategies across batch sizes from 1 to 128. Thanks to Prefill-Decode Disaggregation, we are able to discuss the performance-cost trade-off separately for prefill and decode.
For cost estimates, we assume a nominal H200 market rate of $3.00/hour (note: this is not our actual GPU price).
We discovered that setting batch_size=1 and DP=1 yields the best prefill performance. Different TP configurations then provide varying trade-offs between first-token latency and cost. The tables below present prefill latency and cost per million input tokens for various configurations and input lengths.
GPT-OSS 120B:
Input Length | 8192 | 32768 | 65536 | 128000 |
---|---|---|---|---|
EP1 DP1 TP1 | 0.364s, $0.037 | 2.421s, $0.062 | 7.375s, $0.094 | 23.902s, $0.156 |
EP2 DP1 TP2 | 0.190s, $0.039 | 1.245s, $0.063 | 3.786s, $0.096 | 12.116s, $0.158 |
EP4 DP1 TP4 | 0.165s, $0.067 | 0.900s, $0.092 | 2.449s, $0.125 | 7.165s, $0.187 |
EP8 DP1 TP8 | 0.179s, $0.145 | 0.839s, $0.171 | 2.003s, $0.204 | 5.119s, $0.267 |
GPT-OSS 20B:
Input Length | 8192 | 32768 | 65536 | 128000 |
---|---|---|---|---|
EP1 DP1 TP1 | 0.239s, $0.024 | 1.610s, $0.041 | 4.895s, $0.062 | 16.037s, $0.104 |
EP2 DP1 TP2 | 0.127s, $0.026 | 0.833s, $0.042 | 2.522s, $0.064 | 8.143s, $0.106 |
EP4 DP1 TP4 | 0.108s, $0.044 | 0.593s, $0.060 | 1.616s, $0.082 | 4.766s, $0.124 |
EP8 DP1 TP8 | 0.116s, $0.094 | 0.547s, $0.111 | 1.310s, $0.133 | 3.365s, $0.175 |
For decode, we discovered that EP4 DP1 TP4 is the best configuration for most cases. The four plots below illustrate the trade-off between decode speed and cost per million output tokens for both the 120B and 20B models at sequence lengths of 8192 and 32768.




We have also considered using the smaller 20B model as a speculative decoding draft model for the larger 120B model. However, because the number of activated parameters is similar (3B vs. 5B), the resulting speedups were marginal.
References
GPT-OSS on Day 0
Perplexity is committed to making the best AI models available to curious people who demand accurate and trustworthy answers and agentic actions. OpenAI recently announced their open-weight models (gpt-oss-20b
and gpt-oss-120b
). As with many leading models, Perplexity is among the first organizations to evaluate gpt-oss-20b
and gpt-oss-120b
.
In this post, we share the infrastructure decisions of our in-house inference stack that made Day-0 support possible. We focus on serving these models on NVIDIA H200 GPUs, detailing the kernel changes, deployment choices, and speed-cost trade-offs.
GPT-OSS on Hopper
The open-weight models are shipped using MXFP4 quantization, which helps them fit in the memory on consumer-grade hardware or achieve peak throughput on NVIDIA Blackwell. However, for initial evaluation, we wanted to run them on existing H200 Hopper clusters with minimal inference-engine changes. Hopper does not have dedicated FP4 tensor cores, which were introduced by Blackwell. Consequently, we decided to serve the models with FP8 precision instead to minimize the need for custom kernels and make best use of the available hardware.
At a high level, transformer LLMs are structurally simple:
Input embedding
A sequence of transformer layers. Each contains an attention block, and a dense MLP or sparse MoE block.
Output logit projection
When bringing up a new transformer model, we are primarily interested in learning to what extent they match existing models to identify the changes required. OpenAI kindly shared an overview of their models ahead of time, allowing us to tweak our inference setup accordingly. We found that the model did require quite a number of small changes.
Attention: GQA with Attention Sinks
Each attention head has a pre-trained sink parameter which is factored into Softmax
YaRN positional encoding, similar to DeepSeek-V3
QKV projections have biases, similar to Qwen 2
Output projection has bias
MLP: Sparse Mixture-of-Expert (MoE)
SwiGLU activation function, different from most open-weight models
Softmax after Top-K for expert weights
Experts have biases
Expecting the number of parameters to be around 100B-200B, we decided to:
Extend FlashInfer to support attention sinks
Re-use the MoE implementation already built for the DeepSeek-V3 family of models
Choose FP8 128x128 block scaling as the quantization scheme, extending the DeepGEMM kernels to support a bias term.
In-house Inference Engine
Our inference efforts are backed by our in-house inference engine, ROSE. We are building ROSE out to be a flexible framework that allows new models to be quickly brought up. Additionally, once we are committed to serving a model at scale, it provides a platform to optimize and boost performance to reach the state of the art.
ROSE exposes an LLM engine, which can load a model and generate decoded tokens for it through multiple decoders, including speculative and MTP decoders. The models themselves are composed of a custom set of layers, which support various configuration knows for quantization and parallelism. Underneath the layers, it bundles both Triton and CUDA kernels that efficiently implement and fuse the computations performed by individual layers.
ROSE is primarily built in Python and relies of PyTorch for model definitions. While Python allows us to be flexible and adapt to new models, we are migrating most performance-sensitive components, such as serving and batch scheduling, to Rust.
To bring up a new model, we firstly define the model hierarchy and build a weight converter for it. After all weights are accounted for, we define a forward pass for the simplest TP=1 use case and ensure it works end-to-end. Finally, we implement and test support for various forms of parallelism before we release a new container and deploy the model on a cluster for evaluation. We then iterate on performance, relying on feedback from both dashboards and finer-grained benchmarks.
ROSE operates on numeric tokens, with little knowledge of chat formats outside of integration tests. In parallel, we had to look into adjusting surrounding infrastructure to support the Harmony tokenizer and correctly map requests to input token sequences.
With the GPT OSS model, we found that no changes were required to our decoders, CUDA graph harnesses or surrounding infrastructure, allowing us to focus our efforts on tweaking the kernels backing the model.
Kernels for GPT-OSS
Sink Attention
Sink attention introduces an extra bias factor prior to softmax, accumulated with the product of Q
and K
:
attn = softmax(q @ k^T * sm_scale + sink) * V
While a naive implementation is fairly trivial, efficient attention implementations, such as FlashInfer, parallelize the computation of attention across both independent heads and across the KV
sequence length. Additionally, for numerical stability, Softmax is computed online while also subtracting the maximal element for numerical stability:
softmax(x) = e^x / sum(e^x) = e ^ (x - max(x)) / sum(e ^ (x - max(x))
The online implementation tracked a running maximum m
initialized to -inf
and a scale d
to 1
. By initializing the maximum to the sink value of the current head, we would ensure that it would get accumulated into the Softmax correctly. We had to take care to ensure that this maximum would only be included in the first chunk if attention was split across multiple blocks along the sequence length.
We have adjusted the FlashInfer FA2 kernel template, used for both prefill and decode, to include the sink parameters.
MoE
The main difference between the MoE layer of different models is usually the routing scheme. Expecting substantial variability, ROSE implements this kernel in Triton, allowing us to easily adapt it to new models. In the case of GPT-OSS, we had to adjust weight scoring to perform softmax only across the Top-K select experts, unlike some DeepSeek-style models which compute softmax across all experts. We found it is critical for accuracy to compute the expert weights in bfloat16
precision.
The challenge with the MoE layer was due to our decision to re-use the existing DeepGEMM kernels, which lacked support for bias terms. Fortunately, over time we integrated deeply with them, making changes to better connect them with our custom all-to-all dispatch kernels. We added bias support to DeepGEMM by initializing the WGMMA accumulator registers to the bias values loaded from memory. This is a one-off during the processing of an MN block, thus the cost is likely negligible and the L2 cache ensures efficient access to these values.
Since these models do not readily benefit from Data Parallelism, we re-used TP-only dispatch and combine kernels. The inputs to dispatch are replicated across all rank, thus dispatch selects the tokens for the locally available experts. Combine performs the weighted accumulation across NVLink.
Performance Tweaks
Even though our goal was to bring up the models quickly, we did want to ensure that performance was still reasonable. After ensuring correctness against a reference implementation provided by OpenAI, we profiled our implementation using the default PyTorch profiler, watching out for block and thread allocations across kernels. The functionality we re-used was built out for models which had a larger hidden dimension of 7168, substantially larger than the 2880 of the GPT-OSS models. We specialized kernel launchers to pick better grid dimensions, significantly boosting throughput.
Tokenization and Channels
Harmony’s chat format has several important concepts for structuring conversations, reasoning content, and function calls. In this format, each message consists of a header and content:
<|start|>{header}<|message|>{content}<
The header carries metadata — most importantly the role (system, developer, user, assistant, tool), the channel, and a recipient such as to=functions.<function-name>
when invoking a tool call.
Channels (such as analysis, commentary, and final answer) make the model’s outputs more transparent and segmented. Because the format enforces this separation, the model can reason transparently (and be inspected for debugging) without leaking chain-of-thought into production. Recipient tags (to=
headers) tell the runtime exactly which actor should receive the next payload — whether that’s a function, the model itself, or the end user.
Our inference engine ROSE is format-agnostic. It receives a batch of input tokens and streams back output tokens. In our inference stack, the task of formatting inputs into a tokenized prompt and parsing LLM results back into structured messages is handled by the dedicated JSON API frontend service. We integrated Harmony by adding a new pluggable formatter/parser to this frontend, allowing the backend to remain unchanged while the frontend cleanly manages chat-format rules, tool calls, and streaming. This separation of concerns keeps the architecture modular, enabling quick adoption of new formats, response features with minimal backend changes.
Cost and Performance
Given the relatively small size of GPT-OSS models, we limit the deployment of a replica to a single node to avoid inter-node communication overhead. Our implementation supports flexible combinations of expert parallelism (EP), data parallelism (DP), and tensor parallelism (TP). To determine the optimal deployment setup, we benchmarked all combinations of sharding strategies across batch sizes from 1 to 128. Thanks to Prefill-Decode Disaggregation, we are able to discuss the performance-cost trade-off separately for prefill and decode.
For cost estimates, we assume a nominal H200 market rate of $3.00/hour (note: this is not our actual GPU price).
We discovered that setting batch_size=1 and DP=1 yields the best prefill performance. Different TP configurations then provide varying trade-offs between first-token latency and cost. The tables below present prefill latency and cost per million input tokens for various configurations and input lengths.
GPT-OSS 120B:
Input Length | 8192 | 32768 | 65536 | 128000 |
---|---|---|---|---|
EP1 DP1 TP1 | 0.364s, $0.037 | 2.421s, $0.062 | 7.375s, $0.094 | 23.902s, $0.156 |
EP2 DP1 TP2 | 0.190s, $0.039 | 1.245s, $0.063 | 3.786s, $0.096 | 12.116s, $0.158 |
EP4 DP1 TP4 | 0.165s, $0.067 | 0.900s, $0.092 | 2.449s, $0.125 | 7.165s, $0.187 |
EP8 DP1 TP8 | 0.179s, $0.145 | 0.839s, $0.171 | 2.003s, $0.204 | 5.119s, $0.267 |
GPT-OSS 20B:
Input Length | 8192 | 32768 | 65536 | 128000 |
---|---|---|---|---|
EP1 DP1 TP1 | 0.239s, $0.024 | 1.610s, $0.041 | 4.895s, $0.062 | 16.037s, $0.104 |
EP2 DP1 TP2 | 0.127s, $0.026 | 0.833s, $0.042 | 2.522s, $0.064 | 8.143s, $0.106 |
EP4 DP1 TP4 | 0.108s, $0.044 | 0.593s, $0.060 | 1.616s, $0.082 | 4.766s, $0.124 |
EP8 DP1 TP8 | 0.116s, $0.094 | 0.547s, $0.111 | 1.310s, $0.133 | 3.365s, $0.175 |
For decode, we discovered that EP4 DP1 TP4 is the best configuration for most cases. The four plots below illustrate the trade-off between decode speed and cost per million output tokens for both the 120B and 20B models at sequence lengths of 8192 and 32768.




We have also considered using the smaller 20B model as a speculative decoding draft model for the larger 120B model. However, because the number of activated parameters is similar (3B vs. 5B), the resulting speedups were marginal.
References
Share this article