feat(inference): 增加cuda graph 设置

This commit is contained in:
ViperEkura 2026-02-07 15:42:41 +08:00
parent a5869d89ba
commit dbd57e30e5
1 changed files with 53 additions and 0 deletions

View File

@ -0,0 +1,53 @@
import torch
from torch import Tensor
class CudaGraphWrapper:
def __init__(self, function, device="cuda"):
self.function = function
self.device = device
self.static_input = None
self.static_output = None
self.graph = None
def _update_inplace(self, lhs, rhs):
if isinstance(lhs, Tensor):
if lhs.shape != rhs.shape or lhs.dtype != rhs.dtype or lhs.device != rhs.device:
raise ValueError("Tensor metadata must be static for CUDA Graph.")
lhs.copy_(rhs)
elif isinstance(lhs, dict):
for k in lhs:
self._update_inplace(lhs[k], rhs[k])
elif isinstance(lhs, (list, tuple)):
for i in range(len(lhs)):
self._update_inplace(lhs[i], rhs[i])
elif isinstance(lhs, (int, float, bool, str, type(None))):
if lhs != rhs:
raise ValueError("Does not support changing control parameters.")
def _update_kwargs(self, input_kwargs: dict):
if self.static_input is None:
self.static_input = input_kwargs
else:
self._update_inplace(self.static_input, input_kwargs)
def run(self, input_kwargs: dict):
self._update_kwargs(input_kwargs)
if self.graph is None:
# warmup
_ = torch.matmul(
torch.randn(100, 100, device=self.device),
torch.randn(100, 100, device=self.device)
)
torch.cuda.synchronize()
# capture graph
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.function(**self.static_input)
self.graph.replay()
return self.static_output