[XPU] fix text_image_gather_scatter when image_token_num == token_num && text_token_num == 1 (#4881)

This commit is contained in:
Lucas
2025-11-07 16:35:49 +08:00
committed by GitHub
parent 71bbedaf50
commit 3b0bdbae65

View File

@@ -169,7 +169,11 @@ int text_image_gather_scatter(
WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids);
WRAPPER_CHECK_PTR(ctx, int, token_num, text_index);
WRAPPER_CHECK_PTR(ctx, int, token_num, image_index);
WRAPPER_ASSERT_EQ(ctx, token_num, text_token_num + image_token_num);
// When all tokens are image type, text_token_num will be set to 1 in model.py
WRAPPER_ASSERT_EQ(ctx,
true,
(text_token_num + image_token_num == token_num) ||
(image_token_num == token_num && text_token_num == 1));
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(ctx,