perf(benchmark): 添加基准性能测试
This commit is contained in:
parent
64c4d2d2e3
commit
cd4877e490
|
|
@ -0,0 +1,215 @@
|
||||||
|
import torch
|
||||||
|
from typing import Dict, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from khaosz.core.transformer import TransformerConfig, Transformer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkResult:
|
||||||
|
total_tokens: int
|
||||||
|
total_time: float
|
||||||
|
tokens_per_second: float
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationBenchmark:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: TransformerConfig,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype = torch.float16
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.device = device
|
||||||
|
self.dtype = dtype
|
||||||
|
self.model = Transformer(config).to(device=device, dtype=dtype)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def _initialize_kv_cache(self, batch_size: int, max_len: int) -> list:
|
||||||
|
"""初始化KV缓存"""
|
||||||
|
kv_cache = []
|
||||||
|
head_dim = self.config.n_dim // self.config.n_head
|
||||||
|
for _ in range(self.config.n_layer):
|
||||||
|
k_cache = torch.zeros(
|
||||||
|
(batch_size, max_len, self.config.n_kvhead, head_dim),
|
||||||
|
device=self.device, dtype=self.dtype
|
||||||
|
)
|
||||||
|
v_cache = torch.zeros(
|
||||||
|
(batch_size, max_len, self.config.n_kvhead, head_dim),
|
||||||
|
device=self.device, dtype=self.dtype
|
||||||
|
)
|
||||||
|
kv_cache.append((k_cache, v_cache))
|
||||||
|
return kv_cache
|
||||||
|
|
||||||
|
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
||||||
|
prompt_ids = torch.randint(
|
||||||
|
low=0,
|
||||||
|
high=self.config.vocab_size,
|
||||||
|
size=(batch_size, prompt_length),
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_ids = torch.randint(
|
||||||
|
low=0,
|
||||||
|
high=self.config.vocab_size,
|
||||||
|
size=(batch_size, total_length - prompt_length),
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt_ids, gen_ids
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_prefill_benchmark(
|
||||||
|
self,
|
||||||
|
batch_size: int = 1,
|
||||||
|
prompt_length: int = 512,
|
||||||
|
num_trials: int = 10,
|
||||||
|
) -> BenchmarkResult:
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
|
||||||
|
_ = self.model(prompt_ids)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
total_time = 0.0
|
||||||
|
total_tokens = batch_size * prompt_length * num_trials
|
||||||
|
|
||||||
|
for trial in range(num_trials):
|
||||||
|
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
_ = self.model(prompt_ids)
|
||||||
|
end_event.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||||
|
total_time += trial_time
|
||||||
|
|
||||||
|
print(f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
||||||
|
f"({prompt_length / trial_time:.1f} tokens/s)")
|
||||||
|
|
||||||
|
return BenchmarkResult(
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
total_time=total_time,
|
||||||
|
tokens_per_second=total_tokens / total_time,
|
||||||
|
metadata={
|
||||||
|
"benchmark_type": "prefill",
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"prompt_length": prompt_length,
|
||||||
|
"dtype": self.dtype,
|
||||||
|
"device": self.device,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_decoding_benchmark(
|
||||||
|
self,
|
||||||
|
batch_size: int = 1,
|
||||||
|
prompt_length: int = 512,
|
||||||
|
gen_length: int = 128,
|
||||||
|
num_trials: int = 5,
|
||||||
|
) -> BenchmarkResult:
|
||||||
|
|
||||||
|
total_time = 0.0
|
||||||
|
total_tokens = batch_size * gen_length * num_trials
|
||||||
|
|
||||||
|
for trial in range(num_trials):
|
||||||
|
|
||||||
|
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length)
|
||||||
|
kv_cache = self._initialize_kv_cache(batch_size, self.config.m_len)
|
||||||
|
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
|
||||||
|
current_pos = prompt_length
|
||||||
|
for i in range(gen_length):
|
||||||
|
input_token = gen_ids[:, i:i+1]
|
||||||
|
_ = self.model(input_token, persistent_key_values=kv_cache, start_pos=current_pos)
|
||||||
|
current_pos += 1
|
||||||
|
|
||||||
|
end_event.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||||
|
total_time += trial_time
|
||||||
|
|
||||||
|
print(f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
||||||
|
f"({gen_length / trial_time:.1f} tokens/s)")
|
||||||
|
|
||||||
|
|
||||||
|
return BenchmarkResult(
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
total_time=total_time,
|
||||||
|
tokens_per_second=total_tokens / total_time,
|
||||||
|
metadata={
|
||||||
|
"benchmark_type": "generation",
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"prompt_length": prompt_length,
|
||||||
|
"gen_length": gen_length,
|
||||||
|
"dtype": self.dtype,
|
||||||
|
"device": self.device,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def print_benchmark_result(result: BenchmarkResult):
|
||||||
|
"""打印基准测试结果"""
|
||||||
|
benchmark_type = result.metadata["benchmark_type"]
|
||||||
|
|
||||||
|
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
|
||||||
|
print(f"Total Tokens Processed: {result.total_tokens:,}")
|
||||||
|
print(f"Time Consumed: {result.total_time:.3f}s")
|
||||||
|
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
|
||||||
|
|
||||||
|
if benchmark_type == "prefill":
|
||||||
|
print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}")
|
||||||
|
elif benchmark_type == "generation":
|
||||||
|
print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}")
|
||||||
|
|
||||||
|
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config = TransformerConfig(
|
||||||
|
vocab_size=10000,
|
||||||
|
n_dim=1536,
|
||||||
|
n_head=24,
|
||||||
|
n_kvhead=4,
|
||||||
|
d_ffn=6912,
|
||||||
|
m_len=2048,
|
||||||
|
n_layer=24,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
benchmark = GenerationBenchmark(config)
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("Running Transformer Generation Benchmark")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
prefill_result = benchmark.run_prefill_benchmark(
|
||||||
|
batch_size=4,
|
||||||
|
prompt_length=512,
|
||||||
|
num_trials=5
|
||||||
|
)
|
||||||
|
print_benchmark_result(prefill_result)
|
||||||
|
|
||||||
|
gen_result = benchmark.run_decoding_benchmark(
|
||||||
|
batch_size=4,
|
||||||
|
prompt_length=512,
|
||||||
|
gen_length=128,
|
||||||
|
num_trials=5
|
||||||
|
)
|
||||||
|
print_benchmark_result(gen_result)
|
||||||
|
|
||||||
Loading…
Reference in New Issue