diff --git a/test/layers/test_sampler.py b/test/layers/test_sampler.py index 65a6bfbe6..c2fb69018 100644 --- a/test/layers/test_sampler.py +++ b/test/layers/test_sampler.py @@ -56,6 +56,7 @@ def _create_default_sampling_metadata( min_dec_lens=paddle.full(shape=[batch_size, 1], fill_value=min_seq_len, dtype="int64"), bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"), eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"), + min_p=paddle.randn([batch_size]), ) return fake_sampling_metadata