[SOT] Make custom_op dy&st unified (#2733)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* make_custom_op dy&st unified

* add instance judgement
This commit is contained in:
Ryan
2025-07-08 19:21:44 +08:00
committed by GitHub
parent f6ffbc3cbd
commit f72c4de539
3 changed files with 20 additions and 19 deletions

View File

@@ -931,8 +931,8 @@ class LLMEngine(object):
def _setting_environ_variables(self): def _setting_environ_variables(self):
""" """
配置环境变量 配置环境变量
""" """
variables = { variables = {
"PADDLE_TRAINER_ID": 0, "PADDLE_TRAINER_ID": 0,
"PADDLE_TRAINERS_NUM": 1, "PADDLE_TRAINERS_NUM": 1,
@@ -998,8 +998,8 @@ class LLMEngine(object):
py_script = os.path.join(current_dir_path, worker_path) py_script = os.path.join(current_dir_path, worker_path)
ori_vocab_size = ( ori_vocab_size = (
len(self.data_processor.tokenizer.sp_model) len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, 'sp_model') if hasattr(self.data_processor.tokenizer, 'sp_model')
else len(self.data_processor.tokenizer.vocab) else len(self.data_processor.tokenizer.vocab)
) )

View File

@@ -15,7 +15,6 @@
import functools import functools
import importlib import importlib
import inspect import inspect
import os
import paddle import paddle
@@ -77,7 +76,13 @@ def wrap_unified_op(original_cpp_ext_op, original_custom_op):
@functools.wraps(original_custom_op) @functools.wraps(original_custom_op)
def unified_op(*args, **kwargs): def unified_op(*args, **kwargs):
if paddle.in_dynamic_mode(): if paddle.in_dynamic_mode():
return original_cpp_ext_op(*args, **kwargs) res = original_cpp_ext_op(*args, **kwargs)
if res is None:
return None
# TODO(DrRyanHuang): Remove this if when we align the implementation of custom op and C++ extension
if isinstance(res, list) and len(res) == 1:
return res[0]
return res
return original_custom_op(*args, **kwargs) return original_custom_op(*args, **kwargs)
return unified_op return unified_op
@@ -93,17 +98,13 @@ def preprocess_static_op(global_ns):
""" """
static_op_prefix = "static_op_" static_op_prefix = "static_op_"
static_op_names = [k for k in global_ns if k.startswith(static_op_prefix)] static_op_names = [k for k in global_ns if k.startswith(static_op_prefix)]
enforce_eager = int(os.getenv("FD_ENFORCE_EAGER", "0")) == 1
for static_op in static_op_names: for static_op_name in static_op_names:
op_name = static_op[len(static_op_prefix):] op_name = static_op_name.removeprefix(static_op_prefix)
has_dynamic_op = op_name in global_ns if op_name not in global_ns:
global_ns[op_name] = global_ns[static_op_name]
continue
if has_dynamic_op: original_cpp_ext_op = global_ns[op_name]
if not enforce_eager: original_custom_op = global_ns[static_op_name]
original_cpp_ext_op = global_ns[op_name] global_ns[op_name] = wrap_unified_op(original_cpp_ext_op, original_custom_op)
original_custom_op = global_ns[static_op]
global_ns[op_name] = wrap_unified_op(original_cpp_ext_op,
original_custom_op)
else:
global_ns[op_name] = global_ns[static_op]

View File

@@ -445,7 +445,7 @@ class Ernie4_5_VLModel(nn.Layer):
forward_meta.seq_lens_this_time, forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_q,
score_text, score_text,
)[0].cast(self._dtype) ).cast(self._dtype)
# ----------------------- # -----------------------
out = self.norm(hidden_states) out = self.norm(hidden_states)