mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Bug fix] Fix the multi-input accuracy issue in the pooling model. (#5374)
* fix multi-inputs * fix threshold * fix threshold * fix
This commit is contained in:
@@ -933,7 +933,7 @@ class EmbeddingChatRequest(BaseModel):
|
||||
)
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
default=True,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
|
||||
@@ -230,13 +230,16 @@ class DataProcessor(BaseDataProcessor):
|
||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||
if request.prompt is not None:
|
||||
prompt = request.prompt
|
||||
add_special_tokens = request.get("add_special_tokens", False)
|
||||
assert isinstance(prompt, str) or (
|
||||
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
|
||||
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
|
||||
if isinstance(prompt, list): # if prompt is a token id list
|
||||
request.prompt_token_ids = prompt
|
||||
else:
|
||||
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
|
||||
request.prompt_token_ids = self.text2ids(
|
||||
request.prompt, max_model_len, add_special_tokens=add_special_tokens
|
||||
)
|
||||
elif request.messages is not None:
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
@@ -305,13 +308,16 @@ class DataProcessor(BaseDataProcessor):
|
||||
if not request.get("prompt_token_ids"):
|
||||
if request.get("prompt"):
|
||||
prompt = request.get("prompt")
|
||||
add_special_tokens = request.get("add_special_tokens", False)
|
||||
assert isinstance(prompt, str) or (
|
||||
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
|
||||
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
|
||||
if isinstance(prompt, list): # if prompt is a token id list
|
||||
request["prompt_token_ids"] = prompt
|
||||
else:
|
||||
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
|
||||
request["prompt_token_ids"] = self.text2ids(
|
||||
request["prompt"], max_model_len, add_special_tokens=add_special_tokens
|
||||
).tolist()
|
||||
elif request.get("messages"):
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
@@ -503,7 +509,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def text2ids(self, text, max_model_len):
|
||||
def text2ids(self, text, max_model_len, **kwargs):
|
||||
"""
|
||||
text to token ids
|
||||
|
||||
@@ -513,6 +519,8 @@ class DataProcessor(BaseDataProcessor):
|
||||
Returns:
|
||||
List[int]: token ids list
|
||||
"""
|
||||
|
||||
add_special_tokens = kwargs.get("add_special_tokens")
|
||||
if envs.FD_USE_HF_TOKENIZER:
|
||||
tokens = self.tokenizer(
|
||||
text,
|
||||
@@ -529,7 +537,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_model_len,
|
||||
add_special_tokens=False,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
|
||||
return tokens["input_ids"][0]
|
||||
|
||||
@@ -237,4 +237,97 @@ def test_single_text_embedding(embedding_api_url, headers):
|
||||
save_embedding_baseline(embedding, baseline_file)
|
||||
else:
|
||||
print(f"Comparing with baseline: {baseline_file}")
|
||||
check_embedding_against_baseline(embedding, baseline_file, threshold=0.01)
|
||||
check_embedding_against_baseline(embedding, baseline_file, threshold=0.02)
|
||||
|
||||
|
||||
def test_multi_text_embedding(embedding_api_url, headers):
|
||||
"""Test embedding generation for batch text inputs."""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"input": ["北京天安门在哪里?", "上海东方明珠有多高?", "杭州西湖的面积是多少?"],
|
||||
}
|
||||
|
||||
resp = requests.post(embedding_api_url, headers=headers, json=payload)
|
||||
assert resp.status_code == 200, f"Unexpected status code: {resp.status_code}, response: {resp.text}"
|
||||
|
||||
result = resp.json()
|
||||
assert "data" in result, "Response missing 'data' field"
|
||||
assert len(result["data"]) == 3, f"Expected 3 embedding results, got {len(result['data'])}"
|
||||
|
||||
# Validate each embedding in the batch
|
||||
for idx, item in enumerate(result["data"]):
|
||||
assert "embedding" in item, f"Item {idx} missing 'embedding' field"
|
||||
assert "index" in item, f"Item {idx} missing 'index' field"
|
||||
assert item["index"] == idx, f"Item index mismatch: expected {idx}, got {item['index']}"
|
||||
|
||||
embedding = item["embedding"]
|
||||
assert isinstance(embedding, list), f"Embedding {idx} should be a list"
|
||||
assert len(embedding) > 0, f"Embedding {idx} vector should not be empty"
|
||||
assert all(isinstance(x, (int, float)) for x in embedding), f"Embedding {idx} values should be numeric"
|
||||
|
||||
print(f"Text {idx} embedding dimension: {len(embedding)}")
|
||||
|
||||
# Verify all embeddings have the same dimension
|
||||
dimensions = [len(item["embedding"]) for item in result["data"]]
|
||||
assert len(set(dimensions)) == 1, f"All embeddings should have same dimension, got: {dimensions}"
|
||||
|
||||
# Compare embeddings with baseline
|
||||
base_path = os.getenv("MODEL_PATH", "")
|
||||
baseline_filename = "test-Qwen3-Embedding-0.6B-multi-input-baseline.json"
|
||||
|
||||
if base_path:
|
||||
baseline_file = os.path.join(base_path, "torch", baseline_filename)
|
||||
else:
|
||||
baseline_file = baseline_filename
|
||||
|
||||
# Save all embeddings to baseline
|
||||
batch_embeddings = [item["embedding"] for item in result["data"]]
|
||||
|
||||
if not os.path.exists(baseline_file):
|
||||
print("Batch baseline file not found. Saving current embeddings as baseline...")
|
||||
baseline_data = {
|
||||
"embeddings": batch_embeddings,
|
||||
"dimension": len(batch_embeddings[0]),
|
||||
"count": len(batch_embeddings),
|
||||
"inputs": payload["input"],
|
||||
}
|
||||
with open(baseline_file, "w", encoding="utf-8") as f:
|
||||
json.dump(baseline_data, f, indent=2)
|
||||
print(f"Batch baseline saved to: {baseline_file}")
|
||||
else:
|
||||
print(f"Comparing batch with baseline: {baseline_file}")
|
||||
with open(baseline_file, "r", encoding="utf-8") as f:
|
||||
baseline_data = json.load(f)
|
||||
baseline_embeddings = baseline_data["embeddings"]
|
||||
|
||||
assert len(batch_embeddings) == len(
|
||||
baseline_embeddings
|
||||
), f"Embedding count mismatch: current={len(batch_embeddings)}, baseline={len(baseline_embeddings)}"
|
||||
|
||||
# Compare each embedding
|
||||
for idx, (current_emb, baseline_emb) in enumerate(zip(batch_embeddings, baseline_embeddings)):
|
||||
print(f"\n--- Comparing embedding {idx}: '{payload['input'][idx]}' ---")
|
||||
mean_abs_diff = compare_embeddings(current_emb, baseline_emb, threshold=0.05)
|
||||
|
||||
if mean_abs_diff >= 0.05:
|
||||
# Save current batch for debugging
|
||||
temp_file = f"{baseline_file}.current"
|
||||
print("temp_file", temp_file)
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
"embeddings": batch_embeddings,
|
||||
"dimension": len(batch_embeddings[0]),
|
||||
"count": len(batch_embeddings),
|
||||
"inputs": payload["input"],
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
raise AssertionError(
|
||||
f"Embedding {idx} differs from baseline by too much "
|
||||
f"(mean_abs_diff={mean_abs_diff:.6f} >= 0.01):\n"
|
||||
f"Current batch saved to: {temp_file}\n"
|
||||
f"Please check the differences."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user