feat(inference): 增加cuda_graph 装饰器

This commit is contained in:
ViperEkura 2026-02-07 21:14:39 +08:00
parent dbd57e30e5
commit 99ef8fda71
1 changed files with 60 additions and 15 deletions

View File

@ -1,39 +1,71 @@
import torch
from torch import Tensor
from functools import wraps
from inspect import signature
class CudaGraphWrapper:
def __init__(self, function, device="cuda"):
def __init__(self, function, device="cuda", cast=False):
self.function = function
self.cast = cast
self.device = device
self.static_input = None
self.static_output = None
self.graph = None
self.signature = signature(function)
def _update_inplace(self, lhs, rhs):
if isinstance(lhs, Tensor):
if lhs.shape != rhs.shape or lhs.dtype != rhs.dtype or lhs.device != rhs.device:
raise ValueError("Tensor metadata must be static for CUDA Graph.")
lhs.copy_(rhs)
if isinstance(lhs, Tensor) and isinstance(rhs, Tensor):
if lhs.shape != rhs.shape:
raise ValueError(
f"Tensor shape mismatch! "
f"Expected: {lhs.shape}, Got: {rhs.shape}. "
f"Function: {self.function}"
)
if self.cast:
if lhs.device != rhs.device:
rhs = rhs.to(device=lhs.device)
if lhs.dtype != rhs.dtype:
rhs = rhs.to(dtype=lhs.dtype)
else:
if lhs.device != rhs.device:
raise ValueError(
f"Tensor device mismatch! "
f"Expected: {lhs.device}, Got: {rhs.device}. "
f"Function: {self.function}"
)
if lhs.dtype != rhs.dtype:
raise ValueError(
f"Tensor dtype mismatch! "
f"Expected: {lhs.dtype}, Got: {rhs.dtype}. "
f"Function: {self.function}"
)
lhs.copy_(rhs)
elif isinstance(lhs, dict):
for k in lhs:
self._update_inplace(lhs[k], rhs[k])
if k in rhs:
self._update_inplace(lhs[k], rhs[k])
elif isinstance(lhs, (list, tuple)):
for i in range(len(lhs)):
self._update_inplace(lhs[i], rhs[i])
if i < len(rhs):
self._update_inplace(lhs[i], rhs[i])
elif isinstance(lhs, (int, float, bool, str, type(None))):
if lhs != rhs:
raise ValueError("Does not support changing control parameters.")
def _update_kwargs(self, input_kwargs: dict):
def _update_args(self, input_args, input_kwargs):
bound_args = self.signature.bind(*input_args, **input_kwargs)
bound_args.apply_defaults()
args_dict = bound_args.arguments
if self.static_input is None:
self.static_input = input_kwargs
self.static_input = args_dict
else:
self._update_inplace(self.static_input, input_kwargs)
self._update_inplace(self.static_input, args_dict)
def run(self, input_kwargs: dict):
self._update_kwargs(input_kwargs)
def run(self, *args, **kwargs):
self._update_args(args, kwargs)
if self.graph is None:
# warmup
@ -51,3 +83,16 @@ class CudaGraphWrapper:
self.graph.replay()
return self.static_output
def cuda_graph(device="cuda", cast=False):
def decorator(func):
wrapper = CudaGraphWrapper(func, device, cast)
@wraps(func)
def wrapped(*args, **kwargs):
return wrapper.run(*args, **kwargs)
return wrapped
return decorator