mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
[SOT] Make custom_op dy&st unified (#2733)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* make_custom_op dy&st unified * add instance judgement
This commit is contained in:
@@ -931,8 +931,8 @@ class LLMEngine(object):
|
|||||||
|
|
||||||
def _setting_environ_variables(self):
|
def _setting_environ_variables(self):
|
||||||
"""
|
"""
|
||||||
配置环境变量
|
配置环境变量
|
||||||
"""
|
"""
|
||||||
variables = {
|
variables = {
|
||||||
"PADDLE_TRAINER_ID": 0,
|
"PADDLE_TRAINER_ID": 0,
|
||||||
"PADDLE_TRAINERS_NUM": 1,
|
"PADDLE_TRAINERS_NUM": 1,
|
||||||
@@ -998,8 +998,8 @@ class LLMEngine(object):
|
|||||||
py_script = os.path.join(current_dir_path, worker_path)
|
py_script = os.path.join(current_dir_path, worker_path)
|
||||||
|
|
||||||
ori_vocab_size = (
|
ori_vocab_size = (
|
||||||
len(self.data_processor.tokenizer.sp_model)
|
len(self.data_processor.tokenizer.sp_model)
|
||||||
if hasattr(self.data_processor.tokenizer, 'sp_model')
|
if hasattr(self.data_processor.tokenizer, 'sp_model')
|
||||||
else len(self.data_processor.tokenizer.vocab)
|
else len(self.data_processor.tokenizer.vocab)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -15,7 +15,6 @@
|
|||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
@@ -77,7 +76,13 @@ def wrap_unified_op(original_cpp_ext_op, original_custom_op):
|
|||||||
@functools.wraps(original_custom_op)
|
@functools.wraps(original_custom_op)
|
||||||
def unified_op(*args, **kwargs):
|
def unified_op(*args, **kwargs):
|
||||||
if paddle.in_dynamic_mode():
|
if paddle.in_dynamic_mode():
|
||||||
return original_cpp_ext_op(*args, **kwargs)
|
res = original_cpp_ext_op(*args, **kwargs)
|
||||||
|
if res is None:
|
||||||
|
return None
|
||||||
|
# TODO(DrRyanHuang): Remove this if when we align the implementation of custom op and C++ extension
|
||||||
|
if isinstance(res, list) and len(res) == 1:
|
||||||
|
return res[0]
|
||||||
|
return res
|
||||||
return original_custom_op(*args, **kwargs)
|
return original_custom_op(*args, **kwargs)
|
||||||
|
|
||||||
return unified_op
|
return unified_op
|
||||||
@@ -93,17 +98,13 @@ def preprocess_static_op(global_ns):
|
|||||||
"""
|
"""
|
||||||
static_op_prefix = "static_op_"
|
static_op_prefix = "static_op_"
|
||||||
static_op_names = [k for k in global_ns if k.startswith(static_op_prefix)]
|
static_op_names = [k for k in global_ns if k.startswith(static_op_prefix)]
|
||||||
enforce_eager = int(os.getenv("FD_ENFORCE_EAGER", "0")) == 1
|
|
||||||
|
|
||||||
for static_op in static_op_names:
|
for static_op_name in static_op_names:
|
||||||
op_name = static_op[len(static_op_prefix):]
|
op_name = static_op_name.removeprefix(static_op_prefix)
|
||||||
has_dynamic_op = op_name in global_ns
|
if op_name not in global_ns:
|
||||||
|
global_ns[op_name] = global_ns[static_op_name]
|
||||||
|
continue
|
||||||
|
|
||||||
if has_dynamic_op:
|
original_cpp_ext_op = global_ns[op_name]
|
||||||
if not enforce_eager:
|
original_custom_op = global_ns[static_op_name]
|
||||||
original_cpp_ext_op = global_ns[op_name]
|
global_ns[op_name] = wrap_unified_op(original_cpp_ext_op, original_custom_op)
|
||||||
original_custom_op = global_ns[static_op]
|
|
||||||
global_ns[op_name] = wrap_unified_op(original_cpp_ext_op,
|
|
||||||
original_custom_op)
|
|
||||||
else:
|
|
||||||
global_ns[op_name] = global_ns[static_op]
|
|
||||||
|
@@ -445,7 +445,7 @@ class Ernie4_5_VLModel(nn.Layer):
|
|||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q,
|
||||||
score_text,
|
score_text,
|
||||||
)[0].cast(self._dtype)
|
).cast(self._dtype)
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
|
||||||
out = self.norm(hidden_states)
|
out = self.norm(hidden_states)
|
||||||
|
Reference in New Issue
Block a user