mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 09:31:35 +08:00
[Feat] ernie4_5_vl_moe
support CudaGraph (#3226)
* delete dynamic control flow for decode * coda-style * fix scatter/gather typos and use input stream instead default stream * support 0-Size Tensor * update runner and model * using static mem address as input * fix mem leak * refine code * update mm_buffer * fix typo * fix buffersize * fix unk token * refine code * refine * support other arch * open cudagraph in vlci * fix * update * update * update * fix cmd * update --------- Co-authored-by: aquagull <hongyuh@qq.com> Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -99,3 +99,35 @@ class GraphOptWrapper:
|
||||
fd_config.graph_opt_config.graph_opt_level < 1
|
||||
), "Currently unable to update weights in static graph mode."
|
||||
self.graph_opt_backend.clear_cudagraph_piecewise_backend()
|
||||
|
||||
|
||||
def cuda_graph_buffers(buffer_meta):
|
||||
def decorator(cls):
|
||||
original_init = cls.__init__
|
||||
|
||||
def __init__(self, fd_config: FDConfig, **kwargs):
|
||||
original_init(self, fd_config=fd_config, **kwargs)
|
||||
|
||||
def _resolve_path(root, path: str):
|
||||
cur = root
|
||||
for p in path.split("."):
|
||||
cur = getattr(cur, p)
|
||||
return cur
|
||||
|
||||
if not hasattr(self, "_mm_buffers"):
|
||||
self._mm_buffers = {}
|
||||
for name, meta in buffer_meta.items():
|
||||
shape = [_resolve_path(fd_config, s) if isinstance(s, str) else s for s in meta["shape"]]
|
||||
dtype = meta["dtype"]
|
||||
if "." in meta["dtype"]:
|
||||
dtype = _resolve_path(fd_config, meta["dtype"])
|
||||
self._mm_buffers[name] = paddle.full(
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
fill_value=meta.get("value", 0),
|
||||
)
|
||||
|
||||
cls.__init__ = __init__
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
Reference in New Issue
Block a user