LLM Inference Performance Benchmarking from Scratch
An introduction to LLM inference metrics like TTFT, ITL, and output TPS, and a walkthrough on building a minimal inference benchmarking script from scratch.
LLM inference benchmarking is the task of measuring the throughput and latency of an LLM system under different types of load. This is in contrast to LLM evaluation benchmarking, which measures how closely a model's output resembles the expected output for a series of tasks, such as question answering. It's easier for me to explain inference benchmarking in code, so in this post I'll walk through how to build a simple LLM inference benchmarking script in Python as a means of clearly explaining these metrics and how they're computed.
The four stages of the script are:
- Data generation: Synthesizes a list of prompts for the load generation stage.
- Load generation: Sends streaming requests to the LLM system at a desired concurrency level and collects the output.
- Response processing: Filters and processes the responses to generate clean input for the metrics calculations.
- Performance analysis: Using the request timestamps, response timestamps, and output from the previous stage, computes our key metrics.
Below is the main function for the script.
async def main(args: argparse.Namespace):
prompts = get_prompts(
model_name=args.tokenizer,
request_count=args.request_count,
input_tokens_mean=args.input_tokens_mean,
input_tokens_stddev=args.input_tokens_stddev,
seed=args.seed,
)
responses = await generate_outputs(
url=args.url,
model_name=args.model_name,
max_completion_tokens=args.max_completion_tokens,
concurrency=args.concurrency,
prompts=prompts,
)
processed_responses = process_responses(responses)
metrics = calculate_metrics(
model_name=args.tokenizer,
prompts=prompts,
processed_responses=processed_responses,
)
statistics = calculate_statistics(metrics)
display_table(statistics)I'm going to show you how to build this application from beginning to end, with sections matching the functions called from main.
- Data generation will cover
get_prompts, - Load generation will cover
generate_outputs, - Response processing will cover
process_responses, and - Performance analysis will cover
calculate_metricsandcalculate_statistics. This is also where we'll define all the acronyms like TTFT and ITL.
I don't cover argument parsing or output printing in this post since they aren't essential to understanding inference benchmarking.
Data generation
To calculate these metrics, let's start with generating the request dataset. Ideally the request dataset would match the distribution of data that you expect to receive in production. For example, we want to use similar input sequence lengths, messages, and request parameters. The results of our benchmarking will vary greatly depending on these values. But to keep things simple, we will generate a synthetic dataset using random tokens1 from the tokenizer's vocabulary.
To start, we write a helper function get_tokenizer that returns the HuggingFace tokenizer for the specified model.
@cache
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
return AutoTokenizer.from_pretrained(model_name)Using this function, get_prompts returns a list of strings where each string is a list of tokens randomly drawn from the tokenizer's vocabulary.
The length of the prompt is drawn from a normal distribution with the specified mean and standard deviation.
def get_prompts(
*,
model_name: str,
request_count: int,
input_tokens_mean: float,
input_tokens_stddev: float,
seed: int,
) -> list[str]:
tokenizer = get_tokenizer(model_name)
vocab_size: int = tokenizer.vocab_size
rng = np.random.default_rng(seed)
prompts: list[str] = []
for _ in range(request_count):
input_tokens_count = max(
0, int(rng.normal(loc=input_tokens_mean, scale=input_tokens_stddev))
)
tokens = rng.integers(low=0, high=vocab_size, size=input_tokens_count)
prompts.append(tokenizer.decode(tokens))
return promptsThe result is a list of strings of random tokens. For example, one possible prompt could be:
"govstrings传输 inf исследованияoundation roboticет różnic_laulsesdistanceBrush котор narrowed tmp..."
Load generation
At a high level, our load generation loop looks like the following.
We create one asyncio task per request and constrain the concurrency using a semaphore to ensure that exactly concurrency requests are being sent at a time.
We only control the concurrency, not the request rate.
This means that the system is pushed as hard as possible at each concurrency level.
These requests are sent to an OpenAI Chat Completions API2 compatible endpoint specified by the url parameter.
async def generate_outputs(
*,
prompts: list[str],
url: str,
model_name: str,
max_completion_tokens: int,
concurrency: int,
) -> list[Response]:
semaphore = asyncio.Semaphore(concurrency)
async def send_request(content: str) -> Response:
async with semaphore:
return await request_func(
url=url,
model_name=model_name,
max_completion_tokens=max_completion_tokens,
content=content
)
tasks: list[asyncio.Task[Response]] = []
for prompt in prompts:
tasks.append(asyncio.create_task(send_request(prompt)))
responses = await asyncio.gather(*tasks)
return responsesThe request_func function handles sending the requests to the LLM backend.
We use "temperature": 0.0 to turn off sampling, max_completion_tokens to set the desired output sequence length, and "stream": True to stream the output to get more granular metrics.
We also set "ignore_eos": True so that generation doesn't stop when the model generates an end-of-sequence token.
Instead generation will end when the model has generated max_completion_tokens tokens.
This is to ensure that the output doesn't get truncated so we can measure the performance precisely at the specified output sequence length.
async def request_func(
*,
url: str,
model_name: str,
max_completion_tokens: int,
content: str,
task_id: TaskID,
) -> Response:
async with aiohttp.ClientSession(raise_for_status=True) as session:
request_timestamp = time.perf_counter()
async with session.post(
url=url,
json={
"model": model_name,
"messages": [{"role": "user", "content": content}],
"temperature": 0.0,
"max_completion_tokens": max_completion_tokens,
"stream": True,
"ignore_eos": True,
},
headers={"Content-Type": "application/json"},
) as response:After receiving the response, this function iterates over the Server-sent events (SSE)3, skips empty chunks, and records both the content of each event and the timestamp when each one is received. We do a minimal amount of processing in this function to avoid adding overhead when timing the per-chunk latency. We'll do additional processing and filtering in the next section.
) as response:
response_timestamps: list[float] = []
chunks: list[bytes] = []
async for chunk_bytes in response.content:
response_timestamp = time.perf_counter()
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
response_timestamps.append(response_timestamp)
chunks.append(chunk_bytes)
return Response(
request_timestamp=request_timestamp,
response_timestamps=response_timestamps,
chunks=chunks,
)Response processing
At this point we have a request timestamp, a series of response timestamps, and the corresponding SSE chunks. In the Chat Completions API, the chunks only specify the data and look like the following:
data: {"id":"chatcmpl-421","object":"chat.completion.chunk","created":1762666106,"model":"llama3.2","system_fingerprint":"fp_ollama","choices":[{"index":0,"delta":{"role":"assistant","content":" can"},"finish_reason":null}]}
To process the responses, we remove the data: prefix from each chunk...
def process_responses(responses: list[Response]) -> list[ProcessedResponse]:
results: list[ProcessedResponse] = []
for response in responses:
response_timestamps: list[float] = []
contents: list[str] = []
for timestamp, chunk in zip(response.response_timestamps,
response.chunks):
chunk = chunk.decode("utf-8").removeprefix("data: ")... skip the final [DONE] chunk...
chunk = chunk.decode("utf-8").removeprefix("data: ")
if chunk == "[DONE]":
continue... load the JSON, extract the content, and filter out chunks with empty content...
continue
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
if not content:
continue... and finally collect all the processed responses for the next stage.
continue
response_timestamps.append(timestamp)
contents.append(content)
results.append(
ProcessedResponse(
request_timestamp=response.request_timestamp,
first_response_timestamp=response_timestamps[0],
last_response_timestamp=response_timestamps[-1],
contents=contents,
)
)
return resultsNow we finally have everything we need to compute our metrics.
Performance analysis
The eight metrics we will compute are input sequence length, output sequence length, TTFT, request latency, ITL, output token throughput per request, output token throughput, and request throughput.
The calculate_metrics function iterates over the prompts and processed responses and computes each of the first six metrics for each request.
Then it calculates the output token throughput (in tokens per second) and request throughput (in requests per second) over the entire benchmark.
def calculate_metrics(
*,
model_name: str,
prompts: list[str],
processed_responses: list[ProcessedResponse],
) -> Metrics:
ttfts_ms: list[float] = []
request_latencies_ms: list[float] = []
itls_ms: list[float] = []
requests_tps: list[float] = []
input_sequence_lengths: list[float] = []
output_sequence_lengths: list[float] = []
for prompt, processed_response in zip(prompts, processed_responses):
pass
return Metrics(
ttft=ttfts_ms,
request_latency=request_latencies_ms,
itls=itls_ms,
request_tps=requests_tps,
input_sequence_length=input_sequence_lengths,
output_sequence_length=output_sequence_lengths,
total_tps=[total_tps],
rps=[rps],
)The input sequence length (ISL) and output sequence length (OSL) are computed by tokenizing the prompts and outputs, respectively, and counting the number of tokens.
for prompt, processed_response in zip(prompts, processed_responses):
input_sequence_length = len(
tokenizer.encode(prompt, add_special_tokens=False)
)
output_sequence_length = len(
tokenizer.encode(
"".join(processed_response.contents), add_special_tokens=False
)
)
return Metrics(It's important to first concatenate all the contents together before tokenizing the output. There is no guarantee that the SSE content chunk boundaries correspond to token boundaries, so tokenizing the individual chunks can lead to very different (and incorrect) results. For example, SSE chunks may contain multiple tokens or may end in the middle of a UTF-8 byte sequence.
Time-to-first-token (TTFT) is the number of seconds between sending our request and receiving back the first token of our response.
)
ttft = (
processed_response.first_response_timestamp
- processed_response.request_timestamp
)
return Metrics(The request latency is the number of seconds between the first request and the last response.
)
request_latency = (
processed_response.last_response_timestamp
- processed_response.request_timestamp
)
return Metrics(The inter-token latency (ITL) is the average number of seconds between consecutive tokens. This is sometimes referred to at the time per output token (TPOT). In some other sources, ITL refers to the collection of durations between consecutive tokens and TPOT refers to the average of the ITLs. Using our definition, we compute ITL by taking the duration between receiving the first token and the completion of the request and dividing it by the number of inter-token periods.
)
inter_token_latency = (request_latency - ttft) / (request_output_tokens - 1)
return Metrics(We measure the throughput of an individual request using per-request output token throughput (per-request TPS).
inter_token_latency = (request_latency - ttft) / (request_output_tokens - 1)
request_tps = request_output_tokens / request_latency
return Metrics(When aggregating the per-request metrics, I convert TTFT, ITL, and request latency to milliseconds for readability.
request_tps = request_output_tokens / request_latency
input_sequence_lengths.append(float(input_sequence_length))
output_sequence_lengths.append(float(output_sequence_length))
itls_ms.append(inter_token_latency * 1000)
ttfts_ms.append(ttft * 1000)
request_latencies_ms.append(request_latency * 1000)
requests_tps.append(request_tps)
return Metrics(Another useful metric is the output token throughput (TPS). This is the total number of tokens output divided by the overall duration of the benchmark.
requests_tps.append(request_tps)
min_request_timestamp = min(
processed_response.request_timestamp
for processed_response in processed_responses
)
max_response_timestamp = max(
processed_response.last_response_timestamp
for processed_response in processed_responses
)
benchmark_duration = max_response_timestamp - min_request_timestamp
total_output_tokens = sum(output_sequence_lengths)
total_tps = total_output_tokens / benchmark_duration
return Metrics(And finally the request throughput (RPS) is simply the number of requests divided by the benchmark duration.
total_tps = total_output_tokens / benchmark_duration
rps = len(requests) / benchmark_duration
return Metrics(For each of these metrics, we can compute the min, max, mean, standard deviation, and percentiles.
def statistics(metric: list[float]) -> Statistics:
p75, p90, p99 = np.percentile(metric, [75, 90, 99])
return Statistics(
min=np.min(metric),
max=np.max(metric),
mean=float(np.mean(metric)),
stddev=float(np.std(metric)),
p75=p75,
p90=p90,
p99=p99,
)Finally, I calculate all of the statistics in get_statistics:
def calculate_statistics(metrics: Metrics) -> dict[str, Statistics]:
return {
"Time to First Token (ms)": statistics(metrics.ttft),
"Request Latency (ms)": statistics(metrics.request_latency),
"Inter-Token Latency (ms)": statistics(metrics.itls),
"Per-Request Output Token Throughput (tokens/s)": statistics(
metrics.request_tps
),
"Input Sequence Length (tokens)": statistics(metrics.input_sequence_length),
"Output Sequence Length (tokens)": statistics(metrics.output_sequence_length),
"Output Token Throughput (tokens/s)": statistics(metrics.total_tps),
"Request Throughput (req/s)": statistics(metrics.rps),
}Using rich4 to implement display_table, my end result is a table that looks like this:
Benchmark Statistics
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓
┃ Metric ┃ Mean ┃ Min ┃ Max ┃ P99 ┃ P90 ┃ P75 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩
│ Time to First Token (ms) │ 31.29 │ 26.11 │ 89.68 │ 89.40 │ 33.66 │ 31.18 │
│ Request Latency (ms) │ 534.22 │ 521.16 │ 581.68 │ 581.04 │ 541.30 │ 537.44 │
│ Inter-Token Latency (ms) │ 1.69 │ 1.64 │ 1.78 │ 1.77 │ 1.72 │ 1.71 │
│ Per-Request Output Token Throughput │ 558.77 │ 512.31 │ 578.19 │ 576.94 │ 569.78 │ 565.46 │
│ (tokens/s) │ │ │ │ │ │ │
│ Input Sequence Length (tokens) │ 104.40 │ 82.00 │ 135.00 │ 131.04 │ 118.00 │ 110.00 │
│ Output Sequence Length (tokens) │ 298.41 │ 290.00 │ 304.00 │ 304.00 │ 300.00 │ 300.00 │
│ Output Token Throughput (tokens/s) │ 1116.43 │ 1116.43 │ 1116.43 │ 1116.43 │ 1116.43 │ 1116.43 │
│ Request Throughput (req/s) │ 3.74 │ 3.74 │ 3.74 │ 3.74 │ 3.74 │ 3.74 │
└─────────────────────────────────────────────┴─────────┴─────────┴─────────┴─────────┴─────────┴─────────┘
That's all there is to it!
This was just an introduction to LLM inference benchmarking and there's a lot of improvements that can be made to this to get more insight into how your LLM system performs against real-world workloads. Although I haven't yet explored it in depth, I took some inspiration from NVIDIA's AIPerf tool and would recommend looking into that to see what a production-level LLM inference benchmarking tool looks like.
Footnotes
-
Using random tokens is unrealistic and will lead to an artificially low KV cache hit rate. This is fine for illustrating how benchmarking works, but for real workloads you should sample from production logs or construct representative prompt templates. ↩
-
Here's the official reference for the API: https://platform.openai.com/docs/api-reference/chat/create. Inference providers often provide endpoints that are compatible with this API (to varying degrees). ↩
-
The spec is great for understanding how SSE works: https://html.spec.whatwg.org/multipage/server-sent-events.html ↩
-
richis a great Python library for formatting and styling text in the terminal: https://github.com/Textualize/rich ↩
I’m @phillippe.siclait.com on Bluesky — reach out to continue the conversation.