perf(benchmark): 添加基准性能测试
This commit is contained in:
parent
64c4d2d2e3
commit
cd4877e490
|
|
@ -0,0 +1,215 @@
|
|||
import torch
|
||||
from typing import Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from khaosz.core.transformer import TransformerConfig, Transformer
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
total_tokens: int
|
||||
total_time: float
|
||||
tokens_per_second: float
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class GenerationBenchmark:
|
||||
def __init__(
|
||||
self,
|
||||
config: TransformerConfig,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16
|
||||
):
|
||||
self.config = config
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.model = Transformer(config).to(device=device, dtype=dtype)
|
||||
self.model.eval()
|
||||
|
||||
def _initialize_kv_cache(self, batch_size: int, max_len: int) -> list:
|
||||
"""初始化KV缓存"""
|
||||
kv_cache = []
|
||||
head_dim = self.config.n_dim // self.config.n_head
|
||||
for _ in range(self.config.n_layer):
|
||||
k_cache = torch.zeros(
|
||||
(batch_size, max_len, self.config.n_kvhead, head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
(batch_size, max_len, self.config.n_kvhead, head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
kv_cache.append((k_cache, v_cache))
|
||||
return kv_cache
|
||||
|
||||
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
||||
prompt_ids = torch.randint(
|
||||
low=0,
|
||||
high=self.config.vocab_size,
|
||||
size=(batch_size, prompt_length),
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
gen_ids = torch.randint(
|
||||
low=0,
|
||||
high=self.config.vocab_size,
|
||||
size=(batch_size, total_length - prompt_length),
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
return prompt_ids, gen_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_prefill_benchmark(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
prompt_length: int = 512,
|
||||
num_trials: int = 10,
|
||||
) -> BenchmarkResult:
|
||||
|
||||
for _ in range(3):
|
||||
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
|
||||
_ = self.model(prompt_ids)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
total_time = 0.0
|
||||
total_tokens = batch_size * prompt_length * num_trials
|
||||
|
||||
for trial in range(num_trials):
|
||||
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
_ = self.model(prompt_ids)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||
total_time += trial_time
|
||||
|
||||
print(f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
||||
f"({prompt_length / trial_time:.1f} tokens/s)")
|
||||
|
||||
return BenchmarkResult(
|
||||
total_tokens=total_tokens,
|
||||
total_time=total_time,
|
||||
tokens_per_second=total_tokens / total_time,
|
||||
metadata={
|
||||
"benchmark_type": "prefill",
|
||||
"batch_size": batch_size,
|
||||
"prompt_length": prompt_length,
|
||||
"dtype": self.dtype,
|
||||
"device": self.device,
|
||||
}
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_decoding_benchmark(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
prompt_length: int = 512,
|
||||
gen_length: int = 128,
|
||||
num_trials: int = 5,
|
||||
) -> BenchmarkResult:
|
||||
|
||||
total_time = 0.0
|
||||
total_tokens = batch_size * gen_length * num_trials
|
||||
|
||||
for trial in range(num_trials):
|
||||
|
||||
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length)
|
||||
kv_cache = self._initialize_kv_cache(batch_size, self.config.m_len)
|
||||
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
|
||||
current_pos = prompt_length
|
||||
for i in range(gen_length):
|
||||
input_token = gen_ids[:, i:i+1]
|
||||
_ = self.model(input_token, persistent_key_values=kv_cache, start_pos=current_pos)
|
||||
current_pos += 1
|
||||
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||
total_time += trial_time
|
||||
|
||||
print(f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
||||
f"({gen_length / trial_time:.1f} tokens/s)")
|
||||
|
||||
|
||||
return BenchmarkResult(
|
||||
total_tokens=total_tokens,
|
||||
total_time=total_time,
|
||||
tokens_per_second=total_tokens / total_time,
|
||||
metadata={
|
||||
"benchmark_type": "generation",
|
||||
"batch_size": batch_size,
|
||||
"prompt_length": prompt_length,
|
||||
"gen_length": gen_length,
|
||||
"dtype": self.dtype,
|
||||
"device": self.device,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def print_benchmark_result(result: BenchmarkResult):
|
||||
"""打印基准测试结果"""
|
||||
benchmark_type = result.metadata["benchmark_type"]
|
||||
|
||||
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
|
||||
print(f"Total Tokens Processed: {result.total_tokens:,}")
|
||||
print(f"Time Consumed: {result.total_time:.3f}s")
|
||||
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
|
||||
|
||||
if benchmark_type == "prefill":
|
||||
print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}")
|
||||
elif benchmark_type == "generation":
|
||||
print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}")
|
||||
|
||||
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = TransformerConfig(
|
||||
vocab_size=10000,
|
||||
n_dim=1536,
|
||||
n_head=24,
|
||||
n_kvhead=4,
|
||||
d_ffn=6912,
|
||||
m_len=2048,
|
||||
n_layer=24,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
|
||||
benchmark = GenerationBenchmark(config)
|
||||
|
||||
print("=" * 80)
|
||||
print("Running Transformer Generation Benchmark")
|
||||
print("=" * 80)
|
||||
|
||||
prefill_result = benchmark.run_prefill_benchmark(
|
||||
batch_size=4,
|
||||
prompt_length=512,
|
||||
num_trials=5
|
||||
)
|
||||
print_benchmark_result(prefill_result)
|
||||
|
||||
gen_result = benchmark.run_decoding_benchmark(
|
||||
batch_size=4,
|
||||
prompt_length=512,
|
||||
gen_length=128,
|
||||
num_trials=5
|
||||
)
|
||||
print_benchmark_result(gen_result)
|
||||
|
||||
Loading…
Reference in New Issue