mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
Add loader test for mtp (#3724)
* add test for mtp * fix unittest * fix
This commit is contained in:
81
tests/model_loader/test_load_mtp.py
Normal file
81
tests/model_loader/test_load_mtp.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.distributed.fleet as fleet
|
||||
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.models.ernie4_5_mtp import Ernie4_5_MTPForCausalLM
|
||||
|
||||
sys.path.append("../")
|
||||
from utils import get_default_test_fd_config
|
||||
|
||||
strategy = fleet.DistributedStrategy()
|
||||
fleet.init(strategy=strategy)
|
||||
|
||||
|
||||
class TestErnie4_5_MTPLoadWeights(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.fd_config = get_default_test_fd_config()
|
||||
self.fd_config.speculative_config = Mock()
|
||||
self.fd_config.speculative_config.sharing_model = Mock()
|
||||
self.fd_config.speculative_config.sharing_model.ernie = Mock()
|
||||
self.fd_config.parallel_config.tp_group = None
|
||||
self.fd_config.speculative_config.sharing_model.ernie.embed_tokens = VocabParallelEmbedding(
|
||||
fd_config=self.fd_config,
|
||||
num_embeddings=self.fd_config.model_config.vocab_size,
|
||||
embedding_dim=self.fd_config.model_config.hidden_size,
|
||||
params_dtype=paddle.get_default_dtype,
|
||||
prefix=("embed_tokens"),
|
||||
)
|
||||
self.fd_config.speculative_config.sharing_model.ernie.lm_head = Mock()
|
||||
self.model = Ernie4_5_MTPForCausalLM(self.fd_config)
|
||||
|
||||
def test_load_weights_normal_case(self):
|
||||
weights_iterator = [
|
||||
("ernie.embed_tokens.weight", np.random.rand(32000, 768).astype("float32")),
|
||||
("ernie.mtp_block.0.self_attn.qkv_proj.weight", np.random.rand(768, 768 * 3).astype("float32")),
|
||||
]
|
||||
for k, v in self.model.named_parameters():
|
||||
print("{}".format(k))
|
||||
|
||||
self.model.load_weights(iter(weights_iterator))
|
||||
|
||||
self.assertTrue(np.allclose(self.model.ernie.embed_tokens.embeddings.weight.numpy(), weights_iterator[0][1]))
|
||||
|
||||
def test_load_weights_with_unexpected_keys(self):
|
||||
weights_iterator = [
|
||||
("unknown_key", np.random.rand(10, 10).astype("float32")),
|
||||
("ernie.embed_tokens.weight", np.random.rand(32000, 768).astype("float32")),
|
||||
]
|
||||
|
||||
self.model.load_weights(iter(weights_iterator))
|
||||
|
||||
self.assertTrue(np.allclose(self.model.ernie.embed_tokens.embeddings.weight.numpy(), weights_iterator[1][1]))
|
||||
|
||||
def test_load_weights_empty_iterator(self):
|
||||
weights_iterator = []
|
||||
|
||||
self.model.load_weights(iter(weights_iterator))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
61
tests/utils.py
Normal file
61
tests/utils.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
ParallelConfig,
|
||||
)
|
||||
|
||||
|
||||
class FakeModelConfig:
|
||||
def __init__(self):
|
||||
self.hidden_size = 768
|
||||
self.intermediate_size = 768
|
||||
self.num_hidden_layers = 12
|
||||
self.num_attention_heads = 12
|
||||
self.rms_norm_eps = 1e-6
|
||||
self.tie_word_embeddings = True
|
||||
self.ori_vocab_size = 32000
|
||||
self.moe_layer_start_index = 8
|
||||
self.pretrained_config = Mock()
|
||||
self.pretrained_config.prefix_name = "test"
|
||||
self.num_key_value_heads = 1
|
||||
self.head_dim = 1
|
||||
self.is_quantized = False
|
||||
self.hidden_act = "relu"
|
||||
self.vocab_size = 32000
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.initializer_range = 0.02
|
||||
self.max_position_embeddings = 512
|
||||
self.tie_word_embeddings = True
|
||||
self.model_format = "auto"
|
||||
|
||||
|
||||
def get_default_test_fd_config():
|
||||
graph_opt_config = GraphOptimizationConfig(args={})
|
||||
parallel_config = ParallelConfig(args={})
|
||||
parallel_config.max_num_seqs = 1
|
||||
parallel_config.data_parallel_rank = 1
|
||||
cache_config = CacheConfig({})
|
||||
fd_config = FDConfig(
|
||||
graph_opt_config=graph_opt_config, parallel_config=parallel_config, cache_config=cache_config, test_mode=True
|
||||
)
|
||||
fd_config.model_config = FakeModelConfig()
|
||||
return fd_config
|
Reference in New Issue
Block a user