diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_gather_scatter.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_gather_scatter.cpp index f719ed9fe..328d806f2 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_gather_scatter.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_gather_scatter.cpp @@ -169,7 +169,11 @@ int text_image_gather_scatter( WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids); WRAPPER_CHECK_PTR(ctx, int, token_num, text_index); WRAPPER_CHECK_PTR(ctx, int, token_num, image_index); - WRAPPER_ASSERT_EQ(ctx, token_num, text_token_num + image_token_num); + // When all tokens are image type, text_token_num will be set to 1 in model.py + WRAPPER_ASSERT_EQ(ctx, + true, + (text_token_num + image_token_num == token_num) || + (image_token_num == token_num && text_token_num == 1)); if (ctx->dev().type() == api::kCPU) { return cpu_wrapper(ctx,