diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 6a57cf472..000861470 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -933,7 +933,7 @@ class EmbeddingChatRequest(BaseModel): ) add_special_tokens: bool = Field( - default=False, + default=True, description=( "If true, special tokens (e.g. BOS) will be added to the prompt " "on top of what is added by the chat template. " diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index a27d125eb..366244e52 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -230,13 +230,16 @@ class DataProcessor(BaseDataProcessor): if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0: if request.prompt is not None: prompt = request.prompt + add_special_tokens = request.get("add_special_tokens", False) assert isinstance(prompt, str) or ( isinstance(prompt, list) and all([isinstance(t, int) for t in prompt]) ), f"prompt must be a string or a list of integers, but got {type(prompt)}" if isinstance(prompt, list): # if prompt is a token id list request.prompt_token_ids = prompt else: - request.prompt_token_ids = self.text2ids(request.prompt, max_model_len) + request.prompt_token_ids = self.text2ids( + request.prompt, max_model_len, add_special_tokens=add_special_tokens + ) elif request.messages is not None: if self.tokenizer.chat_template is None: raise ValueError("This model does not support chat_template.") @@ -305,13 +308,16 @@ class DataProcessor(BaseDataProcessor): if not request.get("prompt_token_ids"): if request.get("prompt"): prompt = request.get("prompt") + add_special_tokens = request.get("add_special_tokens", False) assert isinstance(prompt, str) or ( isinstance(prompt, list) and all([isinstance(t, int) for t in prompt]) ), f"prompt must be a string or a list of integers, but got {type(prompt)}" if isinstance(prompt, list): # if prompt is a token id list request["prompt_token_ids"] = prompt else: - request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist() + request["prompt_token_ids"] = self.text2ids( + request["prompt"], max_model_len, add_special_tokens=add_special_tokens + ).tolist() elif request.get("messages"): if self.tokenizer.chat_template is None: raise ValueError("This model does not support chat_template.") @@ -503,7 +509,7 @@ class DataProcessor(BaseDataProcessor): **kwargs, ) - def text2ids(self, text, max_model_len): + def text2ids(self, text, max_model_len, **kwargs): """ text to token ids @@ -513,6 +519,8 @@ class DataProcessor(BaseDataProcessor): Returns: List[int]: token ids list """ + + add_special_tokens = kwargs.get("add_special_tokens") if envs.FD_USE_HF_TOKENIZER: tokens = self.tokenizer( text, @@ -529,7 +537,7 @@ class DataProcessor(BaseDataProcessor): padding=True, truncation=True, max_length=max_model_len, - add_special_tokens=False, + add_special_tokens=add_special_tokens, ) return tokens["input_ids"][0] diff --git a/tests/pooling/test_Qwen3-Embedding_serving.py b/tests/pooling/test_Qwen3-Embedding_serving.py index 80a4410d9..74365bd8e 100644 --- a/tests/pooling/test_Qwen3-Embedding_serving.py +++ b/tests/pooling/test_Qwen3-Embedding_serving.py @@ -237,4 +237,97 @@ def test_single_text_embedding(embedding_api_url, headers): save_embedding_baseline(embedding, baseline_file) else: print(f"Comparing with baseline: {baseline_file}") - check_embedding_against_baseline(embedding, baseline_file, threshold=0.01) + check_embedding_against_baseline(embedding, baseline_file, threshold=0.02) + + +def test_multi_text_embedding(embedding_api_url, headers): + """Test embedding generation for batch text inputs.""" + payload = { + "model": "default", + "input": ["北京天安门在哪里?", "上海东方明珠有多高?", "杭州西湖的面积是多少?"], + } + + resp = requests.post(embedding_api_url, headers=headers, json=payload) + assert resp.status_code == 200, f"Unexpected status code: {resp.status_code}, response: {resp.text}" + + result = resp.json() + assert "data" in result, "Response missing 'data' field" + assert len(result["data"]) == 3, f"Expected 3 embedding results, got {len(result['data'])}" + + # Validate each embedding in the batch + for idx, item in enumerate(result["data"]): + assert "embedding" in item, f"Item {idx} missing 'embedding' field" + assert "index" in item, f"Item {idx} missing 'index' field" + assert item["index"] == idx, f"Item index mismatch: expected {idx}, got {item['index']}" + + embedding = item["embedding"] + assert isinstance(embedding, list), f"Embedding {idx} should be a list" + assert len(embedding) > 0, f"Embedding {idx} vector should not be empty" + assert all(isinstance(x, (int, float)) for x in embedding), f"Embedding {idx} values should be numeric" + + print(f"Text {idx} embedding dimension: {len(embedding)}") + + # Verify all embeddings have the same dimension + dimensions = [len(item["embedding"]) for item in result["data"]] + assert len(set(dimensions)) == 1, f"All embeddings should have same dimension, got: {dimensions}" + + # Compare embeddings with baseline + base_path = os.getenv("MODEL_PATH", "") + baseline_filename = "test-Qwen3-Embedding-0.6B-multi-input-baseline.json" + + if base_path: + baseline_file = os.path.join(base_path, "torch", baseline_filename) + else: + baseline_file = baseline_filename + + # Save all embeddings to baseline + batch_embeddings = [item["embedding"] for item in result["data"]] + + if not os.path.exists(baseline_file): + print("Batch baseline file not found. Saving current embeddings as baseline...") + baseline_data = { + "embeddings": batch_embeddings, + "dimension": len(batch_embeddings[0]), + "count": len(batch_embeddings), + "inputs": payload["input"], + } + with open(baseline_file, "w", encoding="utf-8") as f: + json.dump(baseline_data, f, indent=2) + print(f"Batch baseline saved to: {baseline_file}") + else: + print(f"Comparing batch with baseline: {baseline_file}") + with open(baseline_file, "r", encoding="utf-8") as f: + baseline_data = json.load(f) + baseline_embeddings = baseline_data["embeddings"] + + assert len(batch_embeddings) == len( + baseline_embeddings + ), f"Embedding count mismatch: current={len(batch_embeddings)}, baseline={len(baseline_embeddings)}" + + # Compare each embedding + for idx, (current_emb, baseline_emb) in enumerate(zip(batch_embeddings, baseline_embeddings)): + print(f"\n--- Comparing embedding {idx}: '{payload['input'][idx]}' ---") + mean_abs_diff = compare_embeddings(current_emb, baseline_emb, threshold=0.05) + + if mean_abs_diff >= 0.05: + # Save current batch for debugging + temp_file = f"{baseline_file}.current" + print("temp_file", temp_file) + with open(temp_file, "w", encoding="utf-8") as f: + json.dump( + { + "embeddings": batch_embeddings, + "dimension": len(batch_embeddings[0]), + "count": len(batch_embeddings), + "inputs": payload["input"], + }, + f, + indent=2, + ) + + raise AssertionError( + f"Embedding {idx} differs from baseline by too much " + f"(mean_abs_diff={mean_abs_diff:.6f} >= 0.01):\n" + f"Current batch saved to: {temp_file}\n" + f"Please check the differences." + )