feat(inference): 增加cuda_graph 装饰器
This commit is contained in:
parent
dbd57e30e5
commit
99ef8fda71
|
|
@ -1,39 +1,71 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from functools import wraps
|
||||
from inspect import signature
|
||||
|
||||
|
||||
class CudaGraphWrapper:
|
||||
def __init__(self, function, device="cuda"):
|
||||
def __init__(self, function, device="cuda", cast=False):
|
||||
self.function = function
|
||||
self.cast = cast
|
||||
self.device = device
|
||||
self.static_input = None
|
||||
self.static_output = None
|
||||
self.graph = None
|
||||
self.signature = signature(function)
|
||||
|
||||
def _update_inplace(self, lhs, rhs):
|
||||
if isinstance(lhs, Tensor):
|
||||
if lhs.shape != rhs.shape or lhs.dtype != rhs.dtype or lhs.device != rhs.device:
|
||||
raise ValueError("Tensor metadata must be static for CUDA Graph.")
|
||||
lhs.copy_(rhs)
|
||||
if isinstance(lhs, Tensor) and isinstance(rhs, Tensor):
|
||||
if lhs.shape != rhs.shape:
|
||||
raise ValueError(
|
||||
f"Tensor shape mismatch! "
|
||||
f"Expected: {lhs.shape}, Got: {rhs.shape}. "
|
||||
f"Function: {self.function}"
|
||||
)
|
||||
if self.cast:
|
||||
if lhs.device != rhs.device:
|
||||
rhs = rhs.to(device=lhs.device)
|
||||
|
||||
if lhs.dtype != rhs.dtype:
|
||||
rhs = rhs.to(dtype=lhs.dtype)
|
||||
else:
|
||||
if lhs.device != rhs.device:
|
||||
raise ValueError(
|
||||
f"Tensor device mismatch! "
|
||||
f"Expected: {lhs.device}, Got: {rhs.device}. "
|
||||
f"Function: {self.function}"
|
||||
)
|
||||
if lhs.dtype != rhs.dtype:
|
||||
raise ValueError(
|
||||
f"Tensor dtype mismatch! "
|
||||
f"Expected: {lhs.dtype}, Got: {rhs.dtype}. "
|
||||
f"Function: {self.function}"
|
||||
)
|
||||
lhs.copy_(rhs)
|
||||
elif isinstance(lhs, dict):
|
||||
for k in lhs:
|
||||
self._update_inplace(lhs[k], rhs[k])
|
||||
if k in rhs:
|
||||
self._update_inplace(lhs[k], rhs[k])
|
||||
elif isinstance(lhs, (list, tuple)):
|
||||
for i in range(len(lhs)):
|
||||
self._update_inplace(lhs[i], rhs[i])
|
||||
if i < len(rhs):
|
||||
self._update_inplace(lhs[i], rhs[i])
|
||||
elif isinstance(lhs, (int, float, bool, str, type(None))):
|
||||
if lhs != rhs:
|
||||
raise ValueError("Does not support changing control parameters.")
|
||||
|
||||
def _update_kwargs(self, input_kwargs: dict):
|
||||
def _update_args(self, input_args, input_kwargs):
|
||||
bound_args = self.signature.bind(*input_args, **input_kwargs)
|
||||
bound_args.apply_defaults()
|
||||
args_dict = bound_args.arguments
|
||||
|
||||
if self.static_input is None:
|
||||
self.static_input = input_kwargs
|
||||
self.static_input = args_dict
|
||||
else:
|
||||
self._update_inplace(self.static_input, input_kwargs)
|
||||
self._update_inplace(self.static_input, args_dict)
|
||||
|
||||
|
||||
def run(self, input_kwargs: dict):
|
||||
self._update_kwargs(input_kwargs)
|
||||
def run(self, *args, **kwargs):
|
||||
self._update_args(args, kwargs)
|
||||
|
||||
if self.graph is None:
|
||||
# warmup
|
||||
|
|
@ -51,3 +83,16 @@ class CudaGraphWrapper:
|
|||
self.graph.replay()
|
||||
|
||||
return self.static_output
|
||||
|
||||
|
||||
def cuda_graph(device="cuda", cast=False):
|
||||
def decorator(func):
|
||||
wrapper = CudaGraphWrapper(func, device, cast)
|
||||
|
||||
@wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
return wrapper.run(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
Loading…
Reference in New Issue