mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 20:32:52 +08:00
support openai client
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
<h1 align="center"><b><em>飞桨大模型高性能部署工具FastDeploy</em></b></h1>
|
<h1 align="center"><b><em>FastDeploy大模型服务化部署</em></b></h1>
|
||||||
|
|
||||||
*FastDeploy基于英伟达Triton框架专为服务器场景的大模型服务化部署而设计的解决方案。它提供了支持gRPC、HTTP协议的服务接口,以及流式Token输出能力。底层推理引擎支持连续批处理、weight only int8、后训练量化(PTQ)等加速优化策略,为用户带来易用且高性能的部署体验。*
|
*FastDeploy基于英伟达Triton框架专为服务器场景的大模型服务化部署而设计的解决方案。它提供了支持gRPC、HTTP协议的服务接口,以及流式Token输出能力。底层推理引擎支持连续批处理、weight only int8、后训练量化(PTQ)等加速优化策略,为用户带来易用且高性能的部署体验。*
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ WORKDIR /opt/output/
|
|||||||
COPY ./server/ /opt/output/Serving/
|
COPY ./server/ /opt/output/Serving/
|
||||||
COPY ./client/ /opt/output/client/
|
COPY ./client/ /opt/output/client/
|
||||||
|
|
||||||
ENV LD_LIBRARY_PATH "/usr/local/cuda-11.8/compat/:$LD_LIBRARY_PATH"
|
ENV LD_LIBRARY_PATH="/usr/local/cuda-11.8/compat/:$LD_LIBRARY_PATH"
|
||||||
|
|
||||||
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/ \
|
RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/ \
|
||||||
@@ -15,7 +15,7 @@ RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc
|
|||||||
&& python3 setup_cuda.py build && python3 setup_cuda.py install --user \
|
&& python3 setup_cuda.py build && python3 setup_cuda.py install --user \
|
||||||
&& cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \
|
&& cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \
|
||||||
&& cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \
|
&& cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \
|
||||||
&& rm -rf PaddleNLP
|
&& rm -rf /opt/output/PaddleNLP
|
||||||
|
|
||||||
RUN cd /opt/output/client && pip install -r requirements.txt && pip install .
|
RUN cd /opt/output/client && pip install -r requirements.txt && pip install .
|
||||||
|
|
||||||
@@ -30,7 +30,5 @@ RUN cd /opt/output/Serving/ \
|
|||||||
&& cp scripts/start_server.sh . && cp scripts/stop_server.sh . \
|
&& cp scripts/start_server.sh . && cp scripts/stop_server.sh . \
|
||||||
&& rm -rf scripts
|
&& rm -rf scripts
|
||||||
|
|
||||||
RUN python3 -m pip install protobuf==3.20.0
|
ENV http_proxy=""
|
||||||
|
ENV https_proxy=""
|
||||||
ENV http_proxy ""
|
|
||||||
ENV https_proxy ""
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ WORKDIR /opt/output/
|
|||||||
COPY ./server/ /opt/output/Serving/
|
COPY ./server/ /opt/output/Serving/
|
||||||
COPY ./client/ /opt/output/client/
|
COPY ./client/ /opt/output/client/
|
||||||
|
|
||||||
ENV LD_LIBRARY_PATH "/usr/local/cuda-12.3/compat/:$LD_LIBRARY_PATH"
|
ENV LD_LIBRARY_PATH="/usr/local/cuda-12.3/compat/:$LD_LIBRARY_PATH"
|
||||||
|
|
||||||
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ \
|
RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ \
|
||||||
@@ -15,7 +15,7 @@ RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc
|
|||||||
&& python3 setup_cuda.py build && python3 setup_cuda.py install --user \
|
&& python3 setup_cuda.py build && python3 setup_cuda.py install --user \
|
||||||
&& cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \
|
&& cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \
|
||||||
&& cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \
|
&& cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \
|
||||||
&& rm -rf PaddleNLP
|
&& rm -rf /opt/output/PaddleNLP
|
||||||
|
|
||||||
RUN cd /opt/output/client && pip install -r requirements.txt && pip install .
|
RUN cd /opt/output/client && pip install -r requirements.txt && pip install .
|
||||||
|
|
||||||
@@ -30,7 +30,5 @@ RUN cd /opt/output/Serving/ \
|
|||||||
&& cp scripts/start_server.sh . && cp scripts/stop_server.sh . \
|
&& cp scripts/start_server.sh . && cp scripts/stop_server.sh . \
|
||||||
&& rm -rf scripts
|
&& rm -rf scripts
|
||||||
|
|
||||||
RUN python3 -m pip install protobuf==3.20.0
|
ENV http_proxy=""
|
||||||
|
ENV https_proxy=""
|
||||||
ENV http_proxy ""
|
|
||||||
ENV https_proxy ""
|
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ ls /fastdeploy/models/
|
|||||||
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||||
cd FastDeploy/llm
|
cd FastDeploy/llm
|
||||||
|
|
||||||
docker build -f ./dockerfiles/Dockerfile_serving_cuda123_cudnn9 -t llm-serving-cu123-self .
|
docker build --network=host -f ./dockerfiles/Dockerfile_serving_cuda123_cudnn9 -t llm-serving-cu123-self .
|
||||||
```
|
```
|
||||||
|
|
||||||
创建自己的镜像后,可以基于该镜像[创建容器](#创建容器)
|
创建自己的镜像后,可以基于该镜像[创建容器](#创建容器)
|
||||||
@@ -196,6 +196,77 @@ for line in res.iter_lines():
|
|||||||
如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
|
如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### OpenAI 客户端
|
||||||
|
|
||||||
|
我们提供了 OpenAI 客户端的支持,使用方法如下:
|
||||||
|
|
||||||
|
提示:使用 OpenAI 客户端需要配置 `PUSH_MODE_HTTP_PORT`!
|
||||||
|
|
||||||
|
```
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.Client(base_url="http://127.0.0.1:{PUSH_MODE_HTTP_PORT}/v1/chat/completions", api_key="EMPTY_API_KEY")
|
||||||
|
|
||||||
|
# 非流式返回
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt="Hello, how are you?",
|
||||||
|
max_tokens=50,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
# 流式返回
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt="Hello, how are you?",
|
||||||
|
max_tokens=100,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.choices[0] is not None:
|
||||||
|
print(chunk.choices[0].text, end='')
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
# Chat completion
|
||||||
|
# 非流式返回
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Hello, who are you"},
|
||||||
|
{"role": "system", "content": "I'm a helpful AI assistant."},
|
||||||
|
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=64,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
# 流式返回
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Hello, who are you"},
|
||||||
|
{"role": "system", "content": "I'm a helpful AI assistant."},
|
||||||
|
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=64,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.choices[0].delta is not None:
|
||||||
|
print(chunk.choices[0].delta.content, end='')
|
||||||
|
print("\n")
|
||||||
|
```
|
||||||
|
|
||||||
## 模型配置参数介绍
|
## 模型配置参数介绍
|
||||||
|
|
||||||
| 字段名 | 字段类型 | 说明 | 是否必填 | 默认值 | 备注 |
|
| 字段名 | 字段类型 | 说明 | 是否必填 | 默认值 | 备注 |
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
# model server
|
# model server
|
||||||
paddlenlp==2.7.2
|
|
||||||
sentencepiece
|
sentencepiece
|
||||||
pycryptodome
|
pycryptodome
|
||||||
tritonclient[all]==2.41.1
|
tritonclient[all]==2.41.1
|
||||||
@@ -10,7 +9,7 @@ transformers
|
|||||||
# http server
|
# http server
|
||||||
fastapi
|
fastapi
|
||||||
httpx
|
httpx
|
||||||
openai==1.9.0
|
openai==1.44.1
|
||||||
asyncio
|
asyncio
|
||||||
uvicorn
|
uvicorn
|
||||||
shortuuid
|
shortuuid
|
||||||
@@ -20,4 +19,3 @@ pynvml
|
|||||||
|
|
||||||
# paddlenlp
|
# paddlenlp
|
||||||
tiktoken
|
tiktoken
|
||||||
transformers
|
|
||||||
|
|||||||
@@ -6,8 +6,7 @@ export PYTHONIOENCODING=utf8
|
|||||||
export LC_ALL=C.UTF-8
|
export LC_ALL=C.UTF-8
|
||||||
|
|
||||||
# PaddlePaddle environment variables
|
# PaddlePaddle environment variables
|
||||||
export FLAGS_allocator_strategy=naive_best_fit
|
export FLAGS_allocator_strategy=auto_growth
|
||||||
export FLAGS_fraction_of_gpu_memory_to_use=0.96
|
|
||||||
export FLAGS_dynamic_static_unified_comm=0
|
export FLAGS_dynamic_static_unified_comm=0
|
||||||
export FLAGS_use_xqa_optim=1
|
export FLAGS_use_xqa_optim=1
|
||||||
export FLAGS_gemm_use_half_precision_compute_type=0
|
export FLAGS_gemm_use_half_precision_compute_type=0
|
||||||
|
|||||||
@@ -40,8 +40,6 @@ def check_basic_params(req_dict):
|
|||||||
error_msg.append("The `input_ids` in input parameters must be a list")
|
error_msg.append("The `input_ids` in input parameters must be a list")
|
||||||
if "messages" in req_dict:
|
if "messages" in req_dict:
|
||||||
msg_len = len(req_dict["messages"])
|
msg_len = len(req_dict["messages"])
|
||||||
if msg_len % 2 == 0:
|
|
||||||
error_msg.append(f"The number of the message {msg_len} must be odd")
|
|
||||||
if not all("content" in item for item in req_dict["messages"]):
|
if not all("content" in item for item in req_dict["messages"]):
|
||||||
error_msg.append("The item in messages must include `content`")
|
error_msg.append("The item in messages must include `content`")
|
||||||
|
|
||||||
|
|||||||
@@ -125,8 +125,8 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
|
|
||||||
self.decode_status = dict()
|
self.decode_status = dict()
|
||||||
self.tokenizer = self._load_tokenizer()
|
self.tokenizer = self._load_tokenizer()
|
||||||
data_processor_logger.info(f"tokenizer infomation: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, "+
|
data_processor_logger.info(f"tokenizer infomation: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
|
||||||
f"eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id}, ")
|
eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id} ")
|
||||||
|
|
||||||
def process_request(self, request, max_seq_len=None):
|
def process_request(self, request, max_seq_len=None):
|
||||||
"""
|
"""
|
||||||
@@ -143,14 +143,19 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
request["eos_token_ids"] = []
|
request["eos_token_ids"] = []
|
||||||
request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config))
|
request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config))
|
||||||
|
|
||||||
if "input_ids" in request:
|
if "input_ids" not in request or \
|
||||||
input_ids = request["input_ids"]
|
(isinstance(request["input_ids"], (list, tuple)) and len(request["input_ids"]) == 0):
|
||||||
else:
|
if "text" in request:
|
||||||
input_ids = self.text2ids(request['text'])
|
request["input_ids"] = self.text2ids(request["text"])
|
||||||
|
elif "messages" in request:
|
||||||
|
if self.tokenizer.chat_template is None:
|
||||||
|
raise ValueError(f"This model does not support chat_template.")
|
||||||
|
request["input_ids"] = self.messages2ids(request["messages"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
||||||
|
|
||||||
if max_seq_len is not None and len(input_ids) > max_seq_len:
|
if max_seq_len is not None and len(request["input_ids"]) > max_seq_len:
|
||||||
input_ids = input_ids[:max_seq_len-1]
|
request["input_ids"] = request["input_ids"][:max_seq_len-1]
|
||||||
request["input_ids"] = input_ids
|
|
||||||
data_processor_logger.info(f"processed request: {request}")
|
data_processor_logger.info(f"processed request: {request}")
|
||||||
return request
|
return request
|
||||||
|
|
||||||
@@ -221,7 +226,8 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
Returns:
|
Returns:
|
||||||
List[int]: ID sequences
|
List[int]: ID sequences
|
||||||
"""
|
"""
|
||||||
return
|
message_result = self.tokenizer.apply_chat_template(messages, return_tensors="pd")
|
||||||
|
return message_result["input_ids"][0]
|
||||||
|
|
||||||
def ids2tokens(self, token_id, task_id):
|
def ids2tokens(self, token_id, task_id):
|
||||||
"""
|
"""
|
||||||
|
|||||||
103
llm/server/server/http_server/adapter_openai.py
Normal file
103
llm/server/server/http_server/adapter_openai.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import time
|
||||||
|
import json
|
||||||
|
import queue
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from typing import Dict
|
||||||
|
from datetime import datetime
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import tritonclient.grpc as grpcclient
|
||||||
|
from tritonclient import utils as triton_utils
|
||||||
|
from openai.types.completion_usage import CompletionUsage
|
||||||
|
from openai.types.completion_choice import CompletionChoice
|
||||||
|
from openai.types.completion import Completion
|
||||||
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
|
ChoiceDelta,
|
||||||
|
ChatCompletionChunk,
|
||||||
|
Choice as ChatCompletionChoice
|
||||||
|
)
|
||||||
|
|
||||||
|
from server.http_server.api import Req, chat_completion_generator
|
||||||
|
from server.utils import http_server_logger
|
||||||
|
|
||||||
|
|
||||||
|
def format_openai_message_completions(req: Req, result: Dict) -> Completion:
|
||||||
|
choice_data = CompletionChoice(
|
||||||
|
index=0,
|
||||||
|
text=result['token'],
|
||||||
|
finish_reason=result.get("finish_reason", "stop"),
|
||||||
|
)
|
||||||
|
chunk = Completion(
|
||||||
|
id=req.req_id,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=req.model,
|
||||||
|
created=int(time.time()),
|
||||||
|
object="text_completion",
|
||||||
|
usage=CompletionUsage(
|
||||||
|
completion_tokens=result["usage"]["completion_tokens"],
|
||||||
|
prompt_tokens=result["usage"]["prompt_tokens"],
|
||||||
|
total_tokens=result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return chunk.model_dump_json(exclude_unset=True)
|
||||||
|
|
||||||
|
|
||||||
|
def format_openai_message_chat_completions(req: Req, result: Dict) -> ChatCompletionChunk:
|
||||||
|
choice_data = ChatCompletionChoice(
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(
|
||||||
|
content=result['token'],
|
||||||
|
role="assistant",
|
||||||
|
),
|
||||||
|
finish_reason=result.get("finish_reason", "stop"),
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionChunk(
|
||||||
|
id=req.req_id,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=req.model,
|
||||||
|
created=int(time.time()),
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
usage=CompletionUsage(
|
||||||
|
completion_tokens=result["usage"]["completion_tokens"],
|
||||||
|
prompt_tokens=result["usage"]["prompt_tokens"],
|
||||||
|
total_tokens=result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return chunk.model_dump_json(exclude_unset=True)
|
||||||
|
|
||||||
|
|
||||||
|
def openai_chat_commpletion_generator(infer_grpc_url: str, req: Req, chat_interface: bool) -> Dict:
|
||||||
|
|
||||||
|
def _openai_format_resp(resp_dict):
|
||||||
|
return f"data: {resp_dict}\n\n"
|
||||||
|
|
||||||
|
for resp in chat_completion_generator(infer_grpc_url, req, yield_json=False):
|
||||||
|
if resp.get("is_end") == 1:
|
||||||
|
yield _openai_format_resp("[DONE]")
|
||||||
|
|
||||||
|
if chat_interface:
|
||||||
|
yield _openai_format_resp(format_openai_message_chat_completions(req, resp))
|
||||||
|
else:
|
||||||
|
yield _openai_format_resp(format_openai_message_completions(req, resp))
|
||||||
|
|
||||||
|
|
||||||
|
def openai_chat_completion_result(infer_grpc_url: str, req: Req, chat_interface: bool):
|
||||||
|
result = ""
|
||||||
|
error_resp = None
|
||||||
|
for resp in chat_completion_generator(infer_grpc_url, req, yield_json=False):
|
||||||
|
if resp.get("error_msg") or resp.get("error_code"):
|
||||||
|
error_resp = resp
|
||||||
|
error_resp["result"] = ""
|
||||||
|
else:
|
||||||
|
result += resp.get("token")
|
||||||
|
usage = resp.get("usage", None)
|
||||||
|
|
||||||
|
if error_resp:
|
||||||
|
return error_resp
|
||||||
|
response = {'token': result, 'error_msg': '', 'error_code': 0, 'usage': usage}
|
||||||
|
|
||||||
|
if chat_interface:
|
||||||
|
return format_openai_message_chat_completions(req, response)
|
||||||
|
else:
|
||||||
|
return format_openai_message_completions(req, response)
|
||||||
@@ -16,6 +16,7 @@ import json
|
|||||||
import queue
|
import queue
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
import shortuuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
@@ -46,6 +47,7 @@ class Req(BaseModel):
|
|||||||
return_usage: Optional[bool] = False
|
return_usage: Optional[bool] = False
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
timeout: int = 300
|
timeout: int = 300
|
||||||
|
model: str = None
|
||||||
|
|
||||||
def to_dict_for_infer(self):
|
def to_dict_for_infer(self):
|
||||||
"""
|
"""
|
||||||
@@ -54,14 +56,37 @@ class Req(BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
dict: request parameters in dict format
|
dict: request parameters in dict format
|
||||||
"""
|
"""
|
||||||
self.compatible_with_OpenAI()
|
|
||||||
|
|
||||||
req_dict = {}
|
req_dict = {}
|
||||||
for key, value in self.dict().items():
|
for key, value in self.dict().items():
|
||||||
if value is not None:
|
if value is not None:
|
||||||
req_dict[key] = value
|
req_dict[key] = value
|
||||||
return req_dict
|
return req_dict
|
||||||
|
|
||||||
|
def load_openai_request(self, request_dict: dict):
|
||||||
|
"""
|
||||||
|
Convert openai request to Req
|
||||||
|
official OpenAI API documentation: https://platform.openai.com/docs/api-reference/completions/create
|
||||||
|
"""
|
||||||
|
convert_dict = {
|
||||||
|
"text": "prompt",
|
||||||
|
"frequency_score": "frequency_penalty",
|
||||||
|
"max_dec_len": "max_tokens",
|
||||||
|
"stream": "stream",
|
||||||
|
"return_all_tokens": "best_of",
|
||||||
|
"temperature": "temperature",
|
||||||
|
"topp": "top_p",
|
||||||
|
"presence_score": "presence_penalty",
|
||||||
|
"eos_token_ids": "stop",
|
||||||
|
"req_id": "id",
|
||||||
|
"model": "model",
|
||||||
|
"messages": "messages",
|
||||||
|
}
|
||||||
|
|
||||||
|
self.__setattr__("req_id", f"chatcmpl-{shortuuid.random()}")
|
||||||
|
for key, value in convert_dict.items():
|
||||||
|
if request_dict.get(value, None):
|
||||||
|
self.__setattr__(key, request_dict.get(value))
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_generator(infer_grpc_url: str, req: Req, yield_json: bool) -> Dict:
|
def chat_completion_generator(infer_grpc_url: str, req: Req, yield_json: bool) -> Dict:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -16,10 +16,14 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from typing import Dict
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from server.http_server.api import (Req, chat_completion_generator,
|
from server.http_server.api import (Req, chat_completion_generator,
|
||||||
chat_completion_result)
|
chat_completion_result)
|
||||||
|
from server.http_server.adapter_openai import (
|
||||||
|
openai_chat_commpletion_generator, openai_chat_completion_result
|
||||||
|
)
|
||||||
from server.utils import http_server_logger
|
from server.utils import http_server_logger
|
||||||
|
|
||||||
http_server_logger.info(f"create fastapi app...")
|
http_server_logger.info(f"create fastapi app...")
|
||||||
@@ -58,6 +62,48 @@ def create_chat_completion(req: Req):
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions/completions")
|
||||||
|
def openai_v1_completions(request: Dict):
|
||||||
|
return create_openai_completion(request, chat_interface=False)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions/chat/completions")
|
||||||
|
def openai_v1_chat_completions(request: Dict):
|
||||||
|
return create_openai_completion(request, chat_interface=True)
|
||||||
|
|
||||||
|
|
||||||
|
def create_openai_completion(request: Dict, chat_interface: bool):
|
||||||
|
try:
|
||||||
|
req = Req()
|
||||||
|
req.load_openai_request(request)
|
||||||
|
except Exception as e:
|
||||||
|
return {"error_msg": "request body is not a valid json format", "error_code": 400, "result": ''}
|
||||||
|
|
||||||
|
try:
|
||||||
|
http_server_logger.info(f"receive request: {req.req_id}")
|
||||||
|
|
||||||
|
grpc_port = int(os.getenv("GRPC_PORT", 0))
|
||||||
|
if grpc_port == 0:
|
||||||
|
return {"error_msg": f"GRPC_PORT ({grpc_port}) for infer service is invalid",
|
||||||
|
"error_code": 400}
|
||||||
|
grpc_url = f"localhost:{grpc_port}"
|
||||||
|
|
||||||
|
if req.stream:
|
||||||
|
generator = openai_chat_commpletion_generator(
|
||||||
|
infer_grpc_url=grpc_url,
|
||||||
|
req=req,
|
||||||
|
chat_interface=chat_interface,
|
||||||
|
)
|
||||||
|
resp = StreamingResponse(generator, media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
resp = openai_chat_completion_result(infer_grpc_url=grpc_url, req=req, chat_interface=chat_interface)
|
||||||
|
except Exception as e:
|
||||||
|
resp = {'error_msg': str(e), 'error_code': 501}
|
||||||
|
finally:
|
||||||
|
http_server_logger.info(f"finish request: {req.req_id}")
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
def launch_http_server(port: int, workers: int) -> None:
|
def launch_http_server(port: int, workers: int) -> None:
|
||||||
"""
|
"""
|
||||||
launch http server
|
launch http server
|
||||||
|
|||||||
Reference in New Issue
Block a user