mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	support openai client
This commit is contained in:
		| @@ -1,5 +1,5 @@ | ||||
|  | ||||
| <h1 align="center"><b><em>飞桨大模型高性能部署工具FastDeploy</em></b></h1> | ||||
| <h1 align="center"><b><em>FastDeploy大模型服务化部署</em></b></h1> | ||||
|  | ||||
| *FastDeploy基于英伟达Triton框架专为服务器场景的大模型服务化部署而设计的解决方案。它提供了支持gRPC、HTTP协议的服务接口,以及流式Token输出能力。底层推理引擎支持连续批处理、weight only int8、后训练量化(PTQ)等加速优化策略,为用户带来易用且高性能的部署体验。* | ||||
|  | ||||
|   | ||||
| @@ -4,7 +4,7 @@ WORKDIR /opt/output/ | ||||
| COPY ./server/ /opt/output/Serving/ | ||||
| COPY ./client/ /opt/output/client/ | ||||
|  | ||||
| ENV LD_LIBRARY_PATH "/usr/local/cuda-11.8/compat/:$LD_LIBRARY_PATH" | ||||
| ENV LD_LIBRARY_PATH="/usr/local/cuda-11.8/compat/:$LD_LIBRARY_PATH" | ||||
|  | ||||
| RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple | ||||
| RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/ \ | ||||
| @@ -15,7 +15,7 @@ RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc | ||||
|     && python3 setup_cuda.py build && python3 setup_cuda.py install --user \ | ||||
|     && cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \ | ||||
|     && cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \ | ||||
|     && rm -rf PaddleNLP | ||||
|     && rm -rf /opt/output/PaddleNLP | ||||
|  | ||||
| RUN cd /opt/output/client && pip install -r requirements.txt && pip install . | ||||
|  | ||||
| @@ -30,7 +30,5 @@ RUN cd /opt/output/Serving/ \ | ||||
|     && cp scripts/start_server.sh . && cp scripts/stop_server.sh . \ | ||||
|     && rm -rf scripts | ||||
|  | ||||
| RUN python3 -m pip install protobuf==3.20.0 | ||||
|  | ||||
| ENV http_proxy "" | ||||
| ENV https_proxy "" | ||||
| ENV http_proxy="" | ||||
| ENV https_proxy="" | ||||
|   | ||||
| @@ -4,7 +4,7 @@ WORKDIR /opt/output/ | ||||
| COPY ./server/ /opt/output/Serving/ | ||||
| COPY ./client/ /opt/output/client/ | ||||
|  | ||||
| ENV LD_LIBRARY_PATH "/usr/local/cuda-12.3/compat/:$LD_LIBRARY_PATH" | ||||
| ENV LD_LIBRARY_PATH="/usr/local/cuda-12.3/compat/:$LD_LIBRARY_PATH" | ||||
|  | ||||
| RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple | ||||
| RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ \ | ||||
| @@ -15,7 +15,7 @@ RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc | ||||
|     && python3 setup_cuda.py build && python3 setup_cuda.py install --user \ | ||||
|     && cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \ | ||||
|     && cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \ | ||||
|     && rm -rf PaddleNLP | ||||
|     && rm -rf /opt/output/PaddleNLP | ||||
|  | ||||
| RUN cd /opt/output/client && pip install -r requirements.txt && pip install . | ||||
|  | ||||
| @@ -30,7 +30,5 @@ RUN cd /opt/output/Serving/ \ | ||||
|     && cp scripts/start_server.sh . && cp scripts/stop_server.sh . \ | ||||
|     && rm -rf scripts | ||||
|  | ||||
| RUN python3 -m pip install protobuf==3.20.0 | ||||
|  | ||||
| ENV http_proxy "" | ||||
| ENV https_proxy "" | ||||
| ENV http_proxy="" | ||||
| ENV https_proxy="" | ||||
|   | ||||
| @@ -66,7 +66,7 @@ ls /fastdeploy/models/ | ||||
| git clone https://github.com/PaddlePaddle/FastDeploy.git | ||||
| cd FastDeploy/llm | ||||
|  | ||||
| docker build -f ./dockerfiles/Dockerfile_serving_cuda123_cudnn9 -t llm-serving-cu123-self . | ||||
| docker build --network=host -f ./dockerfiles/Dockerfile_serving_cuda123_cudnn9 -t llm-serving-cu123-self . | ||||
| ``` | ||||
|  | ||||
| 创建自己的镜像后,可以基于该镜像[创建容器](#创建容器) | ||||
| @@ -196,6 +196,77 @@ for line in res.iter_lines(): | ||||
|     如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0 | ||||
| ``` | ||||
|  | ||||
| ### OpenAI 客户端 | ||||
|  | ||||
| 我们提供了 OpenAI 客户端的支持,使用方法如下: | ||||
|  | ||||
| 提示:使用 OpenAI 客户端需要配置 `PUSH_MODE_HTTP_PORT`! | ||||
|  | ||||
| ``` | ||||
| import openai | ||||
|  | ||||
| client = openai.Client(base_url="http://127.0.0.1:{PUSH_MODE_HTTP_PORT}/v1/chat/completions", api_key="EMPTY_API_KEY") | ||||
|  | ||||
| # 非流式返回 | ||||
| response = client.completions.create( | ||||
| 	model="default", | ||||
| 	prompt="Hello, how are you?", | ||||
|   max_tokens=50, | ||||
|   stream=False, | ||||
| ) | ||||
|  | ||||
| print(response) | ||||
| print("\n") | ||||
|  | ||||
| # 流式返回 | ||||
| response = client.completions.create( | ||||
| 	model="default", | ||||
| 	prompt="Hello, how are you?", | ||||
|   max_tokens=100, | ||||
|   stream=True, | ||||
| ) | ||||
|  | ||||
| for chunk in response: | ||||
|   if chunk.choices[0] is not None: | ||||
|     print(chunk.choices[0].text, end='') | ||||
| print("\n") | ||||
|  | ||||
| # Chat completion | ||||
| # 非流式返回 | ||||
| response = client.chat.completions.create( | ||||
|     model="default", | ||||
|     messages=[ | ||||
|         {"role": "user", "content": "Hello, who are you"}, | ||||
|         {"role": "system", "content": "I'm a helpful AI assistant."}, | ||||
|         {"role": "user", "content": "List 3 countries and their capitals."}, | ||||
|     ], | ||||
|     temperature=0, | ||||
|     max_tokens=64, | ||||
|     stream=False, | ||||
| ) | ||||
|  | ||||
| print(response) | ||||
| print("\n") | ||||
|  | ||||
| # 流式返回 | ||||
| response = client.chat.completions.create( | ||||
|     model="default", | ||||
|     messages=[ | ||||
|         {"role": "user", "content": "Hello, who are you"}, | ||||
|         {"role": "system", "content": "I'm a helpful AI assistant."}, | ||||
|         {"role": "user", "content": "List 3 countries and their capitals."}, | ||||
|     ], | ||||
|     temperature=0, | ||||
|     max_tokens=64, | ||||
|     stream=True, | ||||
| ) | ||||
|  | ||||
| for chunk in response: | ||||
|   if chunk.choices[0].delta is not None: | ||||
|     print(chunk.choices[0].delta.content, end='') | ||||
| print("\n") | ||||
| ``` | ||||
|  | ||||
| ## 模型配置参数介绍 | ||||
|  | ||||
| | 字段名 | 字段类型 | 说明 | 是否必填 | 默认值 | 备注 | | ||||
|   | ||||
| @@ -1,5 +1,4 @@ | ||||
| # model server | ||||
| paddlenlp==2.7.2 | ||||
| sentencepiece | ||||
| pycryptodome | ||||
| tritonclient[all]==2.41.1 | ||||
| @@ -10,7 +9,7 @@ transformers | ||||
| # http server | ||||
| fastapi | ||||
| httpx | ||||
| openai==1.9.0 | ||||
| openai==1.44.1 | ||||
| asyncio | ||||
| uvicorn | ||||
| shortuuid | ||||
| @@ -20,4 +19,3 @@ pynvml | ||||
|  | ||||
| # paddlenlp | ||||
| tiktoken | ||||
| transformers | ||||
|   | ||||
| @@ -6,8 +6,7 @@ export PYTHONIOENCODING=utf8 | ||||
| export LC_ALL=C.UTF-8 | ||||
|  | ||||
| # PaddlePaddle environment variables | ||||
| export FLAGS_allocator_strategy=naive_best_fit | ||||
| export FLAGS_fraction_of_gpu_memory_to_use=0.96 | ||||
| export FLAGS_allocator_strategy=auto_growth | ||||
| export FLAGS_dynamic_static_unified_comm=0 | ||||
| export FLAGS_use_xqa_optim=1 | ||||
| export FLAGS_gemm_use_half_precision_compute_type=0 | ||||
|   | ||||
| @@ -40,8 +40,6 @@ def check_basic_params(req_dict): | ||||
|             error_msg.append("The `input_ids` in input parameters must be a list") | ||||
|         if "messages" in req_dict: | ||||
|             msg_len = len(req_dict["messages"]) | ||||
|             if msg_len % 2 == 0: | ||||
|                 error_msg.append(f"The number of the message {msg_len} must be odd") | ||||
|             if not all("content" in item for item in req_dict["messages"]): | ||||
|                 error_msg.append("The item in messages must include `content`") | ||||
|  | ||||
|   | ||||
| @@ -125,8 +125,8 @@ class DataProcessor(BaseDataProcessor): | ||||
|  | ||||
|         self.decode_status = dict() | ||||
|         self.tokenizer = self._load_tokenizer() | ||||
|         data_processor_logger.info(f"tokenizer infomation: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, "+ | ||||
|                     f"eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id}, ") | ||||
|         data_processor_logger.info(f"tokenizer infomation: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \ | ||||
|                                 eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id} ") | ||||
|  | ||||
|     def process_request(self, request, max_seq_len=None): | ||||
|         """ | ||||
| @@ -143,14 +143,19 @@ class DataProcessor(BaseDataProcessor): | ||||
|             request["eos_token_ids"] = [] | ||||
|         request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config)) | ||||
|  | ||||
|         if "input_ids" in request: | ||||
|             input_ids = request["input_ids"] | ||||
|         else: | ||||
|             input_ids = self.text2ids(request['text']) | ||||
|         if "input_ids" not in request or \ | ||||
|             (isinstance(request["input_ids"], (list, tuple)) and len(request["input_ids"]) == 0): | ||||
|             if "text" in request: | ||||
|                 request["input_ids"] = self.text2ids(request["text"]) | ||||
|             elif "messages" in request: | ||||
|                 if self.tokenizer.chat_template is None: | ||||
|                     raise ValueError(f"This model does not support chat_template.") | ||||
|                 request["input_ids"] = self.messages2ids(request["messages"]) | ||||
|             else: | ||||
|                 raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.") | ||||
|  | ||||
|         if max_seq_len is not None and len(input_ids) > max_seq_len: | ||||
|             input_ids = input_ids[:max_seq_len-1] | ||||
|         request["input_ids"] = input_ids | ||||
|         if max_seq_len is not None and len(request["input_ids"]) > max_seq_len: | ||||
|             request["input_ids"] = request["input_ids"][:max_seq_len-1] | ||||
|         data_processor_logger.info(f"processed request: {request}") | ||||
|         return request | ||||
|  | ||||
| @@ -221,7 +226,8 @@ class DataProcessor(BaseDataProcessor): | ||||
|         Returns: | ||||
|             List[int]: ID sequences | ||||
|         """ | ||||
|         return | ||||
|         message_result = self.tokenizer.apply_chat_template(messages, return_tensors="pd") | ||||
|         return message_result["input_ids"][0] | ||||
|  | ||||
|     def ids2tokens(self, token_id, task_id): | ||||
|         """ | ||||
|   | ||||
							
								
								
									
										103
									
								
								llm/server/server/http_server/adapter_openai.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								llm/server/server/http_server/adapter_openai.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | ||||
| import time | ||||
| import json | ||||
| import queue | ||||
|  | ||||
| import numpy as np | ||||
| from typing import Dict | ||||
| from datetime import datetime | ||||
| from functools import partial | ||||
|  | ||||
| import tritonclient.grpc as grpcclient | ||||
| from tritonclient import utils as triton_utils | ||||
| from openai.types.completion_usage import CompletionUsage | ||||
| from openai.types.completion_choice import CompletionChoice | ||||
| from openai.types.completion import Completion | ||||
| from openai.types.chat.chat_completion_chunk import ( | ||||
|     ChoiceDelta, | ||||
|     ChatCompletionChunk, | ||||
|     Choice as ChatCompletionChoice | ||||
| ) | ||||
|  | ||||
| from server.http_server.api import Req, chat_completion_generator | ||||
| from server.utils import http_server_logger | ||||
|  | ||||
|  | ||||
| def format_openai_message_completions(req: Req, result: Dict) -> Completion: | ||||
|     choice_data = CompletionChoice( | ||||
|                 index=0, | ||||
|                 text=result['token'], | ||||
|                 finish_reason=result.get("finish_reason", "stop"), | ||||
|             ) | ||||
|     chunk = Completion( | ||||
|                 id=req.req_id, | ||||
|                 choices=[choice_data], | ||||
|                 model=req.model, | ||||
|                 created=int(time.time()), | ||||
|                 object="text_completion", | ||||
|                 usage=CompletionUsage( | ||||
|                     completion_tokens=result["usage"]["completion_tokens"], | ||||
|                     prompt_tokens=result["usage"]["prompt_tokens"], | ||||
|                     total_tokens=result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"], | ||||
|                 ), | ||||
|             ) | ||||
|     return chunk.model_dump_json(exclude_unset=True) | ||||
|  | ||||
|  | ||||
| def format_openai_message_chat_completions(req: Req, result: Dict) -> ChatCompletionChunk: | ||||
|     choice_data = ChatCompletionChoice( | ||||
|                 index=0, | ||||
|                 delta=ChoiceDelta( | ||||
|                     content=result['token'], | ||||
|                     role="assistant", | ||||
|                 ), | ||||
|                 finish_reason=result.get("finish_reason", "stop"), | ||||
|             ) | ||||
|     chunk = ChatCompletionChunk( | ||||
|                 id=req.req_id, | ||||
|                 choices=[choice_data], | ||||
|                 model=req.model, | ||||
|                 created=int(time.time()), | ||||
|                 object="chat.completion.chunk", | ||||
|                 usage=CompletionUsage( | ||||
|                     completion_tokens=result["usage"]["completion_tokens"], | ||||
|                     prompt_tokens=result["usage"]["prompt_tokens"], | ||||
|                     total_tokens=result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"], | ||||
|                 ), | ||||
|             ) | ||||
|     return chunk.model_dump_json(exclude_unset=True) | ||||
|  | ||||
|  | ||||
| def openai_chat_commpletion_generator(infer_grpc_url: str, req: Req, chat_interface: bool) -> Dict: | ||||
|  | ||||
|     def _openai_format_resp(resp_dict): | ||||
|         return f"data: {resp_dict}\n\n" | ||||
|  | ||||
|     for resp in chat_completion_generator(infer_grpc_url, req, yield_json=False): | ||||
|         if resp.get("is_end") == 1: | ||||
|             yield _openai_format_resp("[DONE]") | ||||
|  | ||||
|         if chat_interface: | ||||
|             yield _openai_format_resp(format_openai_message_chat_completions(req, resp)) | ||||
|         else: | ||||
|             yield _openai_format_resp(format_openai_message_completions(req, resp)) | ||||
|  | ||||
|  | ||||
| def openai_chat_completion_result(infer_grpc_url: str, req: Req, chat_interface: bool): | ||||
|     result = "" | ||||
|     error_resp = None | ||||
|     for resp in chat_completion_generator(infer_grpc_url, req, yield_json=False): | ||||
|         if resp.get("error_msg") or resp.get("error_code"): | ||||
|             error_resp = resp | ||||
|             error_resp["result"] = "" | ||||
|         else: | ||||
|             result += resp.get("token") | ||||
|         usage = resp.get("usage", None) | ||||
|  | ||||
|     if error_resp: | ||||
|         return error_resp | ||||
|     response = {'token': result, 'error_msg': '', 'error_code': 0, 'usage': usage} | ||||
|  | ||||
|     if chat_interface: | ||||
|         return format_openai_message_chat_completions(req, response) | ||||
|     else: | ||||
|         return format_openai_message_completions(req, response) | ||||
| @@ -16,6 +16,7 @@ import json | ||||
| import queue | ||||
| import time | ||||
| import uuid | ||||
| import shortuuid | ||||
| from datetime import datetime | ||||
| from functools import partial | ||||
| from typing import Dict, List, Optional | ||||
| @@ -46,6 +47,7 @@ class Req(BaseModel): | ||||
|     return_usage: Optional[bool] = False | ||||
|     stream: bool = False | ||||
|     timeout: int = 300 | ||||
|     model: str = None | ||||
|  | ||||
|     def to_dict_for_infer(self): | ||||
|         """ | ||||
| @@ -54,14 +56,37 @@ class Req(BaseModel): | ||||
|         Returns: | ||||
|             dict: request parameters in dict format | ||||
|         """ | ||||
|         self.compatible_with_OpenAI() | ||||
|  | ||||
|         req_dict = {} | ||||
|         for key, value in self.dict().items(): | ||||
|             if value is not None: | ||||
|                 req_dict[key] = value | ||||
|         return req_dict | ||||
|  | ||||
|     def load_openai_request(self, request_dict: dict): | ||||
|         """ | ||||
|         Convert openai request to Req | ||||
|         official OpenAI API documentation: https://platform.openai.com/docs/api-reference/completions/create | ||||
|         """ | ||||
|         convert_dict = { | ||||
|             "text": "prompt", | ||||
|             "frequency_score": "frequency_penalty", | ||||
|             "max_dec_len": "max_tokens", | ||||
|             "stream": "stream", | ||||
|             "return_all_tokens": "best_of", | ||||
|             "temperature": "temperature", | ||||
|             "topp": "top_p", | ||||
|             "presence_score": "presence_penalty", | ||||
|             "eos_token_ids": "stop", | ||||
|             "req_id": "id", | ||||
|             "model": "model", | ||||
|             "messages": "messages", | ||||
|         } | ||||
|  | ||||
|         self.__setattr__("req_id", f"chatcmpl-{shortuuid.random()}") | ||||
|         for key, value in convert_dict.items(): | ||||
|             if request_dict.get(value, None): | ||||
|                 self.__setattr__(key, request_dict.get(value)) | ||||
|  | ||||
|  | ||||
| def chat_completion_generator(infer_grpc_url: str, req: Req, yield_json: bool) -> Dict: | ||||
|     """ | ||||
|   | ||||
| @@ -16,10 +16,14 @@ import argparse | ||||
| import os | ||||
|  | ||||
| import uvicorn | ||||
| from fastapi import FastAPI | ||||
| from typing import Dict | ||||
| from fastapi import FastAPI, Request | ||||
| from fastapi.responses import StreamingResponse | ||||
| from server.http_server.api import (Req, chat_completion_generator, | ||||
|                                     chat_completion_result) | ||||
| from server.http_server.adapter_openai import ( | ||||
|     openai_chat_commpletion_generator, openai_chat_completion_result | ||||
| ) | ||||
| from server.utils import http_server_logger | ||||
|  | ||||
| http_server_logger.info(f"create fastapi app...") | ||||
| @@ -58,6 +62,48 @@ def create_chat_completion(req: Req): | ||||
|         return resp | ||||
|  | ||||
|  | ||||
| @app.post("/v1/chat/completions/completions") | ||||
| def openai_v1_completions(request: Dict): | ||||
|     return create_openai_completion(request, chat_interface=False) | ||||
|  | ||||
|  | ||||
| @app.post("/v1/chat/completions/chat/completions") | ||||
| def openai_v1_chat_completions(request: Dict): | ||||
|     return create_openai_completion(request, chat_interface=True) | ||||
|  | ||||
|  | ||||
| def create_openai_completion(request: Dict, chat_interface: bool): | ||||
|     try: | ||||
|         req = Req() | ||||
|         req.load_openai_request(request) | ||||
|     except Exception as e: | ||||
|         return {"error_msg": "request body is not a valid json format", "error_code": 400, "result": ''} | ||||
|  | ||||
|     try: | ||||
|         http_server_logger.info(f"receive request: {req.req_id}") | ||||
|  | ||||
|         grpc_port = int(os.getenv("GRPC_PORT", 0)) | ||||
|         if grpc_port == 0: | ||||
|             return {"error_msg": f"GRPC_PORT ({grpc_port}) for infer service is invalid", | ||||
|                     "error_code": 400} | ||||
|         grpc_url = f"localhost:{grpc_port}" | ||||
|  | ||||
|         if req.stream: | ||||
|             generator = openai_chat_commpletion_generator( | ||||
|                                 infer_grpc_url=grpc_url, | ||||
|                                 req=req, | ||||
|                                 chat_interface=chat_interface, | ||||
|                             ) | ||||
|             resp = StreamingResponse(generator, media_type="text/event-stream") | ||||
|         else: | ||||
|             resp = openai_chat_completion_result(infer_grpc_url=grpc_url, req=req, chat_interface=chat_interface) | ||||
|     except Exception as e: | ||||
|         resp = {'error_msg': str(e), 'error_code': 501} | ||||
|     finally: | ||||
|         http_server_logger.info(f"finish request: {req.req_id}") | ||||
|         return resp | ||||
|  | ||||
|  | ||||
| def launch_http_server(port: int, workers: int) -> None: | ||||
|     """ | ||||
|     launch http server | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 kevin
					kevin