mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[GCU] Update to develop (#2988)
This commit is contained in:
@@ -31,7 +31,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
|
|
||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
@@ -44,15 +44,12 @@ class GCUFlashAttnMetadata(AttentionMetadata):
|
|||||||
GCUFlashAttnMetadata
|
GCUFlashAttnMetadata
|
||||||
"""
|
"""
|
||||||
|
|
||||||
forward_mode: ForwardMode = ForwardMode.MIXED
|
|
||||||
|
|
||||||
_dtype: paddle.dtype = paddle.bfloat16
|
_dtype: paddle.dtype = paddle.bfloat16
|
||||||
|
|
||||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||||
seq_lens_decoder: Optional[paddle.Tensor] = None
|
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||||
seq_lens_this_time: Optional[paddle.Tensor] = None
|
seq_lens_this_time: Optional[paddle.Tensor] = None
|
||||||
cum_offsets: Optional[paddle.Tensor] = None
|
batch_id_per_token: Optional[paddle.Tensor] = None
|
||||||
padding_offset: Optional[paddle.Tensor] = None
|
|
||||||
|
|
||||||
cu_seqlens_q: Optional[paddle.Tensor] = None
|
cu_seqlens_q: Optional[paddle.Tensor] = None
|
||||||
cu_seqlens_k: Optional[paddle.Tensor] = None
|
cu_seqlens_k: Optional[paddle.Tensor] = None
|
||||||
@@ -118,8 +115,7 @@ class GCUFlashAttnBackend(AttentionBackend):
|
|||||||
metadata.seq_lens_encoder = forward_meta.seq_lens_encoder
|
metadata.seq_lens_encoder = forward_meta.seq_lens_encoder
|
||||||
metadata.seq_lens_decoder = forward_meta.seq_lens_decoder
|
metadata.seq_lens_decoder = forward_meta.seq_lens_decoder
|
||||||
metadata.seq_lens_this_time = forward_meta.seq_lens_this_time
|
metadata.seq_lens_this_time = forward_meta.seq_lens_this_time
|
||||||
metadata.cum_offsets = forward_meta.cum_offsets
|
metadata.batch_id_per_token = forward_meta.batch_id_per_token
|
||||||
metadata.padding_offset = forward_meta.padding_offset
|
|
||||||
|
|
||||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||||
metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
|
metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
|
||||||
|
@@ -36,7 +36,7 @@ from fastdeploy.model_executor.ops.gcu import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -45,14 +45,12 @@ class GCUMemEfficientAttnMetadata(AttentionMetadata):
|
|||||||
GCUMemEfficientAttnMetadata
|
GCUMemEfficientAttnMetadata
|
||||||
"""
|
"""
|
||||||
|
|
||||||
forward_mode: ForwardMode = ForwardMode.MIXED
|
|
||||||
_dtype: paddle.dtype = paddle.bfloat16
|
_dtype: paddle.dtype = paddle.bfloat16
|
||||||
|
|
||||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||||
seq_lens_decoder: Optional[paddle.Tensor] = None
|
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||||
seq_lens_this_time: Optional[paddle.Tensor] = None
|
seq_lens_this_time: Optional[paddle.Tensor] = None
|
||||||
cum_offsets: Optional[paddle.Tensor] = None
|
batch_id_per_token: Optional[paddle.Tensor] = None
|
||||||
padding_offset: Optional[paddle.Tensor] = None
|
|
||||||
|
|
||||||
cu_seqlens_q: Optional[paddle.Tensor] = None
|
cu_seqlens_q: Optional[paddle.Tensor] = None
|
||||||
cu_seqlens_k: Optional[paddle.Tensor] = None
|
cu_seqlens_k: Optional[paddle.Tensor] = None
|
||||||
@@ -115,8 +113,7 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
|||||||
metadata.seq_lens_encoder = forward_meta.seq_lens_encoder
|
metadata.seq_lens_encoder = forward_meta.seq_lens_encoder
|
||||||
metadata.seq_lens_decoder = forward_meta.seq_lens_decoder
|
metadata.seq_lens_decoder = forward_meta.seq_lens_decoder
|
||||||
metadata.seq_lens_this_time = forward_meta.seq_lens_this_time
|
metadata.seq_lens_this_time = forward_meta.seq_lens_this_time
|
||||||
metadata.cum_offsets = forward_meta.cum_offsets
|
metadata.batch_id_per_token = forward_meta.batch_id_per_token
|
||||||
metadata.padding_offset = forward_meta.padding_offset
|
|
||||||
|
|
||||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||||
metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
|
metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
|
||||||
|
@@ -60,6 +60,7 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
local_rank: int,
|
local_rank: int,
|
||||||
):
|
):
|
||||||
super().__init__(fd_config=fd_config, device=device)
|
super().__init__(fd_config=fd_config, device=device)
|
||||||
|
self.enable_mm = self.model_config.enable_mm
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.device_id = device_id
|
self.device_id = device_id
|
||||||
@@ -80,8 +81,6 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
# Cuda Graph
|
# Cuda Graph
|
||||||
self.use_cudagraph = self.graph_opt_config.use_cudagraph
|
self.use_cudagraph = self.graph_opt_config.use_cudagraph
|
||||||
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
||||||
self.cudagraph_num_of_warmups = self.graph_opt_config.cudagraph_num_of_warmups
|
|
||||||
self.input_ids = paddle.zeros(self.parallel_config.max_num_seqs, dtype="int32")
|
|
||||||
|
|
||||||
# Initialize share inputs
|
# Initialize share inputs
|
||||||
self._init_share_inputs(self.parallel_config.max_num_seqs)
|
self._init_share_inputs(self.parallel_config.max_num_seqs)
|
||||||
@@ -107,14 +106,14 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
def exist_prefill(self):
|
def exist_prefill(self):
|
||||||
"""
|
"""
|
||||||
check whether prefill stage exist
|
Check whether prefill stage exist
|
||||||
"""
|
"""
|
||||||
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
|
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def init_speculative_proposer(self):
|
def _init_speculative_proposer(self):
|
||||||
"""
|
"""
|
||||||
Init speculative proposer
|
Init speculative proposer
|
||||||
"""
|
"""
|
||||||
@@ -155,11 +154,19 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
|
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
|
||||||
os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1"
|
os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1"
|
||||||
|
|
||||||
|
def get_attr_from_request(request, attr, default_value=None):
|
||||||
|
res = request.get(attr, default_value)
|
||||||
|
if res is not None:
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
return default_value
|
||||||
|
|
||||||
req_len = len(req_dicts)
|
req_len = len(req_dicts)
|
||||||
for i in range(req_len):
|
for i in range(req_len):
|
||||||
request = req_dicts[i]
|
request = req_dicts[i]
|
||||||
idx = request.idx
|
idx = request.idx
|
||||||
length = len(request.prompt_token_ids)
|
length = len(request.prompt_token_ids)
|
||||||
|
assert length > 0, "The prompt requested must not be empty."
|
||||||
|
|
||||||
prefill_tokens = []
|
prefill_tokens = []
|
||||||
if (
|
if (
|
||||||
@@ -177,11 +184,13 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
prefill_tokens.append(request.prompt_token_ids[0])
|
prefill_tokens.append(request.prompt_token_ids[0])
|
||||||
self.share_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1]
|
self.share_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1]
|
||||||
self.share_inputs["input_ids"][idx : idx + 1, 0] = request.prompt_token_ids[0]
|
self.share_inputs["input_ids"][idx : idx + 1, 0] = request.prompt_token_ids[0]
|
||||||
|
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
|
||||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = length
|
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = length
|
||||||
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 1
|
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 1
|
||||||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = 0
|
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = 0
|
||||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = length
|
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = length
|
||||||
|
self.share_inputs["prompt_lens"][idx : idx + 1] = length
|
||||||
self.share_inputs["step_idx"][idx : idx + 1] = 1
|
self.share_inputs["step_idx"][idx : idx + 1] = 1
|
||||||
|
|
||||||
if self.speculative_decoding:
|
if self.speculative_decoding:
|
||||||
@@ -195,39 +204,52 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["pre_ids"][idx : idx + 1] = -1
|
self.share_inputs["pre_ids"][idx : idx + 1] = -1
|
||||||
self.share_inputs["step_idx"][idx : idx + 1] = 0
|
self.share_inputs["step_idx"][idx : idx + 1] = 0
|
||||||
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
|
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
|
||||||
|
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
|
||||||
|
|
||||||
# Use chunked prefill
|
# Use chunked prefill
|
||||||
if self.parallel_config.enable_chunked_prefill:
|
if self.parallel_config.enable_chunked_prefill:
|
||||||
request.set("chunk_idx", 1)
|
request.set("chunk_idx", 1)
|
||||||
logger.info(f"prefill_chunk_info: {request.prefill_chunk_info}")
|
logger.info(f"prefill_chunk_info: {request.prefill_chunk_info}")
|
||||||
token_chunk_size = request.prefill_chunk_info[0]
|
token_chunk_size = request.prefill_chunk_info[0]
|
||||||
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
|
|
||||||
self.share_inputs["input_ids"][idx, :token_chunk_size] = np.array(
|
self.share_inputs["input_ids"][idx, :token_chunk_size] = np.array(
|
||||||
request.prompt_token_ids[:token_chunk_size]
|
request.prompt_token_ids[:token_chunk_size]
|
||||||
)
|
)
|
||||||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
|
||||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
|
||||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||||
|
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
|
||||||
|
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
||||||
|
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
||||||
|
self.share_inputs["prompt_lens"][idx : idx + 1] = token_chunk_size
|
||||||
else:
|
else:
|
||||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||||
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
|
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
|
||||||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
|
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
|
||||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||||
|
self.share_inputs["prompt_lens"][idx : idx + 1] = length
|
||||||
|
|
||||||
if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens:
|
if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens:
|
||||||
request.eos_token_ids.append(request.eos_token_ids[0])
|
request.eos_token_ids.append(request.eos_token_ids[0])
|
||||||
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||||
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
|
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
|
||||||
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
|
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
|
||||||
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
|
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
|
||||||
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
|
|
||||||
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
|
self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95)
|
||||||
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
|
self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request(
|
||||||
|
request, "repetition_penalty", 1.0
|
||||||
|
)
|
||||||
|
self.share_inputs["frequency_score"][idx : idx + 1] = get_attr_from_request(
|
||||||
|
request, "frequency_penalty", 0.0
|
||||||
|
)
|
||||||
|
self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request(
|
||||||
|
request, "presence_penalty", 0.0
|
||||||
|
)
|
||||||
|
|
||||||
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
|
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
|
||||||
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get("max_tokens", self.model_config.max_length)
|
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
|
||||||
|
"max_tokens", self.model_config.max_model_len
|
||||||
|
)
|
||||||
self.share_inputs["stop_flags"][idx : idx + 1] = False
|
self.share_inputs["stop_flags"][idx : idx + 1] = False
|
||||||
|
|
||||||
self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1]
|
self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1]
|
||||||
@@ -273,14 +295,18 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
idx = i
|
idx = i
|
||||||
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
||||||
|
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
||||||
self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
|
self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
|
||||||
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length
|
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length
|
||||||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length
|
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length
|
||||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
|
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
|
||||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||||
|
self.share_inputs["prompt_lens"][idx : idx + 1] = 0
|
||||||
self.share_inputs["step_idx"][idx : idx + 1] = 0
|
self.share_inputs["step_idx"][idx : idx + 1] = 0
|
||||||
self.share_inputs["max_dec_len"][idx : idx + 1] = max_dec_len
|
self.share_inputs["max_dec_len"][idx : idx + 1] = max_dec_len
|
||||||
|
self.share_inputs["min_dec_len"][idx : idx + 1] = max_dec_len
|
||||||
self.share_inputs["stop_flags"][idx : idx + 1] = False
|
self.share_inputs["stop_flags"][idx : idx + 1] = False
|
||||||
|
self.share_inputs["temperature"][idx : idx + 1] = 1
|
||||||
|
|
||||||
self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1]
|
self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1]
|
||||||
self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = input_length
|
self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = input_length
|
||||||
@@ -291,8 +317,8 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _init_share_inputs(self, max_num_seqs: int):
|
def _init_share_inputs(self, max_num_seqs: int):
|
||||||
"""Initialize all share buffers for model inputs.
|
"""
|
||||||
Note: In the future, we may abandon share buffers.
|
Initialize all share buffers for model inputs.
|
||||||
"""
|
"""
|
||||||
self.MAX_INFER_SEED = 9223372036854775806
|
self.MAX_INFER_SEED = 9223372036854775806
|
||||||
self.share_inputs = {}
|
self.share_inputs = {}
|
||||||
@@ -307,9 +333,15 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
self.parallel_config.pad_token_id,
|
self.parallel_config.pad_token_id,
|
||||||
dtype="int64",
|
dtype="int64",
|
||||||
)
|
)
|
||||||
|
self.share_inputs["prompt_ids"] = paddle.full(
|
||||||
|
[max_num_seqs, self.parallel_config.max_model_len],
|
||||||
|
self.parallel_config.pad_token_id,
|
||||||
|
dtype="int64",
|
||||||
|
)
|
||||||
self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64")
|
self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64")
|
||||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
|
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
|
||||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||||
|
self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
|
||||||
self.share_inputs["temperature"] = paddle.full(
|
self.share_inputs["temperature"] = paddle.full(
|
||||||
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
|
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
|
||||||
)
|
)
|
||||||
@@ -326,14 +358,19 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.share_inputs["min_dec_len"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
|
self.share_inputs["min_dec_len"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
|
||||||
self.share_inputs["max_dec_len"] = paddle.full([max_num_seqs, 1], self.model_config.max_length, dtype="int64")
|
self.share_inputs["max_dec_len"] = paddle.full(
|
||||||
|
[max_num_seqs, 1], self.model_config.max_model_len, dtype="int64"
|
||||||
|
)
|
||||||
self.share_inputs["min_length"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
|
self.share_inputs["min_length"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
|
||||||
self.share_inputs["max_length"] = paddle.full([max_num_seqs, 1], self.model_config.max_length, dtype="int64")
|
self.share_inputs["max_length"] = paddle.full(
|
||||||
|
[max_num_seqs, 1], self.model_config.max_model_len, dtype="int64"
|
||||||
|
)
|
||||||
self.share_inputs["seq_lens_this_time"] = paddle.full(max_num_seqs, 0, dtype="int32")
|
self.share_inputs["seq_lens_this_time"] = paddle.full(max_num_seqs, 0, dtype="int32")
|
||||||
self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
|
self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||||
self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||||
self.share_inputs["not_need_stop"] = paddle.full([1], False, dtype="bool").cpu()
|
self.share_inputs["not_need_stop"] = paddle.full([1], False, dtype="bool").cpu()
|
||||||
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
||||||
@@ -362,7 +399,7 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
dtype="int64",
|
dtype="int64",
|
||||||
)
|
)
|
||||||
self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
self.share_inputs["padding_offset"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
# AttentionBackend buffers
|
# AttentionBackend buffers
|
||||||
@@ -438,12 +475,12 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_inputs(self) -> None:
|
def _prepare_inputs(self) -> None:
|
||||||
"""prepare the model inputs"""
|
"""Prepare the model inputs"""
|
||||||
# Remove padding
|
# Remove padding
|
||||||
(
|
(
|
||||||
ids_remove_padding,
|
ids_remove_padding,
|
||||||
cum_offsets,
|
cum_offsets,
|
||||||
padding_offset,
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
output_cum_offsets,
|
output_cum_offsets,
|
||||||
@@ -459,7 +496,7 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
||||||
self.share_inputs["cum_offsets"].copy_(cum_offsets, False)
|
self.share_inputs["cum_offsets"].copy_(cum_offsets, False)
|
||||||
self.share_inputs["padding_offset"].copy_(padding_offset, False)
|
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
|
||||||
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
||||||
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
||||||
|
|
||||||
@@ -476,8 +513,11 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
temperature=self.share_inputs["temperature"],
|
temperature=self.share_inputs["temperature"],
|
||||||
top_p=self.share_inputs["top_p"],
|
top_p=self.share_inputs["top_p"],
|
||||||
top_k=self.share_inputs["top_k"],
|
top_k=self.share_inputs["top_k"],
|
||||||
|
min_p=self.share_inputs["min_p"],
|
||||||
step_idx=self.share_inputs["step_idx"],
|
step_idx=self.share_inputs["step_idx"],
|
||||||
pre_token_ids=self.share_inputs["pre_ids"],
|
pre_token_ids=self.share_inputs["pre_ids"],
|
||||||
|
prompt_ids=self.share_inputs["prompt_ids"],
|
||||||
|
prompt_lens=self.share_inputs["prompt_lens"],
|
||||||
frequency_penalties=self.share_inputs["frequency_score"],
|
frequency_penalties=self.share_inputs["frequency_score"],
|
||||||
presence_penalties=self.share_inputs["presence_score"],
|
presence_penalties=self.share_inputs["presence_score"],
|
||||||
repetition_penalties=self.share_inputs["penalty_score"],
|
repetition_penalties=self.share_inputs["penalty_score"],
|
||||||
@@ -507,10 +547,10 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
|
logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
|
||||||
|
|
||||||
# 4. Init proposer for speculative method
|
# 4. Init proposer for speculative method
|
||||||
self.init_speculative_proposer()
|
self._init_speculative_proposer()
|
||||||
|
|
||||||
def get_model(self) -> nn.Layer:
|
def get_model(self) -> nn.Layer:
|
||||||
"""get current model"""
|
"""Get current model"""
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def initialize_forward_meta(self):
|
def initialize_forward_meta(self):
|
||||||
@@ -528,36 +568,21 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||||
cum_offsets=self.share_inputs["cum_offsets"],
|
batch_id_per_token=self.share_inputs["batch_id_per_token"],
|
||||||
padding_offset=self.share_inputs["padding_offset"],
|
|
||||||
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
|
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
|
||||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||||
block_tables=self.share_inputs["block_tables"],
|
block_tables=self.share_inputs["block_tables"],
|
||||||
caches=self.share_inputs["caches"],
|
caches=self.share_inputs["caches"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update Batch type for cuda graph
|
||||||
|
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0)
|
||||||
|
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
|
||||||
|
|
||||||
# Initialzie attention meta data
|
# Initialzie attention meta data
|
||||||
for attn_backend in self.attn_backends:
|
for attn_backend in self.attn_backends:
|
||||||
attn_backend.init_attention_metadata(self.forward_meta)
|
attn_backend.init_attention_metadata(self.forward_meta)
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
"""Clear cached data from shared inputs and forward metadata."""
|
|
||||||
self.share_inputs.pop("caches", None)
|
|
||||||
if self.forward_meta is not None:
|
|
||||||
self.forward_meta.clear_caches()
|
|
||||||
|
|
||||||
def clear_parameters(self, pid):
|
|
||||||
""" "dynamic model loader use to clear parameters use for RL"""
|
|
||||||
self.dynamic_weight_manager.clear_parameters(pid)
|
|
||||||
self.clear_cache()
|
|
||||||
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
|
|
||||||
|
|
||||||
def update_parameters(self, pid):
|
|
||||||
""" "dynamic model loader use to update parameters use for RL"""
|
|
||||||
self.dynamic_weight_manager.update_parameters(pid)
|
|
||||||
self.initialize_kv_cache()
|
|
||||||
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
|
||||||
|
|
||||||
def initialize_kv_cache(self, profile: bool = False) -> None:
|
def initialize_kv_cache(self, profile: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize kv cache
|
Initialize kv cache
|
||||||
@@ -606,13 +631,14 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
def initialize_attn_backend(self) -> None:
|
def initialize_attn_backend(self) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize attention backends and forward metadata
|
Initialize attention backends
|
||||||
"""
|
"""
|
||||||
assert len(self.attn_backends) == 0
|
assert len(self.attn_backends) == 0
|
||||||
|
|
||||||
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size
|
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size
|
||||||
self.model_config.kv_num_heads = (
|
self.model_config.kv_num_heads = max(
|
||||||
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size
|
1,
|
||||||
|
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size,
|
||||||
)
|
)
|
||||||
head_dim = self.model_config.head_dim
|
head_dim = self.model_config.head_dim
|
||||||
|
|
||||||
@@ -642,6 +668,7 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
Args:
|
Args:
|
||||||
num_tokens:
|
num_tokens:
|
||||||
expected_decode_len: Expected number of tokens generated
|
expected_decode_len: Expected number of tokens generated
|
||||||
|
in_capturing: Is cuda graph in capturing state
|
||||||
"""
|
"""
|
||||||
self._dummy_prefill_inputs(
|
self._dummy_prefill_inputs(
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
@@ -656,20 +683,20 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
# 1. Compute real num_tokens
|
# 1. Initialize forward meta and attention meta data
|
||||||
self._prepare_inputs()
|
self._prepare_inputs()
|
||||||
|
|
||||||
# 2. Initialize attention backend and forward meta data
|
# 2. Padding inputs for cuda graph
|
||||||
|
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph
|
||||||
|
self.padding_cudagraph_inputs()
|
||||||
|
|
||||||
# 3. Prepare lora
|
# 3. Run model
|
||||||
|
|
||||||
# 4. Run model
|
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||||
forward_meta=self.forward_meta,
|
forward_meta=self.forward_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hiddden_states = rebuild_padding(
|
hidden_states = rebuild_padding(
|
||||||
model_output,
|
model_output,
|
||||||
self.share_inputs["cum_offsets"],
|
self.share_inputs["cum_offsets"],
|
||||||
self.share_inputs["seq_lens_this_time"],
|
self.share_inputs["seq_lens_this_time"],
|
||||||
@@ -681,8 +708,8 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
self.parallel_config.max_model_len,
|
self.parallel_config.max_model_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Execute spec decode
|
# 4. Execute spec decode
|
||||||
logits = self.model.compute_logits(hiddden_states)
|
logits = self.model.compute_logits(hidden_states)
|
||||||
|
|
||||||
if not self.speculative_decoding:
|
if not self.speculative_decoding:
|
||||||
set_value_by_flags_and_idx(
|
set_value_by_flags_and_idx(
|
||||||
@@ -711,7 +738,7 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
paddle.distributed.broadcast(self.share_inputs["step_idx"], 0)
|
paddle.distributed.broadcast(self.share_inputs["step_idx"], 0)
|
||||||
paddle.distributed.broadcast(self.share_inputs["stop_flags"], 0)
|
paddle.distributed.broadcast(self.share_inputs["stop_flags"], 0)
|
||||||
|
|
||||||
# 6. post process
|
# 5. post process
|
||||||
model_output_data = ModelOutputData(
|
model_output_data = ModelOutputData(
|
||||||
next_tokens=self.share_inputs["next_tokens"],
|
next_tokens=self.share_inputs["next_tokens"],
|
||||||
stop_flags=self.share_inputs["stop_flags"],
|
stop_flags=self.share_inputs["stop_flags"],
|
||||||
@@ -736,6 +763,10 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
),
|
),
|
||||||
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
||||||
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
||||||
|
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
|
||||||
|
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
|
||||||
|
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
|
||||||
|
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
||||||
)
|
)
|
||||||
|
|
||||||
post_process(
|
post_process(
|
||||||
@@ -760,11 +791,10 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
def _update_chunked_prefill(self, tasks):
|
def _update_chunked_prefill(self, tasks):
|
||||||
"""
|
"""
|
||||||
更新chunked prefill相关参数
|
Update chunked prefill related parameters
|
||||||
"""
|
"""
|
||||||
if not self.parallel_config.enable_chunked_prefill:
|
if not self.parallel_config.enable_chunked_prefill:
|
||||||
return
|
return
|
||||||
|
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
if task.get("prefill_chunk_info", None) is None:
|
if task.get("prefill_chunk_info", None) is None:
|
||||||
continue
|
continue
|
||||||
@@ -785,25 +815,22 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
del self.restore_chunked_prefill_request[task.request_id]
|
del self.restore_chunked_prefill_request[task.request_id]
|
||||||
else:
|
else:
|
||||||
token_chunk_size = task.prefill_chunk_info[task.chunk_idx]
|
token_chunk_size = task.prefill_chunk_info[task.chunk_idx]
|
||||||
|
|
||||||
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
|
|
||||||
self.share_inputs["input_ids"][idx, :token_chunk_size] = np.array(
|
self.share_inputs["input_ids"][idx, :token_chunk_size] = np.array(
|
||||||
task.prompt_token_ids[start_idx : start_idx + token_chunk_size]
|
task.prompt_token_ids[start_idx : start_idx + token_chunk_size]
|
||||||
)
|
)
|
||||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
|
||||||
self.share_inputs["step_idx"][idx : idx + 1] = 0
|
|
||||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
|
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
|
||||||
|
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
|
||||||
|
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
||||||
|
self.share_inputs["prompt_lens"][idx : idx + 1] += token_chunk_size
|
||||||
|
self.share_inputs["step_idx"][idx : idx + 1] = 0
|
||||||
|
|
||||||
if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled():
|
if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled():
|
||||||
self.proposer.update_task_chunk_prefill(task)
|
self.proposer.update_task_chunk_prefill(task)
|
||||||
task.chunk_idx += 1
|
task.chunk_idx += 1
|
||||||
|
|
||||||
def _dummy_sampler_run(self) -> paddle.Tensor:
|
|
||||||
""" """
|
|
||||||
pass
|
|
||||||
|
|
||||||
def capture_model(self) -> None:
|
def capture_model(self) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes'
|
Trigger CUDA Graph capture for all shapes in cuda graph capture list
|
||||||
"""
|
"""
|
||||||
if not self.use_cudagraph:
|
if not self.use_cudagraph:
|
||||||
logger.info("Skipping CUDA graph capture. Please check GraphOptimizationConfig")
|
logger.info("Skipping CUDA graph capture. Please check GraphOptimizationConfig")
|
||||||
@@ -813,7 +840,7 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
capture_sizes = self.cudagraph_capture_sizes.copy()
|
capture_sizes = self.cudagraph_capture_sizes.copy()
|
||||||
for batch_size in sorted(capture_sizes, reverse=True):
|
for batch_size in sorted(capture_sizes, reverse=True):
|
||||||
self._dummy_run(
|
self._dummy_run(
|
||||||
num_tokens=self.parallel_config.max_model_len,
|
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
in_capturing=True,
|
in_capturing=True,
|
||||||
expected_decode_len=expected_decode_len,
|
expected_decode_len=expected_decode_len,
|
||||||
@@ -823,7 +850,7 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
time_after_capture = time.perf_counter()
|
time_after_capture = time.perf_counter()
|
||||||
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
|
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
|
||||||
|
|
||||||
def _get_skip_idx(self, model_forward_batch):
|
def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
|
||||||
"""
|
"""
|
||||||
Get the index of the request that needs to be skipped during execution.
|
Get the index of the request that needs to be skipped during execution.
|
||||||
Args:
|
Args:
|
||||||
@@ -866,13 +893,12 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
self._execute_empty_input()
|
self._execute_empty_input()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 1. Prepare inputs of model and decoder.
|
# 1. Prepare inputs of model and sampler.
|
||||||
# sampler create async operation
|
|
||||||
skip_idx_list = self._get_skip_idx(model_forward_batch)
|
skip_idx_list = self._get_skip_idx(model_forward_batch)
|
||||||
self._prepare_inputs()
|
self._prepare_inputs()
|
||||||
self.sampler.pre_process(skip_idx_list)
|
self.sampler.pre_process(skip_idx_list)
|
||||||
|
|
||||||
# 2. Padding inputs for cuda grph
|
# 2. Padding inputs for cuda graph
|
||||||
|
|
||||||
# 3. Execute model
|
# 3. Execute model
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
@@ -880,7 +906,7 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
forward_meta=self.forward_meta,
|
forward_meta=self.forward_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hiddden_states = rebuild_padding(
|
hidden_states = rebuild_padding(
|
||||||
model_output,
|
model_output,
|
||||||
self.share_inputs["cum_offsets"],
|
self.share_inputs["cum_offsets"],
|
||||||
self.share_inputs["seq_lens_this_time"],
|
self.share_inputs["seq_lens_this_time"],
|
||||||
@@ -891,7 +917,7 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 4. Compute logits, Sample
|
# 4. Compute logits, Sample
|
||||||
logits = self.model.compute_logits(hiddden_states)
|
logits = self.model.compute_logits(hidden_states)
|
||||||
|
|
||||||
if not self.speculative_decoding:
|
if not self.speculative_decoding:
|
||||||
set_value_by_flags_and_idx(
|
set_value_by_flags_and_idx(
|
||||||
@@ -950,6 +976,10 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
),
|
),
|
||||||
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
||||||
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
||||||
|
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
|
||||||
|
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
|
||||||
|
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
|
||||||
|
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill":
|
if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill":
|
||||||
@@ -1009,7 +1039,7 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
||||||
|
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
"""Execute a forward pass with dummy inputs to profile the memory usage of the model."""
|
"""Execute a forward pass with dummy inputs to profile the memory usage of the model"""
|
||||||
|
|
||||||
# Initialize kv cache for profile run. After profile run kv cache will be reset.
|
# Initialize kv cache for profile run. After profile run kv cache will be reset.
|
||||||
self.num_gcu_blocks = self.parallel_config.total_block_num
|
self.num_gcu_blocks = self.parallel_config.total_block_num
|
||||||
@@ -1093,5 +1123,36 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
return required_memory
|
return required_memory
|
||||||
|
|
||||||
def not_need_stop(self) -> bool:
|
def not_need_stop(self) -> bool:
|
||||||
""" """
|
"""Stop decoding if the tensor meets the termination condition"""
|
||||||
return self.share_inputs["not_need_stop"][0]
|
return self.share_inputs["not_need_stop"][0]
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
"""Clear cached data from shared inputs and forward metadata"""
|
||||||
|
self.share_inputs.pop("caches", None)
|
||||||
|
if self.forward_meta is not None:
|
||||||
|
self.forward_meta.clear_caches()
|
||||||
|
|
||||||
|
def clear_parameters(self, pid):
|
||||||
|
""" " Dynamic model loader use to clear parameters use for RL"""
|
||||||
|
self.dynamic_weight_manager.clear_parameters(pid)
|
||||||
|
self.clear_cache()
|
||||||
|
paddle.device.cuda.empty_cache()
|
||||||
|
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
|
||||||
|
|
||||||
|
def update_parameters(self, pid):
|
||||||
|
""" " Dynamic model loader use to update parameters use for RL"""
|
||||||
|
self.dynamic_weight_manager.update_parameters(pid)
|
||||||
|
self.initialize_kv_cache()
|
||||||
|
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
||||||
|
|
||||||
|
def padding_cudagraph_inputs(self) -> None:
|
||||||
|
"""
|
||||||
|
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
||||||
|
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
|
||||||
|
"""
|
||||||
|
# TODO(gongshaotian): Use more efficient implementation
|
||||||
|
if self.forward_meta.step_use_cudagraph:
|
||||||
|
num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum()
|
||||||
|
for i in range(1, num_empty_batch + 1):
|
||||||
|
self.forward_meta.decoder_batch_ids[-i] = 0
|
||||||
|
self.forward_meta.decoder_tile_ids_per_batch[-i] = 0
|
||||||
|
Reference in New Issue
Block a user