LLM Inference Performance Benchmarking from Scratch

13 min read

An introduction to LLM inference metrics like TTFT, ITL, and output TPS, and a walkthrough on building a minimal inference benchmarking script from scratch.

Whether you're interested in reducing the environmental impact of large language models (LLMs), increasing their usefulness, or reducing the cost of serving them, performance engineering of LLM systems is integral. By improving the performance of the system, you can get the same output with a reduced hardware / energy / time budget, or increase the output given a fixed budget. However, to engineer an LLM system for performance, you first need to know how to benchmark the system to measure the performance before and after any modification. In this post, I'll dive into the key metrics used for benchmarking LLMs and how to measure them. By the end of the post, you'll know what TTFT, TPOT, ITL, and TPS mean.

LLM inference (performance) benchmarking is the task of measuring the throughput and latency of an LLM system under different types of load. This is in contrast to LLM evaluation (quality) benchmarking, which measures how closely a model's actual output resembles the expected output for a series of tasks, such as question answering. It's easier for me to explain inference benchmarking in code, so in this post I'll walk through how to build a simple LLM inference benchmarking script in Python.

The four stages of the script are:

  1. Data generation: Synthesizes a list of prompts for the load generation stage.
  2. Load generation: Sends streaming requests to the LLM system at a desired concurrency level and collects the output.
  3. Response processing: Filters and processes the responses to generate clean input for the metrics calculations.
  4. Performance analysis: Using the request timestamps, response timestamps, and output from the previous stage, computes our key metrics.

Below is the main function for the script.

async def main(args: argparse.Namespace):
    prompts = get_prompts(
        model_name=args.tokenizer,
        request_count=args.request_count,
        input_tokens_mean=args.input_tokens_mean,
        input_tokens_stddev=args.input_tokens_stddev,
        seed=args.seed,
    )
    responses = await generate_outputs(
        url=args.url,
        model_name=args.model_name,
        max_completion_tokens=args.max_completion_tokens,
        concurrency=args.concurrency,
        prompts=prompts,
    )
    processed_responses = process_responses(responses)
    metrics = calculate_metrics(
        model_name=args.tokenizer,
        prompts=prompts,
        processed_responses=processed_responses,
    )
    statistics = calculate_statistics(metrics)
 
    display_table(statistics)

I'll walk through how to build this application from beginning to end, with the post divided into sections that correspond to the functions called from main.

  1. Data generation will cover get_prompts,
  2. Load generation will cover generate_outputs,
  3. Response processing will cover process_responses, and
  4. Performance analysis will cover calculate_metrics and calculate_statistics. This is also where we'll define all the acronyms like TTFT and ITL.

I don't cover argument parsing or output printing in this post since they aren't essential to understanding inference benchmarking. Let's get started by diving into data generation.

Data generation

To benchmark the LLM, we're going to send a bunch of requests to it and measure how long it takes to get responses back. Ideally, the request dataset should match the distribution of data that you expect to receive in production. This means that we want to use similar input sequence lengths, messages, and request parameters. The results of our benchmarking will vary greatly depending on these values. But to keep things simple, we will generate a synthetic dataset using random tokens1 from the tokenizer's vocabulary.

To start, we write a helper function get_tokenizer that returns the HuggingFace tokenizer2 for the specified model.

@cache
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
    return AutoTokenizer.from_pretrained(model_name)

Using this function, get_prompts returns a list of strings where each string is a list of tokens randomly drawn from the tokenizer's vocabulary. For example, one possible prompt could be:

"govstrings传输 inf исследованияoundation roboticет różnic_laulsesdistanceBrush котор narrowed tmp..."

The length of the prompt is drawn from a normal distribution with the specified mean and standard deviation.

def get_prompts(
    *,
    model_name: str,
    request_count: int,
    input_tokens_mean: float,
    input_tokens_stddev: float,
    seed: int,
) -> list[str]:
    tokenizer = get_tokenizer(model_name)
    vocab_size: int = tokenizer.vocab_size
    rng = np.random.default_rng(seed)
    prompts: list[str] = []
 
    for _ in range(request_count):
        input_tokens_count = max(
            0, int(rng.normal(loc=input_tokens_mean, scale=input_tokens_stddev))
        )
        tokens = rng.integers(low=0, high=vocab_size, size=input_tokens_count)
        prompts.append(tokenizer.decode(tokens))
 
    return prompts

Now that we have the prompts, we need to send them to the model and record the response times.

Load generation

At a high level, our load generation loop sends a bunch of concurrent requests and records the timestamps of each streaming response. More specifically, we create one asyncio task per request and constrain the concurrency using a semaphore3 to ensure that exactly concurrency requests are being sent at a time. We only control the concurrency, not the rate at which we send requests. The request rate is not limited — the system is pushed as hard as possible at each concurrency level. These requests are sent to an OpenAI Chat Completions API4 compatible endpoint specified by the url parameter.

async def generate_outputs(
    *,
    prompts: list[str],
    url: str,
    model_name: str,
    max_completion_tokens: int,
    concurrency: int,
) -> list[Response]:
    semaphore = asyncio.Semaphore(concurrency)
 
	async def send_request(content: str) -> Response:
		async with semaphore:
			return await request_func(
				url=url,
				model_name=model_name,
				max_completion_tokens=max_completion_tokens,
				content=content
			)
 
	tasks: list[asyncio.Task[Response]] = []
	for prompt in prompts:
		tasks.append(asyncio.create_task(send_request(prompt)))
 
	responses = await asyncio.gather(*tasks)
 
    return responses

The request_func function sends the requests to the LLM backend. We use "temperature": 0.0 to turn off sampling, max_completion_tokens to set the desired output sequence length, and "stream": True to stream the output for more granular metrics, like the time until the first token is received. We also set "ignore_eos": True so that generation doesn't stop when the model generates an end-of-sequence token. Instead generation will end when the model has generated max_completion_tokens tokens. This ensures that the output doesn't get truncated so we can measure the performance precisely at the specified output sequence length.

async def request_func(
	*,
    url: str,
    model_name: str,
    max_completion_tokens: int,
    content: str,
    task_id: TaskID,
) -> Response:
    async with aiohttp.ClientSession(raise_for_status=True) as session:
        request_timestamp = time.perf_counter()
        async with session.post(
            url=url,
            json={
                "model": model_name,
                "messages": [{"role": "user", "content": content}],
                "temperature": 0.0,
                "max_completion_tokens": max_completion_tokens,
                "stream": True,
                "ignore_eos": True,
            },
            headers={"Content-Type": "application/json"},
        ) as response:

After receiving the response, this function iterates over the Server-sent events (SSE)5, skips empty chunks, and records both the content of each event and the timestamp when each one is received. We do a minimal amount of processing in this function to avoid adding overhead when timing the per-chunk latency. We'll do additional processing and filtering in the next section.

        ) as response:
            response_timestamps: list[float] = []
            chunks: list[bytes] = []
            async for chunk_bytes in response.content:
                response_timestamp = time.perf_counter()
 
                chunk_bytes = chunk_bytes.strip()
                if not chunk_bytes:
                    continue
 
                response_timestamps.append(response_timestamp)
                chunks.append(chunk_bytes)
 
            return Response(
                request_timestamp=request_timestamp,
                response_timestamps=response_timestamps,
                chunks=chunks,
            )

Response processing

Now is when we do the processing that we skipped in the previous section, to get the data ready for metric calculation. At this point we have a request timestamp, a series of response timestamps, and the corresponding SSE chunks. In the Chat Completions API, the chunks only specify the data: and look like the following:

data: {"id":"chatcmpl-421","object":"chat.completion.chunk","created":1762666106,"model":"llama3.2","system_fingerprint":"fp_ollama","choices":[{"index":0,"delta":{"role":"assistant","content":" can"},"finish_reason":null}]}

To process the responses, we remove the data: prefix from each chunk...

def process_responses(responses: list[Response]) -> list[ProcessedResponse]:
    results: list[ProcessedResponse] = []
 
    for response in responses:
        response_timestamps: list[float] = []
        contents: list[str] = []
 
        for timestamp, chunk in zip(response.response_timestamps,
									response.chunks):
            chunk = chunk.decode("utf-8").removeprefix("data: ")

... skip the final [DONE] chunk...

            chunk = chunk.decode("utf-8").removeprefix("data: ")
			if chunk == "[DONE]":
                continue

... load the JSON, extract the content, and filter out chunks with empty content...

                continue
            data = json.loads(chunk)
            if choices := data.get("choices"):
                content = choices[0]["delta"].get("content")
                if not content:
                    continue

... and finally collect all the processed responses for the next stage.

                    continue
                response_timestamps.append(timestamp)
                contents.append(content)
 
        results.append(
            ProcessedResponse(
                request_timestamp=response.request_timestamp,
                first_response_timestamp=response_timestamps[0],
                last_response_timestamp=response_timestamps[-1],
                contents=contents,
            )
        )
 
    return results

Finally, we have everything we need to compute our metrics.

Performance analysis

The eight metrics we will compute are input sequence length, output sequence length, TTFT, request latency, ITL, output token throughput per request, output token throughput, and request throughput. The calculate_metrics function iterates over the prompts and processed responses and computes each of the first six metrics for each request. Then it calculates the output token throughput (in tokens per second) and request throughput (in requests per second) over the entire benchmark. Don't worry if you didn't catch all of that. In this section we'll finally define each of those terms.

def calculate_metrics(
    *,
    model_name: str,
    prompts: list[str],
    processed_responses: list[ProcessedResponse],
) -> Metrics:
    ttfts_ms: list[float] = []
    request_latencies_ms: list[float] = []
    itls_ms: list[float] = []
    requests_tps: list[float] = []
    input_sequence_lengths: list[float] = []
    output_sequence_lengths: list[float] = []
 
    for prompt, processed_response in zip(prompts, processed_responses):
        pass
 
    return Metrics(
        ttft=ttfts_ms,
        request_latency=request_latencies_ms,
        itls=itls_ms,
        request_tps=requests_tps,
        input_sequence_length=input_sequence_lengths,
        output_sequence_length=output_sequence_lengths,
        total_tps=[total_tps],
        rps=[rps],
    )

The input sequence length (ISL) and output sequence length (OSL) are the lengths of the input (prompt) and output, in number of tokens. They are computed by tokenizing the prompts and outputs, respectively, and getting the lengths of those lists.

	for prompt, processed_response in zip(prompts, processed_responses):
		input_sequence_length = len(
            tokenizer.encode(prompt, add_special_tokens=False)
        )
        output_sequence_length = len(
            tokenizer.encode(
                "".join(processed_response.contents), add_special_tokens=False
            )
        )
 
	return Metrics(

It's important to first concatenate all the contents together before tokenizing the output. There is no guarantee that the SSE content chunk boundaries correspond to token boundaries, so tokenizing each chunk separately can lead to very different (and incorrect) results. For example, SSE chunks may contain multiple tokens or may end in the middle of a UTF-8 byte sequence.

Time-to-first-token (TTFT) is the number of seconds between sending our request and receiving back the first token of our response.

		)
		ttft = (
			processed_response.first_response_timestamp
			- processed_response.request_timestamp
		)
 
	return Metrics(

The request latency is the number of seconds between the first request and the last response.

		)
		request_latency = (
			processed_response.last_response_timestamp
			- processed_response.request_timestamp
		)
 
	return Metrics(

The inter-token latency (ITL) is the average number of seconds between consecutive tokens. This is sometimes referred to at the time per output token (TPOT). In some other sources, ITL refers to the collection of durations between consecutive tokens and TPOT refers to the average of the ITLs. Using our definition, we compute ITL by taking the duration between receiving the first token and the completion of the request and dividing it by the number of inter-token periods.

		)
		inter_token_latency = (request_latency - ttft) / (request_output_tokens - 1)
 
	return Metrics(

We measure the throughput of an individual request using per-request output token throughput / per-request tokens per second (per-request TPS). So the throughput we care about is how many output tokens are generated per second.

		inter_token_latency = (request_latency - ttft) / (request_output_tokens - 1)
		request_tps = request_output_tokens / request_latency
 
	return Metrics(

When aggregating the per-request metrics, I convert TTFT, ITL, and request latency to milliseconds for readability.

		request_tps = request_output_tokens / request_latency
		input_sequence_lengths.append(float(input_sequence_length))
		output_sequence_lengths.append(float(output_sequence_length))
		itls_ms.append(inter_token_latency * 1000)
		ttfts_ms.append(ttft * 1000)
		request_latencies_ms.append(request_latency * 1000)
		requests_tps.append(request_tps)
 
	return Metrics(

Another useful metric is the output token throughput / tokens per second (TPS). This is the total number of tokens output divided by the overall duration of the benchmark.

		requests_tps.append(request_tps)
	min_request_timestamp = min(
		processed_response.request_timestamp
		for processed_response in processed_responses
	)
	max_response_timestamp = max(
		processed_response.last_response_timestamp
		for processed_response in processed_responses
	)
	benchmark_duration = max_response_timestamp - min_request_timestamp
	total_output_tokens = sum(output_sequence_lengths)
	total_tps = total_output_tokens / benchmark_duration
 
	return Metrics(

And finally the request throughput / requests per second (RPS) is simply the number of requests divided by the benchmark duration.

	total_tps = total_output_tokens / benchmark_duration
	rps = len(requests) / benchmark_duration
 
	return Metrics(

For each of these metrics, we can compute the min, max, mean, standard deviation, and percentiles.

def statistics(metric: list[float]) -> Statistics:
    p75, p90, p99 = np.percentile(metric, [75, 90, 99])
    return Statistics(
        min=np.min(metric),
        max=np.max(metric),
        mean=float(np.mean(metric)),
        stddev=float(np.std(metric)),
        p75=p75,
        p90=p90,
        p99=p99,
    )

Finally, I calculate all of the statistics in get_statistics:

def calculate_statistics(metrics: Metrics) -> dict[str, Statistics]:
    return {
        "Time to First Token (ms)": statistics(metrics.ttft),
        "Request Latency (ms)": statistics(metrics.request_latency),
        "Inter-Token Latency (ms)": statistics(metrics.itls),
        "Per-Request Output Token Throughput (tokens/s)": statistics(
            metrics.request_tps
        ),
        "Input Sequence Length (tokens)": statistics(metrics.input_sequence_length),
        "Output Sequence Length (tokens)": statistics(metrics.output_sequence_length),
        "Output Token Throughput (tokens/s)": statistics(metrics.total_tps),
        "Request Throughput (req/s)": statistics(metrics.rps),
    }

Using rich6 to implement display_table, my end result is a table that looks like this:

                                           Benchmark Statistics
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓
┃ Metric                                      ┃    Mean ┃     Min ┃     Max ┃     P99 ┃     P90 ┃     P75 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩
│ Time to First Token (ms)                    │   31.29 │   26.11 │   89.68 │   89.40 │   33.66 │   31.18 │
│ Request Latency (ms)                        │  534.22 │  521.16 │  581.68 │  581.04 │  541.30 │  537.44 │
│ Inter-Token Latency (ms)                    │    1.69 │    1.64 │    1.78 │    1.77 │    1.72 │    1.71 │
│ Per-Request Output Token Throughput         │  558.77 │  512.31 │  578.19 │  576.94 │  569.78 │  565.46 │
│ (tokens/s)                                  │         │         │         │         │         │         │
│ Input Sequence Length (tokens)              │  104.40 │   82.00 │  135.00 │  131.04 │  118.00 │  110.00 │
│ Output Sequence Length (tokens)             │  298.41 │  290.00 │  304.00 │  304.00 │  300.00 │  300.00 │
│ Output Token Throughput (tokens/s)          │ 1116.43 │ 1116.43 │ 1116.43 │ 1116.43 │ 1116.43 │ 1116.43 │
│ Request Throughput (req/s)                  │    3.74 │    3.74 │    3.74 │    3.74 │    3.74 │    3.74 │
└─────────────────────────────────────────────┴─────────┴─────────┴─────────┴─────────┴─────────┴─────────┘

That's all there is to it!

This was just an introduction to LLM inference benchmarking and there's a lot of improvements that can be made to this to get more insight into how your LLM system performs against real-world workloads. Although I haven't yet explored it in depth, I took a lot of inspiration from NVIDIA's AIPerf tool and would recommend looking into that to see what a production-level LLM inference benchmarking tool looks like.

Acknowledgements

Thank you Cai Wangwilt for your feedback on an earlier draft of this post.

Footnotes

  1. Using random tokens is unrealistic and will lead to an artificially low KV cache hit rate. This is fine for illustrating how benchmarking works, but for real workloads you should sample from production logs or construct representative prompt templates.

  2. Each trained model has its own vocabulary: a mapping from integer IDs to their corresponding strings. The tokenizer for a given model uses the vocabulary to encode a string into a vector of integers and decode a vector of integers back to its corresponding string.

  3. Semaphores are concurrency primitives that have an initial integer value and two operations, wait and signal. The wait operation decreases the value and waits if the value is negative, and signal increases the value and wakes one of the waiting threads (or in this case, async coroutines).

  4. Here's the official reference for the API: https://platform.openai.com/docs/api-reference/chat/create. Inference providers often provide endpoints that are compatible with this API (to varying degrees).

  5. The spec is great for understanding how SSE works: https://html.spec.whatwg.org/multipage/server-sent-events.html

  6. rich is a great Python library for formatting and styling text in the terminal: https://github.com/Textualize/rich


Discussion

Thanks for reading! I'd love to hear your thoughts — please join the discussion on Bluesky.

Loading comments...