mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-21 15:49:31 +08:00
[Excutor] Increase buffer size to prevent address corruption; add forward metadata debug tool (#3404)
* 修复buffer申请不够大,增加打印forwardmetadata的工具 * fix mistake * Make CPU tensor in CPUPlace * Add test about forward_meta_str and Add unitest_requirement --------- Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
@@ -289,7 +289,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
|||||||
kv_tile_ids_per_batch =
|
kv_tile_ids_per_batch =
|
||||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
kv_num_blocks_x_cpu =
|
kv_num_blocks_x_cpu =
|
||||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (max_just_dec_len_this_time > 0) {
|
if (max_just_dec_len_this_time > 0) {
|
||||||
|
@@ -114,6 +114,39 @@ class ForwardMeta:
|
|||||||
if self.caches:
|
if self.caches:
|
||||||
del self.caches
|
del self.caches
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns a concise string representation of the ForwardMeta object in a compact format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def format_str(obj):
|
||||||
|
"""
|
||||||
|
A helper function to recursively get a concise string representation of objects.
|
||||||
|
"""
|
||||||
|
if obj is None:
|
||||||
|
return "None"
|
||||||
|
elif isinstance(obj, paddle.Tensor):
|
||||||
|
tensor_info = {
|
||||||
|
"data_ptr": obj.data_ptr(),
|
||||||
|
"shape": obj.shape,
|
||||||
|
"dtype": str(obj.dtype),
|
||||||
|
"place": str(obj.place),
|
||||||
|
}
|
||||||
|
return tensor_info
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
return [format_str(item) for item in obj]
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {key: format_str(value) for key, value in obj.items()}
|
||||||
|
elif not isinstance(obj, (int, float, str, bool)) and hasattr(obj, "__dict__"):
|
||||||
|
info = {key: format_str(value) for key, value in obj.__dict__.items() if not key.startswith("_")}
|
||||||
|
return f"<{obj.__class__.__name__} object info: {info}>"
|
||||||
|
else:
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
|
simplified_info = format_str(self.__dict__)
|
||||||
|
lines = [f" {key}: {value}" for key, value in simplified_info.items()]
|
||||||
|
return "{\n" + ",\n".join(lines) + "\n}"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class XPUForwardMeta(ForwardMeta):
|
class XPUForwardMeta(ForwardMeta):
|
||||||
|
@@ -681,9 +681,11 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
dtype="int64",
|
dtype="int64",
|
||||||
)
|
)
|
||||||
self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["batch_id_per_token"] = paddle.full(
|
||||||
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
[max_num_seqs * self.parallel_config.max_model_len, 1], 0, dtype="int32"
|
||||||
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
)
|
||||||
|
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs + 1, 1], 0, dtype="int32")
|
||||||
|
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs + 1, 1], 0, dtype="int32")
|
||||||
|
|
||||||
# Declare AttentionBackend buffers
|
# Declare AttentionBackend buffers
|
||||||
self.share_inputs["decoder_batch_ids"] = None
|
self.share_inputs["decoder_batch_ids"] = None
|
||||||
|
@@ -7,3 +7,4 @@ pytest-twisted
|
|||||||
anyio
|
anyio
|
||||||
coverage
|
coverage
|
||||||
diff-cover
|
diff-cover
|
||||||
|
partial_json_parser
|
||||||
|
106
test/model_executor/test_forward_meta_str.py
Normal file
106
test/model_executor/test_forward_meta_str.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
|
|
||||||
|
|
||||||
|
class TOYGPUModelRunner:
|
||||||
|
def __init__(self):
|
||||||
|
self.forward_meta: ForwardMeta = None
|
||||||
|
|
||||||
|
self.max_num_seqs = 64
|
||||||
|
self.max_model_len = 1024
|
||||||
|
self.pre_max_block_num = 16
|
||||||
|
# Not the tensor in real sense, just for make ForwardMeta
|
||||||
|
self.share_inputs = {}
|
||||||
|
self.share_inputs["input_ids"] = paddle.full(
|
||||||
|
[self.max_num_seqs, self.max_model_len],
|
||||||
|
0,
|
||||||
|
dtype="int64",
|
||||||
|
)
|
||||||
|
self.share_inputs["ids_remove_padding"] = paddle.full(
|
||||||
|
[self.max_num_seqs * self.max_model_len],
|
||||||
|
0,
|
||||||
|
dtype="int64",
|
||||||
|
)
|
||||||
|
self.share_inputs["decoder_batch_ids"] = None
|
||||||
|
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||||
|
self.share_inputs["decoder_num_blocks_cpu"] = None
|
||||||
|
self.share_inputs["max_len_tensor_cpu"] = None
|
||||||
|
self.share_inputs["seq_lens_encoder"] = paddle.full([self.max_num_seqs, 1], 0, dtype="int32")
|
||||||
|
self.share_inputs["seq_lens_decoder"] = paddle.full([self.max_num_seqs, 1], 0, dtype="int32")
|
||||||
|
self.share_inputs["seq_lens_this_time"] = paddle.full([self.max_num_seqs, 1], 0, dtype="int32")
|
||||||
|
self.share_inputs["batch_id_per_token"] = paddle.full(
|
||||||
|
[self.max_num_seqs * self.max_model_len, 1], 0, dtype="int32"
|
||||||
|
)
|
||||||
|
self.share_inputs["cu_seqlens_q"] = paddle.full([self.max_num_seqs + 1, 1], 0, dtype="int32")
|
||||||
|
self.share_inputs["cu_seqlens_k"] = paddle.full([self.max_num_seqs + 1, 1], 0, dtype="int32")
|
||||||
|
self.share_inputs["block_tables"] = paddle.full([self.max_num_seqs, self.pre_max_block_num], -1, dtype="int32")
|
||||||
|
self.share_inputs["caches"] = [
|
||||||
|
paddle.full([self.max_num_seqs, 4, self.max_model_len, self.pre_max_block_num], 0, dtype="int32")
|
||||||
|
] * 16
|
||||||
|
|
||||||
|
def initialize_forward_meta(self):
|
||||||
|
"""
|
||||||
|
Initialize forward meta
|
||||||
|
"""
|
||||||
|
# Ignore the attentionbackbend for simplify
|
||||||
|
self.forward_meta = ForwardMeta(
|
||||||
|
input_ids=self.share_inputs["input_ids"],
|
||||||
|
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||||
|
# rotary_embs=self.share_inputs["rope_emb"],# Ignore the rope_emb for simplify
|
||||||
|
# attn_backend=self.attn_backends[0],# Ignore the attn_backbend for simplify
|
||||||
|
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||||
|
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||||
|
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
|
||||||
|
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
|
||||||
|
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||||
|
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||||
|
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||||
|
batch_id_per_token=self.share_inputs["batch_id_per_token"],
|
||||||
|
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
|
||||||
|
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||||
|
block_tables=self.share_inputs["block_tables"],
|
||||||
|
caches=self.share_inputs["caches"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Test(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
Initialize the test environment
|
||||||
|
"""
|
||||||
|
self.runner = TOYGPUModelRunner()
|
||||||
|
|
||||||
|
def test_case(self):
|
||||||
|
"""
|
||||||
|
Check if the CustomAllreduce function works properly.
|
||||||
|
"""
|
||||||
|
print(
|
||||||
|
"in test/model_executor/test_forward_meta_str.py, forward_meta :", self.runner.forward_meta
|
||||||
|
) # Get None, Not Error
|
||||||
|
self.runner.initialize_forward_meta()
|
||||||
|
print(
|
||||||
|
"in test/model_executor/test_forward_meta_str.py, forward_meta :", self.runner.forward_meta
|
||||||
|
) # Get information
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Reference in New Issue
Block a user