import torch from typing import Dict, Any from dataclasses import dataclass from khaosz.model.transformer import ModelConfig, Transformer @dataclass class BenchmarkResult: total_tokens: int total_time: float tokens_per_second: float metadata: Dict[str, Any] class GenerationBenchmark: def __init__( self, config: ModelConfig, device: str = "cuda", dtype: torch.dtype = torch.float16 ): self.config = config self.device = device self.dtype = dtype self.model = Transformer(config).to(device=device, dtype=dtype) self.model.eval() def _initialize_kv_cache(self, batch_size: int) -> list: """初始化KV缓存""" config = self.config shape = (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head) k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) return (k_cache, v_cache) def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): prompt_ids = torch.randint( low=0, high=self.config.vocab_size, size=(batch_size, prompt_length), device=self.device, dtype=torch.long ) gen_ids = torch.randint( low=0, high=self.config.vocab_size, size=(batch_size, total_length - prompt_length), device=self.device, dtype=torch.long ) return prompt_ids, gen_ids @torch.inference_mode() def run_prefill_benchmark( self, batch_size: int = 1, prompt_length: int = 512, num_trials: int = 10, ) -> BenchmarkResult: for _ in range(3): prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length) _ = self.model(prompt_ids) torch.cuda.synchronize() total_time = 0.0 total_tokens = batch_size * prompt_length * num_trials for trial in range(num_trials): prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() _ = self.model(prompt_ids) end_event.record() torch.cuda.synchronize() trial_time = start_event.elapsed_time(end_event) / 1000 total_time += trial_time print(f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s " f"({prompt_length / trial_time:.1f} tokens/s)") return BenchmarkResult( total_tokens=total_tokens, total_time=total_time, tokens_per_second=total_tokens / total_time, metadata={ "benchmark_type": "prefill", "batch_size": batch_size, "prompt_length": prompt_length, "dtype": self.dtype, "device": self.device, } ) @torch.inference_mode() def run_decoding_benchmark( self, batch_size: int = 1, prompt_length: int = 512, gen_length: int = 128, num_trials: int = 5, ) -> BenchmarkResult: total_time = 0.0 total_tokens = batch_size * gen_length * num_trials for trial in range(num_trials): prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length) kv_cache = self._initialize_kv_cache(batch_size) _ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() current_pos = prompt_length for i in range(gen_length): input_token = gen_ids[:, i:i+1] _ = self.model(input_token, persistent_key_values=kv_cache, start_pos=current_pos) current_pos += 1 end_event.record() torch.cuda.synchronize() trial_time = start_event.elapsed_time(end_event) / 1000 total_time += trial_time print(f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " f"({gen_length / trial_time:.1f} tokens/s)") return BenchmarkResult( total_tokens=total_tokens, total_time=total_time, tokens_per_second=total_tokens / total_time, metadata={ "benchmark_type": "decoding", "batch_size": batch_size, "prompt_length": prompt_length, "gen_length": gen_length, "dtype": self.dtype, "device": self.device, } ) def print_benchmark_result(result: BenchmarkResult): """打印基准测试结果""" benchmark_type = result.metadata["benchmark_type"] print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}") print(f"Total Tokens Processed: {result.total_tokens:,}") print(f"Time Consumed: {result.total_time:.3f}s") print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s") if benchmark_type == "prefill": print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}") elif benchmark_type == "decoding": print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}") print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}") print("-" * 80) if __name__ == "__main__": config = ModelConfig( vocab_size=10000, n_dim=1536, n_head=24, n_kvhead=4, d_ffn=6912, m_len=2048, n_layer=24, norm_eps=1e-5, ) benchmark = GenerationBenchmark(config) print("=" * 80) print("Running Transformer Generation Benchmark") print("=" * 80) prefill_result = benchmark.run_prefill_benchmark(batch_size=4, prompt_length=512, num_trials=5) print_benchmark_result(prefill_result) gen_result = benchmark.run_decoding_benchmark(batch_size=4, prompt_length=512, gen_length=128, num_trials=5) print_benchmark_result(gen_result)