[Bug fix] Fix the multi-input accuracy issue in the pooling model. (#5374)

* fix multi-inputs

* fix threshold

* fix threshold

* fix
This commit is contained in:
lizexu123
2025-12-05 20:18:17 +08:00
committed by GitHub
parent 96d2d4877b
commit d4979347ca
3 changed files with 107 additions and 6 deletions

View File

@@ -933,7 +933,7 @@ class EmbeddingChatRequest(BaseModel):
)
add_special_tokens: bool = Field(
default=False,
default=True,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "

View File

@@ -230,13 +230,16 @@ class DataProcessor(BaseDataProcessor):
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is not None:
prompt = request.prompt
add_special_tokens = request.get("add_special_tokens", False)
assert isinstance(prompt, str) or (
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
if isinstance(prompt, list): # if prompt is a token id list
request.prompt_token_ids = prompt
else:
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
request.prompt_token_ids = self.text2ids(
request.prompt, max_model_len, add_special_tokens=add_special_tokens
)
elif request.messages is not None:
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
@@ -305,13 +308,16 @@ class DataProcessor(BaseDataProcessor):
if not request.get("prompt_token_ids"):
if request.get("prompt"):
prompt = request.get("prompt")
add_special_tokens = request.get("add_special_tokens", False)
assert isinstance(prompt, str) or (
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
if isinstance(prompt, list): # if prompt is a token id list
request["prompt_token_ids"] = prompt
else:
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
request["prompt_token_ids"] = self.text2ids(
request["prompt"], max_model_len, add_special_tokens=add_special_tokens
).tolist()
elif request.get("messages"):
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
@@ -503,7 +509,7 @@ class DataProcessor(BaseDataProcessor):
**kwargs,
)
def text2ids(self, text, max_model_len):
def text2ids(self, text, max_model_len, **kwargs):
"""
text to token ids
@@ -513,6 +519,8 @@ class DataProcessor(BaseDataProcessor):
Returns:
List[int]: token ids list
"""
add_special_tokens = kwargs.get("add_special_tokens")
if envs.FD_USE_HF_TOKENIZER:
tokens = self.tokenizer(
text,
@@ -529,7 +537,7 @@ class DataProcessor(BaseDataProcessor):
padding=True,
truncation=True,
max_length=max_model_len,
add_special_tokens=False,
add_special_tokens=add_special_tokens,
)
return tokens["input_ids"][0]

View File

@@ -237,4 +237,97 @@ def test_single_text_embedding(embedding_api_url, headers):
save_embedding_baseline(embedding, baseline_file)
else:
print(f"Comparing with baseline: {baseline_file}")
check_embedding_against_baseline(embedding, baseline_file, threshold=0.01)
check_embedding_against_baseline(embedding, baseline_file, threshold=0.02)
def test_multi_text_embedding(embedding_api_url, headers):
"""Test embedding generation for batch text inputs."""
payload = {
"model": "default",
"input": ["北京天安门在哪里?", "上海东方明珠有多高?", "杭州西湖的面积是多少?"],
}
resp = requests.post(embedding_api_url, headers=headers, json=payload)
assert resp.status_code == 200, f"Unexpected status code: {resp.status_code}, response: {resp.text}"
result = resp.json()
assert "data" in result, "Response missing 'data' field"
assert len(result["data"]) == 3, f"Expected 3 embedding results, got {len(result['data'])}"
# Validate each embedding in the batch
for idx, item in enumerate(result["data"]):
assert "embedding" in item, f"Item {idx} missing 'embedding' field"
assert "index" in item, f"Item {idx} missing 'index' field"
assert item["index"] == idx, f"Item index mismatch: expected {idx}, got {item['index']}"
embedding = item["embedding"]
assert isinstance(embedding, list), f"Embedding {idx} should be a list"
assert len(embedding) > 0, f"Embedding {idx} vector should not be empty"
assert all(isinstance(x, (int, float)) for x in embedding), f"Embedding {idx} values should be numeric"
print(f"Text {idx} embedding dimension: {len(embedding)}")
# Verify all embeddings have the same dimension
dimensions = [len(item["embedding"]) for item in result["data"]]
assert len(set(dimensions)) == 1, f"All embeddings should have same dimension, got: {dimensions}"
# Compare embeddings with baseline
base_path = os.getenv("MODEL_PATH", "")
baseline_filename = "test-Qwen3-Embedding-0.6B-multi-input-baseline.json"
if base_path:
baseline_file = os.path.join(base_path, "torch", baseline_filename)
else:
baseline_file = baseline_filename
# Save all embeddings to baseline
batch_embeddings = [item["embedding"] for item in result["data"]]
if not os.path.exists(baseline_file):
print("Batch baseline file not found. Saving current embeddings as baseline...")
baseline_data = {
"embeddings": batch_embeddings,
"dimension": len(batch_embeddings[0]),
"count": len(batch_embeddings),
"inputs": payload["input"],
}
with open(baseline_file, "w", encoding="utf-8") as f:
json.dump(baseline_data, f, indent=2)
print(f"Batch baseline saved to: {baseline_file}")
else:
print(f"Comparing batch with baseline: {baseline_file}")
with open(baseline_file, "r", encoding="utf-8") as f:
baseline_data = json.load(f)
baseline_embeddings = baseline_data["embeddings"]
assert len(batch_embeddings) == len(
baseline_embeddings
), f"Embedding count mismatch: current={len(batch_embeddings)}, baseline={len(baseline_embeddings)}"
# Compare each embedding
for idx, (current_emb, baseline_emb) in enumerate(zip(batch_embeddings, baseline_embeddings)):
print(f"\n--- Comparing embedding {idx}: '{payload['input'][idx]}' ---")
mean_abs_diff = compare_embeddings(current_emb, baseline_emb, threshold=0.05)
if mean_abs_diff >= 0.05:
# Save current batch for debugging
temp_file = f"{baseline_file}.current"
print("temp_file", temp_file)
with open(temp_file, "w", encoding="utf-8") as f:
json.dump(
{
"embeddings": batch_embeddings,
"dimension": len(batch_embeddings[0]),
"count": len(batch_embeddings),
"inputs": payload["input"],
},
f,
indent=2,
)
raise AssertionError(
f"Embedding {idx} differs from baseline by too much "
f"(mean_abs_diff={mean_abs_diff:.6f} >= 0.01):\n"
f"Current batch saved to: {temp_file}\n"
f"Please check the differences."
)