From 99ef8fda71cad1c9d0dedc1ea4d77a7702e7cb2f Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 7 Feb 2026 21:14:39 +0800 Subject: [PATCH] =?UTF-8?q?feat(inference):=20=E5=A2=9E=E5=8A=A0cuda=5Fgra?= =?UTF-8?q?ph=20=E8=A3=85=E9=A5=B0=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/inference/cuda_graph.py | 75 +++++++++++++++++++++++++++------- 1 file changed, 60 insertions(+), 15 deletions(-) diff --git a/khaosz/inference/cuda_graph.py b/khaosz/inference/cuda_graph.py index a7681b5..09c05a7 100644 --- a/khaosz/inference/cuda_graph.py +++ b/khaosz/inference/cuda_graph.py @@ -1,39 +1,71 @@ import torch from torch import Tensor +from functools import wraps +from inspect import signature class CudaGraphWrapper: - def __init__(self, function, device="cuda"): + def __init__(self, function, device="cuda", cast=False): self.function = function + self.cast = cast self.device = device self.static_input = None self.static_output = None self.graph = None + self.signature = signature(function) def _update_inplace(self, lhs, rhs): - if isinstance(lhs, Tensor): - if lhs.shape != rhs.shape or lhs.dtype != rhs.dtype or lhs.device != rhs.device: - raise ValueError("Tensor metadata must be static for CUDA Graph.") - lhs.copy_(rhs) + if isinstance(lhs, Tensor) and isinstance(rhs, Tensor): + if lhs.shape != rhs.shape: + raise ValueError( + f"Tensor shape mismatch! " + f"Expected: {lhs.shape}, Got: {rhs.shape}. " + f"Function: {self.function}" + ) + if self.cast: + if lhs.device != rhs.device: + rhs = rhs.to(device=lhs.device) + + if lhs.dtype != rhs.dtype: + rhs = rhs.to(dtype=lhs.dtype) + else: + if lhs.device != rhs.device: + raise ValueError( + f"Tensor device mismatch! " + f"Expected: {lhs.device}, Got: {rhs.device}. " + f"Function: {self.function}" + ) + if lhs.dtype != rhs.dtype: + raise ValueError( + f"Tensor dtype mismatch! " + f"Expected: {lhs.dtype}, Got: {rhs.dtype}. " + f"Function: {self.function}" + ) + lhs.copy_(rhs) elif isinstance(lhs, dict): for k in lhs: - self._update_inplace(lhs[k], rhs[k]) + if k in rhs: + self._update_inplace(lhs[k], rhs[k]) elif isinstance(lhs, (list, tuple)): for i in range(len(lhs)): - self._update_inplace(lhs[i], rhs[i]) + if i < len(rhs): + self._update_inplace(lhs[i], rhs[i]) elif isinstance(lhs, (int, float, bool, str, type(None))): if lhs != rhs: raise ValueError("Does not support changing control parameters.") - def _update_kwargs(self, input_kwargs: dict): - if self.static_input is None: - self.static_input = input_kwargs - else: - self._update_inplace(self.static_input, input_kwargs) + def _update_args(self, input_args, input_kwargs): + bound_args = self.signature.bind(*input_args, **input_kwargs) + bound_args.apply_defaults() + args_dict = bound_args.arguments + if self.static_input is None: + self.static_input = args_dict + else: + self._update_inplace(self.static_input, args_dict) - def run(self, input_kwargs: dict): - self._update_kwargs(input_kwargs) + def run(self, *args, **kwargs): + self._update_args(args, kwargs) if self.graph is None: # warmup @@ -50,4 +82,17 @@ class CudaGraphWrapper: self.graph.replay() - return self.static_output \ No newline at end of file + return self.static_output + + +def cuda_graph(device="cuda", cast=False): + def decorator(func): + wrapper = CudaGraphWrapper(func, device, cast) + + @wraps(func) + def wrapped(*args, **kwargs): + return wrapper.run(*args, **kwargs) + + return wrapped + + return decorator \ No newline at end of file