mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	support openai client
This commit is contained in:
		| @@ -1,5 +1,5 @@ | |||||||
|  |  | ||||||
| <h1 align="center"><b><em>飞桨大模型高性能部署工具FastDeploy</em></b></h1> | <h1 align="center"><b><em>FastDeploy大模型服务化部署</em></b></h1> | ||||||
|  |  | ||||||
| *FastDeploy基于英伟达Triton框架专为服务器场景的大模型服务化部署而设计的解决方案。它提供了支持gRPC、HTTP协议的服务接口,以及流式Token输出能力。底层推理引擎支持连续批处理、weight only int8、后训练量化(PTQ)等加速优化策略,为用户带来易用且高性能的部署体验。* | *FastDeploy基于英伟达Triton框架专为服务器场景的大模型服务化部署而设计的解决方案。它提供了支持gRPC、HTTP协议的服务接口,以及流式Token输出能力。底层推理引擎支持连续批处理、weight only int8、后训练量化(PTQ)等加速优化策略,为用户带来易用且高性能的部署体验。* | ||||||
|  |  | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ WORKDIR /opt/output/ | |||||||
| COPY ./server/ /opt/output/Serving/ | COPY ./server/ /opt/output/Serving/ | ||||||
| COPY ./client/ /opt/output/client/ | COPY ./client/ /opt/output/client/ | ||||||
|  |  | ||||||
| ENV LD_LIBRARY_PATH "/usr/local/cuda-11.8/compat/:$LD_LIBRARY_PATH" | ENV LD_LIBRARY_PATH="/usr/local/cuda-11.8/compat/:$LD_LIBRARY_PATH" | ||||||
|  |  | ||||||
| RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple | RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple | ||||||
| RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/ \ | RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/ \ | ||||||
| @@ -15,7 +15,7 @@ RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc | |||||||
|     && python3 setup_cuda.py build && python3 setup_cuda.py install --user \ |     && python3 setup_cuda.py build && python3 setup_cuda.py install --user \ | ||||||
|     && cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \ |     && cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \ | ||||||
|     && cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \ |     && cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \ | ||||||
|     && rm -rf PaddleNLP |     && rm -rf /opt/output/PaddleNLP | ||||||
|  |  | ||||||
| RUN cd /opt/output/client && pip install -r requirements.txt && pip install . | RUN cd /opt/output/client && pip install -r requirements.txt && pip install . | ||||||
|  |  | ||||||
| @@ -30,7 +30,5 @@ RUN cd /opt/output/Serving/ \ | |||||||
|     && cp scripts/start_server.sh . && cp scripts/stop_server.sh . \ |     && cp scripts/start_server.sh . && cp scripts/stop_server.sh . \ | ||||||
|     && rm -rf scripts |     && rm -rf scripts | ||||||
|  |  | ||||||
| RUN python3 -m pip install protobuf==3.20.0 | ENV http_proxy="" | ||||||
|  | ENV https_proxy="" | ||||||
| ENV http_proxy "" |  | ||||||
| ENV https_proxy "" |  | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ WORKDIR /opt/output/ | |||||||
| COPY ./server/ /opt/output/Serving/ | COPY ./server/ /opt/output/Serving/ | ||||||
| COPY ./client/ /opt/output/client/ | COPY ./client/ /opt/output/client/ | ||||||
|  |  | ||||||
| ENV LD_LIBRARY_PATH "/usr/local/cuda-12.3/compat/:$LD_LIBRARY_PATH" | ENV LD_LIBRARY_PATH="/usr/local/cuda-12.3/compat/:$LD_LIBRARY_PATH" | ||||||
|  |  | ||||||
| RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple | RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple | ||||||
| RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ \ | RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ \ | ||||||
| @@ -15,7 +15,7 @@ RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc | |||||||
|     && python3 setup_cuda.py build && python3 setup_cuda.py install --user \ |     && python3 setup_cuda.py build && python3 setup_cuda.py install --user \ | ||||||
|     && cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \ |     && cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \ | ||||||
|     && cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \ |     && cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \ | ||||||
|     && rm -rf PaddleNLP |     && rm -rf /opt/output/PaddleNLP | ||||||
|  |  | ||||||
| RUN cd /opt/output/client && pip install -r requirements.txt && pip install . | RUN cd /opt/output/client && pip install -r requirements.txt && pip install . | ||||||
|  |  | ||||||
| @@ -30,7 +30,5 @@ RUN cd /opt/output/Serving/ \ | |||||||
|     && cp scripts/start_server.sh . && cp scripts/stop_server.sh . \ |     && cp scripts/start_server.sh . && cp scripts/stop_server.sh . \ | ||||||
|     && rm -rf scripts |     && rm -rf scripts | ||||||
|  |  | ||||||
| RUN python3 -m pip install protobuf==3.20.0 | ENV http_proxy="" | ||||||
|  | ENV https_proxy="" | ||||||
| ENV http_proxy "" |  | ||||||
| ENV https_proxy "" |  | ||||||
|   | |||||||
| @@ -66,7 +66,7 @@ ls /fastdeploy/models/ | |||||||
| git clone https://github.com/PaddlePaddle/FastDeploy.git | git clone https://github.com/PaddlePaddle/FastDeploy.git | ||||||
| cd FastDeploy/llm | cd FastDeploy/llm | ||||||
|  |  | ||||||
| docker build -f ./dockerfiles/Dockerfile_serving_cuda123_cudnn9 -t llm-serving-cu123-self . | docker build --network=host -f ./dockerfiles/Dockerfile_serving_cuda123_cudnn9 -t llm-serving-cu123-self . | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| 创建自己的镜像后,可以基于该镜像[创建容器](#创建容器) | 创建自己的镜像后,可以基于该镜像[创建容器](#创建容器) | ||||||
| @@ -196,6 +196,77 @@ for line in res.iter_lines(): | |||||||
|     如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0 |     如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
|  | ### OpenAI 客户端 | ||||||
|  |  | ||||||
|  | 我们提供了 OpenAI 客户端的支持,使用方法如下: | ||||||
|  |  | ||||||
|  | 提示:使用 OpenAI 客户端需要配置 `PUSH_MODE_HTTP_PORT`! | ||||||
|  |  | ||||||
|  | ``` | ||||||
|  | import openai | ||||||
|  |  | ||||||
|  | client = openai.Client(base_url="http://127.0.0.1:{PUSH_MODE_HTTP_PORT}/v1/chat/completions", api_key="EMPTY_API_KEY") | ||||||
|  |  | ||||||
|  | # 非流式返回 | ||||||
|  | response = client.completions.create( | ||||||
|  | 	model="default", | ||||||
|  | 	prompt="Hello, how are you?", | ||||||
|  |   max_tokens=50, | ||||||
|  |   stream=False, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | print(response) | ||||||
|  | print("\n") | ||||||
|  |  | ||||||
|  | # 流式返回 | ||||||
|  | response = client.completions.create( | ||||||
|  | 	model="default", | ||||||
|  | 	prompt="Hello, how are you?", | ||||||
|  |   max_tokens=100, | ||||||
|  |   stream=True, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | for chunk in response: | ||||||
|  |   if chunk.choices[0] is not None: | ||||||
|  |     print(chunk.choices[0].text, end='') | ||||||
|  | print("\n") | ||||||
|  |  | ||||||
|  | # Chat completion | ||||||
|  | # 非流式返回 | ||||||
|  | response = client.chat.completions.create( | ||||||
|  |     model="default", | ||||||
|  |     messages=[ | ||||||
|  |         {"role": "user", "content": "Hello, who are you"}, | ||||||
|  |         {"role": "system", "content": "I'm a helpful AI assistant."}, | ||||||
|  |         {"role": "user", "content": "List 3 countries and their capitals."}, | ||||||
|  |     ], | ||||||
|  |     temperature=0, | ||||||
|  |     max_tokens=64, | ||||||
|  |     stream=False, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | print(response) | ||||||
|  | print("\n") | ||||||
|  |  | ||||||
|  | # 流式返回 | ||||||
|  | response = client.chat.completions.create( | ||||||
|  |     model="default", | ||||||
|  |     messages=[ | ||||||
|  |         {"role": "user", "content": "Hello, who are you"}, | ||||||
|  |         {"role": "system", "content": "I'm a helpful AI assistant."}, | ||||||
|  |         {"role": "user", "content": "List 3 countries and their capitals."}, | ||||||
|  |     ], | ||||||
|  |     temperature=0, | ||||||
|  |     max_tokens=64, | ||||||
|  |     stream=True, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | for chunk in response: | ||||||
|  |   if chunk.choices[0].delta is not None: | ||||||
|  |     print(chunk.choices[0].delta.content, end='') | ||||||
|  | print("\n") | ||||||
|  | ``` | ||||||
|  |  | ||||||
| ## 模型配置参数介绍 | ## 模型配置参数介绍 | ||||||
|  |  | ||||||
| | 字段名 | 字段类型 | 说明 | 是否必填 | 默认值 | 备注 | | | 字段名 | 字段类型 | 说明 | 是否必填 | 默认值 | 备注 | | ||||||
|   | |||||||
| @@ -1,5 +1,4 @@ | |||||||
| # model server | # model server | ||||||
| paddlenlp==2.7.2 |  | ||||||
| sentencepiece | sentencepiece | ||||||
| pycryptodome | pycryptodome | ||||||
| tritonclient[all]==2.41.1 | tritonclient[all]==2.41.1 | ||||||
| @@ -10,7 +9,7 @@ transformers | |||||||
| # http server | # http server | ||||||
| fastapi | fastapi | ||||||
| httpx | httpx | ||||||
| openai==1.9.0 | openai==1.44.1 | ||||||
| asyncio | asyncio | ||||||
| uvicorn | uvicorn | ||||||
| shortuuid | shortuuid | ||||||
| @@ -20,4 +19,3 @@ pynvml | |||||||
|  |  | ||||||
| # paddlenlp | # paddlenlp | ||||||
| tiktoken | tiktoken | ||||||
| transformers |  | ||||||
|   | |||||||
| @@ -6,8 +6,7 @@ export PYTHONIOENCODING=utf8 | |||||||
| export LC_ALL=C.UTF-8 | export LC_ALL=C.UTF-8 | ||||||
|  |  | ||||||
| # PaddlePaddle environment variables | # PaddlePaddle environment variables | ||||||
| export FLAGS_allocator_strategy=naive_best_fit | export FLAGS_allocator_strategy=auto_growth | ||||||
| export FLAGS_fraction_of_gpu_memory_to_use=0.96 |  | ||||||
| export FLAGS_dynamic_static_unified_comm=0 | export FLAGS_dynamic_static_unified_comm=0 | ||||||
| export FLAGS_use_xqa_optim=1 | export FLAGS_use_xqa_optim=1 | ||||||
| export FLAGS_gemm_use_half_precision_compute_type=0 | export FLAGS_gemm_use_half_precision_compute_type=0 | ||||||
|   | |||||||
| @@ -40,8 +40,6 @@ def check_basic_params(req_dict): | |||||||
|             error_msg.append("The `input_ids` in input parameters must be a list") |             error_msg.append("The `input_ids` in input parameters must be a list") | ||||||
|         if "messages" in req_dict: |         if "messages" in req_dict: | ||||||
|             msg_len = len(req_dict["messages"]) |             msg_len = len(req_dict["messages"]) | ||||||
|             if msg_len % 2 == 0: |  | ||||||
|                 error_msg.append(f"The number of the message {msg_len} must be odd") |  | ||||||
|             if not all("content" in item for item in req_dict["messages"]): |             if not all("content" in item for item in req_dict["messages"]): | ||||||
|                 error_msg.append("The item in messages must include `content`") |                 error_msg.append("The item in messages must include `content`") | ||||||
|  |  | ||||||
|   | |||||||
| @@ -125,8 +125,8 @@ class DataProcessor(BaseDataProcessor): | |||||||
|  |  | ||||||
|         self.decode_status = dict() |         self.decode_status = dict() | ||||||
|         self.tokenizer = self._load_tokenizer() |         self.tokenizer = self._load_tokenizer() | ||||||
|         data_processor_logger.info(f"tokenizer infomation: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, "+ |         data_processor_logger.info(f"tokenizer infomation: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \ | ||||||
|                     f"eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id}, ") |                                 eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id} ") | ||||||
|  |  | ||||||
|     def process_request(self, request, max_seq_len=None): |     def process_request(self, request, max_seq_len=None): | ||||||
|         """ |         """ | ||||||
| @@ -143,14 +143,19 @@ class DataProcessor(BaseDataProcessor): | |||||||
|             request["eos_token_ids"] = [] |             request["eos_token_ids"] = [] | ||||||
|         request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config)) |         request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config)) | ||||||
|  |  | ||||||
|         if "input_ids" in request: |         if "input_ids" not in request or \ | ||||||
|             input_ids = request["input_ids"] |             (isinstance(request["input_ids"], (list, tuple)) and len(request["input_ids"]) == 0): | ||||||
|         else: |             if "text" in request: | ||||||
|             input_ids = self.text2ids(request['text']) |                 request["input_ids"] = self.text2ids(request["text"]) | ||||||
|  |             elif "messages" in request: | ||||||
|  |                 if self.tokenizer.chat_template is None: | ||||||
|  |                     raise ValueError(f"This model does not support chat_template.") | ||||||
|  |                 request["input_ids"] = self.messages2ids(request["messages"]) | ||||||
|  |             else: | ||||||
|  |                 raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.") | ||||||
|  |  | ||||||
|         if max_seq_len is not None and len(input_ids) > max_seq_len: |         if max_seq_len is not None and len(request["input_ids"]) > max_seq_len: | ||||||
|             input_ids = input_ids[:max_seq_len-1] |             request["input_ids"] = request["input_ids"][:max_seq_len-1] | ||||||
|         request["input_ids"] = input_ids |  | ||||||
|         data_processor_logger.info(f"processed request: {request}") |         data_processor_logger.info(f"processed request: {request}") | ||||||
|         return request |         return request | ||||||
|  |  | ||||||
| @@ -221,7 +226,8 @@ class DataProcessor(BaseDataProcessor): | |||||||
|         Returns: |         Returns: | ||||||
|             List[int]: ID sequences |             List[int]: ID sequences | ||||||
|         """ |         """ | ||||||
|         return |         message_result = self.tokenizer.apply_chat_template(messages, return_tensors="pd") | ||||||
|  |         return message_result["input_ids"][0] | ||||||
|  |  | ||||||
|     def ids2tokens(self, token_id, task_id): |     def ids2tokens(self, token_id, task_id): | ||||||
|         """ |         """ | ||||||
|   | |||||||
							
								
								
									
										103
									
								
								llm/server/server/http_server/adapter_openai.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								llm/server/server/http_server/adapter_openai.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | |||||||
|  | import time | ||||||
|  | import json | ||||||
|  | import queue | ||||||
|  |  | ||||||
|  | import numpy as np | ||||||
|  | from typing import Dict | ||||||
|  | from datetime import datetime | ||||||
|  | from functools import partial | ||||||
|  |  | ||||||
|  | import tritonclient.grpc as grpcclient | ||||||
|  | from tritonclient import utils as triton_utils | ||||||
|  | from openai.types.completion_usage import CompletionUsage | ||||||
|  | from openai.types.completion_choice import CompletionChoice | ||||||
|  | from openai.types.completion import Completion | ||||||
|  | from openai.types.chat.chat_completion_chunk import ( | ||||||
|  |     ChoiceDelta, | ||||||
|  |     ChatCompletionChunk, | ||||||
|  |     Choice as ChatCompletionChoice | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | from server.http_server.api import Req, chat_completion_generator | ||||||
|  | from server.utils import http_server_logger | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def format_openai_message_completions(req: Req, result: Dict) -> Completion: | ||||||
|  |     choice_data = CompletionChoice( | ||||||
|  |                 index=0, | ||||||
|  |                 text=result['token'], | ||||||
|  |                 finish_reason=result.get("finish_reason", "stop"), | ||||||
|  |             ) | ||||||
|  |     chunk = Completion( | ||||||
|  |                 id=req.req_id, | ||||||
|  |                 choices=[choice_data], | ||||||
|  |                 model=req.model, | ||||||
|  |                 created=int(time.time()), | ||||||
|  |                 object="text_completion", | ||||||
|  |                 usage=CompletionUsage( | ||||||
|  |                     completion_tokens=result["usage"]["completion_tokens"], | ||||||
|  |                     prompt_tokens=result["usage"]["prompt_tokens"], | ||||||
|  |                     total_tokens=result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"], | ||||||
|  |                 ), | ||||||
|  |             ) | ||||||
|  |     return chunk.model_dump_json(exclude_unset=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def format_openai_message_chat_completions(req: Req, result: Dict) -> ChatCompletionChunk: | ||||||
|  |     choice_data = ChatCompletionChoice( | ||||||
|  |                 index=0, | ||||||
|  |                 delta=ChoiceDelta( | ||||||
|  |                     content=result['token'], | ||||||
|  |                     role="assistant", | ||||||
|  |                 ), | ||||||
|  |                 finish_reason=result.get("finish_reason", "stop"), | ||||||
|  |             ) | ||||||
|  |     chunk = ChatCompletionChunk( | ||||||
|  |                 id=req.req_id, | ||||||
|  |                 choices=[choice_data], | ||||||
|  |                 model=req.model, | ||||||
|  |                 created=int(time.time()), | ||||||
|  |                 object="chat.completion.chunk", | ||||||
|  |                 usage=CompletionUsage( | ||||||
|  |                     completion_tokens=result["usage"]["completion_tokens"], | ||||||
|  |                     prompt_tokens=result["usage"]["prompt_tokens"], | ||||||
|  |                     total_tokens=result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"], | ||||||
|  |                 ), | ||||||
|  |             ) | ||||||
|  |     return chunk.model_dump_json(exclude_unset=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def openai_chat_commpletion_generator(infer_grpc_url: str, req: Req, chat_interface: bool) -> Dict: | ||||||
|  |  | ||||||
|  |     def _openai_format_resp(resp_dict): | ||||||
|  |         return f"data: {resp_dict}\n\n" | ||||||
|  |  | ||||||
|  |     for resp in chat_completion_generator(infer_grpc_url, req, yield_json=False): | ||||||
|  |         if resp.get("is_end") == 1: | ||||||
|  |             yield _openai_format_resp("[DONE]") | ||||||
|  |  | ||||||
|  |         if chat_interface: | ||||||
|  |             yield _openai_format_resp(format_openai_message_chat_completions(req, resp)) | ||||||
|  |         else: | ||||||
|  |             yield _openai_format_resp(format_openai_message_completions(req, resp)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def openai_chat_completion_result(infer_grpc_url: str, req: Req, chat_interface: bool): | ||||||
|  |     result = "" | ||||||
|  |     error_resp = None | ||||||
|  |     for resp in chat_completion_generator(infer_grpc_url, req, yield_json=False): | ||||||
|  |         if resp.get("error_msg") or resp.get("error_code"): | ||||||
|  |             error_resp = resp | ||||||
|  |             error_resp["result"] = "" | ||||||
|  |         else: | ||||||
|  |             result += resp.get("token") | ||||||
|  |         usage = resp.get("usage", None) | ||||||
|  |  | ||||||
|  |     if error_resp: | ||||||
|  |         return error_resp | ||||||
|  |     response = {'token': result, 'error_msg': '', 'error_code': 0, 'usage': usage} | ||||||
|  |  | ||||||
|  |     if chat_interface: | ||||||
|  |         return format_openai_message_chat_completions(req, response) | ||||||
|  |     else: | ||||||
|  |         return format_openai_message_completions(req, response) | ||||||
| @@ -16,6 +16,7 @@ import json | |||||||
| import queue | import queue | ||||||
| import time | import time | ||||||
| import uuid | import uuid | ||||||
|  | import shortuuid | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
| from functools import partial | from functools import partial | ||||||
| from typing import Dict, List, Optional | from typing import Dict, List, Optional | ||||||
| @@ -46,6 +47,7 @@ class Req(BaseModel): | |||||||
|     return_usage: Optional[bool] = False |     return_usage: Optional[bool] = False | ||||||
|     stream: bool = False |     stream: bool = False | ||||||
|     timeout: int = 300 |     timeout: int = 300 | ||||||
|  |     model: str = None | ||||||
|  |  | ||||||
|     def to_dict_for_infer(self): |     def to_dict_for_infer(self): | ||||||
|         """ |         """ | ||||||
| @@ -54,14 +56,37 @@ class Req(BaseModel): | |||||||
|         Returns: |         Returns: | ||||||
|             dict: request parameters in dict format |             dict: request parameters in dict format | ||||||
|         """ |         """ | ||||||
|         self.compatible_with_OpenAI() |  | ||||||
|  |  | ||||||
|         req_dict = {} |         req_dict = {} | ||||||
|         for key, value in self.dict().items(): |         for key, value in self.dict().items(): | ||||||
|             if value is not None: |             if value is not None: | ||||||
|                 req_dict[key] = value |                 req_dict[key] = value | ||||||
|         return req_dict |         return req_dict | ||||||
|  |  | ||||||
|  |     def load_openai_request(self, request_dict: dict): | ||||||
|  |         """ | ||||||
|  |         Convert openai request to Req | ||||||
|  |         official OpenAI API documentation: https://platform.openai.com/docs/api-reference/completions/create | ||||||
|  |         """ | ||||||
|  |         convert_dict = { | ||||||
|  |             "text": "prompt", | ||||||
|  |             "frequency_score": "frequency_penalty", | ||||||
|  |             "max_dec_len": "max_tokens", | ||||||
|  |             "stream": "stream", | ||||||
|  |             "return_all_tokens": "best_of", | ||||||
|  |             "temperature": "temperature", | ||||||
|  |             "topp": "top_p", | ||||||
|  |             "presence_score": "presence_penalty", | ||||||
|  |             "eos_token_ids": "stop", | ||||||
|  |             "req_id": "id", | ||||||
|  |             "model": "model", | ||||||
|  |             "messages": "messages", | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         self.__setattr__("req_id", f"chatcmpl-{shortuuid.random()}") | ||||||
|  |         for key, value in convert_dict.items(): | ||||||
|  |             if request_dict.get(value, None): | ||||||
|  |                 self.__setattr__(key, request_dict.get(value)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def chat_completion_generator(infer_grpc_url: str, req: Req, yield_json: bool) -> Dict: | def chat_completion_generator(infer_grpc_url: str, req: Req, yield_json: bool) -> Dict: | ||||||
|     """ |     """ | ||||||
|   | |||||||
| @@ -16,10 +16,14 @@ import argparse | |||||||
| import os | import os | ||||||
|  |  | ||||||
| import uvicorn | import uvicorn | ||||||
| from fastapi import FastAPI | from typing import Dict | ||||||
|  | from fastapi import FastAPI, Request | ||||||
| from fastapi.responses import StreamingResponse | from fastapi.responses import StreamingResponse | ||||||
| from server.http_server.api import (Req, chat_completion_generator, | from server.http_server.api import (Req, chat_completion_generator, | ||||||
|                                     chat_completion_result) |                                     chat_completion_result) | ||||||
|  | from server.http_server.adapter_openai import ( | ||||||
|  |     openai_chat_commpletion_generator, openai_chat_completion_result | ||||||
|  | ) | ||||||
| from server.utils import http_server_logger | from server.utils import http_server_logger | ||||||
|  |  | ||||||
| http_server_logger.info(f"create fastapi app...") | http_server_logger.info(f"create fastapi app...") | ||||||
| @@ -58,6 +62,48 @@ def create_chat_completion(req: Req): | |||||||
|         return resp |         return resp | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.post("/v1/chat/completions/completions") | ||||||
|  | def openai_v1_completions(request: Dict): | ||||||
|  |     return create_openai_completion(request, chat_interface=False) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.post("/v1/chat/completions/chat/completions") | ||||||
|  | def openai_v1_chat_completions(request: Dict): | ||||||
|  |     return create_openai_completion(request, chat_interface=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def create_openai_completion(request: Dict, chat_interface: bool): | ||||||
|  |     try: | ||||||
|  |         req = Req() | ||||||
|  |         req.load_openai_request(request) | ||||||
|  |     except Exception as e: | ||||||
|  |         return {"error_msg": "request body is not a valid json format", "error_code": 400, "result": ''} | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         http_server_logger.info(f"receive request: {req.req_id}") | ||||||
|  |  | ||||||
|  |         grpc_port = int(os.getenv("GRPC_PORT", 0)) | ||||||
|  |         if grpc_port == 0: | ||||||
|  |             return {"error_msg": f"GRPC_PORT ({grpc_port}) for infer service is invalid", | ||||||
|  |                     "error_code": 400} | ||||||
|  |         grpc_url = f"localhost:{grpc_port}" | ||||||
|  |  | ||||||
|  |         if req.stream: | ||||||
|  |             generator = openai_chat_commpletion_generator( | ||||||
|  |                                 infer_grpc_url=grpc_url, | ||||||
|  |                                 req=req, | ||||||
|  |                                 chat_interface=chat_interface, | ||||||
|  |                             ) | ||||||
|  |             resp = StreamingResponse(generator, media_type="text/event-stream") | ||||||
|  |         else: | ||||||
|  |             resp = openai_chat_completion_result(infer_grpc_url=grpc_url, req=req, chat_interface=chat_interface) | ||||||
|  |     except Exception as e: | ||||||
|  |         resp = {'error_msg': str(e), 'error_code': 501} | ||||||
|  |     finally: | ||||||
|  |         http_server_logger.info(f"finish request: {req.req_id}") | ||||||
|  |         return resp | ||||||
|  |  | ||||||
|  |  | ||||||
| def launch_http_server(port: int, workers: int) -> None: | def launch_http_server(port: int, workers: int) -> None: | ||||||
|     """ |     """ | ||||||
|     launch http server |     launch http server | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 kevin
					kevin