diff --git a/tests/model_loader/test_load_mtp.py b/tests/model_loader/test_load_mtp.py new file mode 100644 index 000000000..446956246 --- /dev/null +++ b/tests/model_loader/test_load_mtp.py @@ -0,0 +1,81 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import sys +import unittest +from unittest.mock import Mock + +import numpy as np +import paddle +import paddle.distributed.fleet as fleet + +from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding +from fastdeploy.model_executor.models.ernie4_5_mtp import Ernie4_5_MTPForCausalLM + +sys.path.append("../") +from utils import get_default_test_fd_config + +strategy = fleet.DistributedStrategy() +fleet.init(strategy=strategy) + + +class TestErnie4_5_MTPLoadWeights(unittest.TestCase): + def setUp(self): + self.fd_config = get_default_test_fd_config() + self.fd_config.speculative_config = Mock() + self.fd_config.speculative_config.sharing_model = Mock() + self.fd_config.speculative_config.sharing_model.ernie = Mock() + self.fd_config.parallel_config.tp_group = None + self.fd_config.speculative_config.sharing_model.ernie.embed_tokens = VocabParallelEmbedding( + fd_config=self.fd_config, + num_embeddings=self.fd_config.model_config.vocab_size, + embedding_dim=self.fd_config.model_config.hidden_size, + params_dtype=paddle.get_default_dtype, + prefix=("embed_tokens"), + ) + self.fd_config.speculative_config.sharing_model.ernie.lm_head = Mock() + self.model = Ernie4_5_MTPForCausalLM(self.fd_config) + + def test_load_weights_normal_case(self): + weights_iterator = [ + ("ernie.embed_tokens.weight", np.random.rand(32000, 768).astype("float32")), + ("ernie.mtp_block.0.self_attn.qkv_proj.weight", np.random.rand(768, 768 * 3).astype("float32")), + ] + for k, v in self.model.named_parameters(): + print("{}".format(k)) + + self.model.load_weights(iter(weights_iterator)) + + self.assertTrue(np.allclose(self.model.ernie.embed_tokens.embeddings.weight.numpy(), weights_iterator[0][1])) + + def test_load_weights_with_unexpected_keys(self): + weights_iterator = [ + ("unknown_key", np.random.rand(10, 10).astype("float32")), + ("ernie.embed_tokens.weight", np.random.rand(32000, 768).astype("float32")), + ] + + self.model.load_weights(iter(weights_iterator)) + + self.assertTrue(np.allclose(self.model.ernie.embed_tokens.embeddings.weight.numpy(), weights_iterator[1][1])) + + def test_load_weights_empty_iterator(self): + weights_iterator = [] + + self.model.load_weights(iter(weights_iterator)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..048410fc9 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,61 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from unittest.mock import Mock + +from fastdeploy.config import ( + CacheConfig, + FDConfig, + GraphOptimizationConfig, + ParallelConfig, +) + + +class FakeModelConfig: + def __init__(self): + self.hidden_size = 768 + self.intermediate_size = 768 + self.num_hidden_layers = 12 + self.num_attention_heads = 12 + self.rms_norm_eps = 1e-6 + self.tie_word_embeddings = True + self.ori_vocab_size = 32000 + self.moe_layer_start_index = 8 + self.pretrained_config = Mock() + self.pretrained_config.prefix_name = "test" + self.num_key_value_heads = 1 + self.head_dim = 1 + self.is_quantized = False + self.hidden_act = "relu" + self.vocab_size = 32000 + self.hidden_dropout_prob = 0.1 + self.initializer_range = 0.02 + self.max_position_embeddings = 512 + self.tie_word_embeddings = True + self.model_format = "auto" + + +def get_default_test_fd_config(): + graph_opt_config = GraphOptimizationConfig(args={}) + parallel_config = ParallelConfig(args={}) + parallel_config.max_num_seqs = 1 + parallel_config.data_parallel_rank = 1 + cache_config = CacheConfig({}) + fd_config = FDConfig( + graph_opt_config=graph_opt_config, parallel_config=parallel_config, cache_config=cache_config, test_mode=True + ) + fd_config.model_config = FakeModelConfig() + return fd_config