mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Serving]add an serving example of tts (#384)
* add tts example * update example Co-authored-by: Zeyu Chen <chenzeyu01@baidu.com>
This commit is contained in:
77
examples/audio/pp-tts/serving/README.md
Normal file
77
examples/audio/pp-tts/serving/README.md
Normal file
@@ -0,0 +1,77 @@
|
||||
([简体中文](./README_cn.md)|English)
|
||||
|
||||
# PP-TTS Streaming Text-to-Speech Serving
|
||||
|
||||
## Introduction
|
||||
This demo is an implementation of starting the streaming speech synthesis service and accessing the service.
|
||||
|
||||
`Server` must be started in the docker, while `Client` does not have to be in the docker.
|
||||
|
||||
**The streaming_pp_tts under the path of this article ($PWD) contains the configuration and code of the model, which needs to be mapped to the docker for use.**
|
||||
|
||||
## Usage
|
||||
### 1. Server
|
||||
#### 1.1 Docker
|
||||
|
||||
```bash
|
||||
docker pull registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
|
||||
docker run -dit --net=host --name fastdeploy --shm-size="1g" -v $PWD:/models registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
|
||||
docker exec -it -u root fastdeploy bash
|
||||
```
|
||||
|
||||
#### 1.2 Installation (inside the docker)
|
||||
```bash
|
||||
apt-get install build-essential python3-dev libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libsndfile1 language-pack-zh-hans wget zip
|
||||
pip3 install paddlespeech
|
||||
export LC_ALL="zh_CN.UTF-8"
|
||||
export LANG="zh_CN.UTF-8"
|
||||
export LANGUAGE="zh_CN:zh:en_US:en"
|
||||
```
|
||||
|
||||
#### 1.3 Download models (inside the docker)
|
||||
```bash
|
||||
cd /models/streaming_pp_tts/1
|
||||
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
|
||||
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip
|
||||
unzip fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
|
||||
unzip mb_melgan_csmsc_onnx_0.2.0.zip
|
||||
```
|
||||
**For the convenience of users, we recommend that you use the command `docker -v` to map $PWD (streaming_pp_tts and the configuration and code of the model contained therein) to the docker path `/models`. You can also use other methods, but regardless of which method you use, the final model directory and structure in the docker are shown in the following figure.**
|
||||
|
||||
```
|
||||
/models
|
||||
│
|
||||
└───streaming_pp_tts #Directory of the entire service model
|
||||
│ config.pbtxt #Configuration file of service model
|
||||
│ stream_client.py #Code of Client
|
||||
│
|
||||
└───1 #Model version number
|
||||
│ model.py #Code to start the model
|
||||
└───fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0 #Model file required by code
|
||||
└───mb_melgan_csmsc_onnx_0.2.0 #Model file required by code
|
||||
|
||||
```
|
||||
|
||||
#### 1.4 Start the server (inside the docker)
|
||||
|
||||
```bash
|
||||
fastdeployserver --model-repository=/models --model-control-mode=explicit --load-model=streaming_pp_tts
|
||||
```
|
||||
Arguments:
|
||||
- `model-repository`(required): Path of model storage.
|
||||
- `model-control-mode`(required): The mode of loading the model. At present, you can use 'explicit'.
|
||||
- `load-model`(required): Name of the model to be loaded.
|
||||
- `http-port`(optional): Port for http service. Default: `8000`. This is not used in our example.
|
||||
- `grpc-port`(optional): Port for grpc service. Default: `8001`.
|
||||
- `metrics-port`(optional): Port for metrics service. Default: `8002`. This is not used in our example.
|
||||
|
||||
### 2. Client
|
||||
#### 2.1 Installation
|
||||
```bash
|
||||
pip3 install tritonclient[all]
|
||||
```
|
||||
|
||||
#### 2.2 Send request
|
||||
```bash
|
||||
python3 /models/streaming_pp_tts/stream_client.py
|
||||
```
|
76
examples/audio/pp-tts/serving/README_cn.md
Normal file
76
examples/audio/pp-tts/serving/README_cn.md
Normal file
@@ -0,0 +1,76 @@
|
||||
(简体中文|[English](./README.md))
|
||||
|
||||
# PP-TTS流式语音合成服务化部署
|
||||
|
||||
## 介绍
|
||||
本文介绍了使用FastDeploy搭建流式语音合成服务的方法。
|
||||
|
||||
服务端必须在docker内启动,而客户端不是必须在docker容器内.
|
||||
|
||||
**本文所在路径($PWD)下的streaming_pp_tts里包含模型的配置和代码(服务端会加载模型和代码以启动服务),需要将其映射到docker中使用。**
|
||||
|
||||
## 使用
|
||||
### 1. 服务端
|
||||
#### 1.1 Docker
|
||||
```bash
|
||||
docker pull registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
|
||||
docker run -dit --net=host --name fastdeploy --shm-size="1g" -v $PWD:/models registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
|
||||
docker exec -it -u root fastdeploy bash
|
||||
```
|
||||
|
||||
#### 1.2 安装(在docker内)
|
||||
```bash
|
||||
apt-get install build-essential python3-dev libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libsndfile1 language-pack-zh-hans wget zip
|
||||
pip3 install paddlespeech
|
||||
export LC_ALL="zh_CN.UTF-8"
|
||||
export LANG="zh_CN.UTF-8"
|
||||
export LANGUAGE="zh_CN:zh:en_US:en"
|
||||
```
|
||||
|
||||
#### 1.3 下载模型(在docker内)
|
||||
```bash
|
||||
cd /models/streaming_pp_tts/1
|
||||
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
|
||||
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip
|
||||
unzip fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
|
||||
unzip mb_melgan_csmsc_onnx_0.2.0.zip
|
||||
```
|
||||
**为了方便用户使用,我们推荐用户使用1.1中的`docker -v`命令将$PWD(streaming_pp_tts及里面包含的模型的配置和代码)映射到了docker内的`/models`路径,用户也可以使用其他办法,但无论使用哪种方法,最终在docker内的模型目录及结构如下图所示。**
|
||||
|
||||
```
|
||||
/models
|
||||
│
|
||||
└───streaming_pp_tts #整个服务模型文件夹
|
||||
│ config.pbtxt #服务模型配置文件
|
||||
│ stream_client.py #客户端代码
|
||||
│
|
||||
└───1 #模型版本号,此处为1
|
||||
│ model.py #模型启动代码
|
||||
└───fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0 #启动代码所需的模型文件
|
||||
└───mb_melgan_csmsc_onnx_0.2.0 #启动代码所需的模型文件
|
||||
|
||||
```
|
||||
|
||||
#### 1.4 启动服务端(在docker内)
|
||||
```bash
|
||||
fastdeployserver --model-repository=/models --model-control-mode=explicit --load-model=streaming_pp_tts
|
||||
```
|
||||
|
||||
参数:
|
||||
- `model-repository`(required): 整套模型streaming_pp_tts存放的路径.
|
||||
- `model-control-mode`(required): 模型加载的方式,现阶段, 使用'explicit'即可.
|
||||
- `load-model`(required): 需要加载的模型的名称.
|
||||
- `http-port`(optional): HTTP服务的端口号. 默认: `8000`. 本示例中未使用该端口.
|
||||
- `grpc-port`(optional): GRPC服务的端口号. 默认: `8001`.
|
||||
- `metrics-port`(optional): 服务端指标的端口号. 默认: `8002`. 本示例中未使用该端口.
|
||||
|
||||
### 2. 客户端
|
||||
#### 2.1 安装
|
||||
```bash
|
||||
pip3 install tritonclient[all]
|
||||
```
|
||||
|
||||
#### 2.2 发送请求
|
||||
```bash
|
||||
python3 /models/streaming_pp_tts/stream_client.py
|
||||
```
|
303
examples/audio/pp-tts/serving/streaming_pp_tts/1/model.py
Normal file
303
examples/audio/pp-tts/serving/streaming_pp_tts/1/model.py
Normal file
@@ -0,0 +1,303 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import codecs
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import triton_python_backend_utils as pb_utils
|
||||
|
||||
from paddlespeech.server.utils.util import denorm
|
||||
from paddlespeech.server.utils.util import get_chunks
|
||||
from paddlespeech.t2s.frontend.zh_frontend import Frontend
|
||||
|
||||
voc_block = 36
|
||||
voc_pad = 14
|
||||
am_block = 72
|
||||
am_pad = 12
|
||||
voc_upsample = 300
|
||||
|
||||
# 模型路径
|
||||
dir_name = "/models/streaming_tts_serving/1/"
|
||||
phones_dict = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/phone_id_map.txt"
|
||||
am_stat_path = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/speech_stats.npy"
|
||||
|
||||
onnx_am_encoder = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_encoder_infer.onnx"
|
||||
onnx_am_decoder = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_decoder.onnx"
|
||||
onnx_am_postnet = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_postnet.onnx"
|
||||
onnx_voc_melgan = dir_name + "mb_melgan_csmsc_onnx_0.2.0/mb_melgan_csmsc.onnx"
|
||||
|
||||
frontend = Frontend(phone_vocab_path=phones_dict, tone_vocab_path=None)
|
||||
am_mu, am_std = np.load(am_stat_path)
|
||||
|
||||
# 用CPU推理
|
||||
providers = ['CPUExecutionProvider']
|
||||
|
||||
# 配置ort session
|
||||
sess_options = ort.SessionOptions()
|
||||
|
||||
# 创建session
|
||||
am_encoder_infer_sess = ort.InferenceSession(
|
||||
onnx_am_encoder, providers=providers, sess_options=sess_options)
|
||||
am_decoder_sess = ort.InferenceSession(
|
||||
onnx_am_decoder, providers=providers, sess_options=sess_options)
|
||||
am_postnet_sess = ort.InferenceSession(
|
||||
onnx_am_postnet, providers=providers, sess_options=sess_options)
|
||||
voc_melgan_sess = ort.InferenceSession(
|
||||
onnx_voc_melgan, providers=providers, sess_options=sess_options)
|
||||
|
||||
|
||||
def depadding(data, chunk_num, chunk_id, block, pad, upsample):
|
||||
"""
|
||||
Streaming inference removes the result of pad inference
|
||||
"""
|
||||
front_pad = min(chunk_id * block, pad)
|
||||
# first chunk
|
||||
if chunk_id == 0:
|
||||
data = data[:block * upsample]
|
||||
# last chunk
|
||||
elif chunk_id == chunk_num - 1:
|
||||
data = data[front_pad * upsample:]
|
||||
# middle chunk
|
||||
else:
|
||||
data = data[front_pad * upsample:(front_pad + block) * upsample]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Your Python model must use the same class name. Every Python model
|
||||
that is created must have "TritonPythonModel" as the class name.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""`initialize` is called only once when the model is being loaded.
|
||||
Implementing `initialize` function is optional. This function allows
|
||||
the model to intialize any state associated with this model.
|
||||
Parameters
|
||||
----------
|
||||
args : dict
|
||||
Both keys and values are strings. The dictionary keys and values are:
|
||||
* model_config: A JSON string containing the model configuration
|
||||
* model_instance_kind: A string containing model instance kind
|
||||
* model_instance_device_id: A string containing model instance device ID
|
||||
* model_repository: Model repository path
|
||||
* model_version: Model version
|
||||
* model_name: Model name
|
||||
"""
|
||||
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.detach())
|
||||
print(sys.getdefaultencoding())
|
||||
# You must parse model_config. JSON string is not parsed here
|
||||
self.model_config = model_config = json.loads(args['model_config'])
|
||||
print("model_config:", self.model_config)
|
||||
|
||||
using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
|
||||
model_config)
|
||||
|
||||
if not using_decoupled:
|
||||
raise pb_utils.TritonModelException(
|
||||
"""the model `{}` can generate any number of responses per request,
|
||||
enable decoupled transaction policy in model configuration to
|
||||
serve this model""".format(args['model_name']))
|
||||
|
||||
self.input_names = []
|
||||
for input_config in self.model_config["input"]:
|
||||
self.input_names.append(input_config["name"])
|
||||
print("input:", self.input_names)
|
||||
|
||||
self.output_names = []
|
||||
self.output_dtype = []
|
||||
for output_config in self.model_config["output"]:
|
||||
self.output_names.append(output_config["name"])
|
||||
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
||||
self.output_dtype.append(dtype)
|
||||
print("output:", self.output_names)
|
||||
|
||||
# To keep track of response threads so that we can delay
|
||||
# the finalizing the model until all response threads
|
||||
# have completed.
|
||||
self.inflight_thread_count = 0
|
||||
self.inflight_thread_count_lck = threading.Lock()
|
||||
|
||||
def execute(self, requests):
|
||||
"""`execute` must be implemented in every Python model. `execute`
|
||||
function receives a list of pb_utils.InferenceRequest as the only
|
||||
argument. This function is called when an inference is requested
|
||||
for this model. Depending on the batching configuration (e.g. Dynamic
|
||||
Batching) used, `requests` may contain multiple requests. Every
|
||||
Python model, must create one pb_utils.InferenceResponse for every
|
||||
pb_utils.InferenceRequest in `requests`. If there is an error, you can
|
||||
set the error argument when creating a pb_utils.InferenceResponse.
|
||||
Parameters
|
||||
----------
|
||||
requests : list
|
||||
A list of pb_utils.InferenceRequest
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list of pb_utils.InferenceResponse. The length of this list must
|
||||
be the same as `requests`
|
||||
"""
|
||||
|
||||
# This model does not support batching, so 'request_count' should always
|
||||
# be 1.
|
||||
if len(requests) != 1:
|
||||
raise pb_utils.TritonModelException("unsupported batch size " + len(
|
||||
requests))
|
||||
|
||||
input_data = []
|
||||
for idx in range(len(self.input_names)):
|
||||
data = pb_utils.get_input_tensor_by_name(requests[0],
|
||||
self.input_names[idx])
|
||||
data = data.as_numpy()
|
||||
data = data[0].decode('utf-8')
|
||||
input_data.append(data)
|
||||
text = input_data[0]
|
||||
|
||||
# Start a separate thread to send the responses for the request. The
|
||||
# sending back the responses is delegated to this thread.
|
||||
thread = threading.Thread(
|
||||
target=self.response_thread,
|
||||
args=(requests[0].get_response_sender(), text))
|
||||
thread.daemon = True
|
||||
with self.inflight_thread_count_lck:
|
||||
self.inflight_thread_count += 1
|
||||
|
||||
thread.start()
|
||||
# Unlike in non-decoupled model transaction policy, execute function
|
||||
# here returns no response. A return from this function only notifies
|
||||
# Triton that the model instance is ready to receive another request. As
|
||||
# we are not waiting for the response thread to complete here, it is
|
||||
# possible that at any give time the model may be processing multiple
|
||||
# requests. Depending upon the request workload, this may lead to a lot
|
||||
# of requests being processed by a single model instance at a time. In
|
||||
# real-world models, the developer should be mindful of when to return
|
||||
# from execute and be willing to accept next request.
|
||||
return None
|
||||
|
||||
def response_thread(self, response_sender, text):
|
||||
input_ids = frontend.get_input_ids(
|
||||
text, merge_sentences=False, get_tone_ids=False)
|
||||
phone_ids = input_ids["phone_ids"]
|
||||
for i in range(len(phone_ids)):
|
||||
part_phone_ids = phone_ids[i].numpy()
|
||||
voc_chunk_id = 0
|
||||
|
||||
orig_hs = am_encoder_infer_sess.run(
|
||||
None, input_feed={'text': part_phone_ids})
|
||||
orig_hs = orig_hs[0]
|
||||
|
||||
# streaming voc chunk info
|
||||
mel_len = orig_hs.shape[1]
|
||||
voc_chunk_num = math.ceil(mel_len / voc_block)
|
||||
start = 0
|
||||
end = min(voc_block + voc_pad, mel_len)
|
||||
|
||||
# streaming am
|
||||
hss = get_chunks(orig_hs, am_block, am_pad, "am")
|
||||
am_chunk_num = len(hss)
|
||||
for i, hs in enumerate(hss):
|
||||
am_decoder_output = am_decoder_sess.run(
|
||||
None, input_feed={'xs': hs})
|
||||
am_postnet_output = am_postnet_sess.run(
|
||||
None,
|
||||
input_feed={
|
||||
'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
|
||||
})
|
||||
am_output_data = am_decoder_output + np.transpose(
|
||||
am_postnet_output[0], (0, 2, 1))
|
||||
normalized_mel = am_output_data[0][0]
|
||||
|
||||
sub_mel = denorm(normalized_mel, am_mu, am_std)
|
||||
sub_mel = depadding(sub_mel, am_chunk_num, i, am_block, am_pad,
|
||||
1)
|
||||
|
||||
if i == 0:
|
||||
mel_streaming = sub_mel
|
||||
else:
|
||||
mel_streaming = np.concatenate(
|
||||
(mel_streaming, sub_mel), axis=0)
|
||||
|
||||
# streaming voc
|
||||
# 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理
|
||||
while (mel_streaming.shape[0] >= end and
|
||||
voc_chunk_id < voc_chunk_num):
|
||||
voc_chunk = mel_streaming[start:end, :]
|
||||
|
||||
sub_wav = voc_melgan_sess.run(
|
||||
output_names=None, input_feed={'logmel': voc_chunk})
|
||||
sub_wav = depadding(sub_wav[0], voc_chunk_num, voc_chunk_id,
|
||||
voc_block, voc_pad, voc_upsample)
|
||||
|
||||
output_np = np.array(sub_wav, dtype=self.output_dtype[0])
|
||||
out_tensor1 = pb_utils.Tensor(self.output_names[0],
|
||||
output_np)
|
||||
|
||||
status = 0 if voc_chunk_id != (voc_chunk_num - 1) else 1
|
||||
output_status = np.array(
|
||||
[status], dtype=self.output_dtype[1])
|
||||
out_tensor2 = pb_utils.Tensor(self.output_names[1],
|
||||
output_status)
|
||||
|
||||
inference_response = pb_utils.InferenceResponse(
|
||||
output_tensors=[out_tensor1, out_tensor2])
|
||||
|
||||
#yield sub_wav
|
||||
response_sender.send(inference_response)
|
||||
|
||||
voc_chunk_id += 1
|
||||
start = max(0, voc_chunk_id * voc_block - voc_pad)
|
||||
end = min((voc_chunk_id + 1) * voc_block + voc_pad, mel_len)
|
||||
|
||||
# We must close the response sender to indicate to Triton that we are
|
||||
# done sending responses for the corresponding request. We can't use the
|
||||
# response sender after closing it. The response sender is closed by
|
||||
# setting the TRITONSERVER_RESPONSE_COMPLETE_FINAL.
|
||||
response_sender.send(
|
||||
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||
|
||||
with self.inflight_thread_count_lck:
|
||||
self.inflight_thread_count -= 1
|
||||
|
||||
def finalize(self):
|
||||
"""`finalize` is called only once when the model is being unloaded.
|
||||
Implementing `finalize` function is OPTIONAL. This function allows
|
||||
the model to perform any necessary clean ups before exit.
|
||||
Here we will wait for all response threads to complete sending
|
||||
responses.
|
||||
"""
|
||||
print('Finalize invoked')
|
||||
|
||||
inflight_threads = True
|
||||
cycles = 0
|
||||
logging_time_sec = 5
|
||||
sleep_time_sec = 0.1
|
||||
cycle_to_log = (logging_time_sec / sleep_time_sec)
|
||||
while inflight_threads:
|
||||
with self.inflight_thread_count_lck:
|
||||
inflight_threads = (self.inflight_thread_count != 0)
|
||||
if (cycles % cycle_to_log == 0):
|
||||
print(
|
||||
f"Waiting for {self.inflight_thread_count} response threads to complete..."
|
||||
)
|
||||
if inflight_threads:
|
||||
time.sleep(sleep_time_sec)
|
||||
cycles += 1
|
||||
|
||||
print('Finalize complete...')
|
33
examples/audio/pp-tts/serving/streaming_pp_tts/config.pbtxt
Normal file
33
examples/audio/pp-tts/serving/streaming_pp_tts/config.pbtxt
Normal file
@@ -0,0 +1,33 @@
|
||||
name: "streaming_pp_tts"
|
||||
backend: "python"
|
||||
max_batch_size: 0
|
||||
model_transaction_policy {
|
||||
decoupled: True
|
||||
}
|
||||
input [
|
||||
{
|
||||
name: "INPUT_0"
|
||||
data_type: TYPE_STRING
|
||||
dims: [ 1 ]
|
||||
}
|
||||
]
|
||||
|
||||
output [
|
||||
{
|
||||
name: "OUTPUT_0"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1, 1 ]
|
||||
},
|
||||
{
|
||||
name: "status"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
}
|
||||
]
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind: KIND_CPU
|
||||
}
|
||||
]
|
130
examples/audio/pp-tts/serving/streaming_pp_tts/stream_client.py
Normal file
130
examples/audio/pp-tts/serving/streaming_pp_tts/stream_client.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import queue
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from tritonclient.utils import *
|
||||
|
||||
FLAGS = None
|
||||
|
||||
|
||||
class UserData:
|
||||
def __init__(self):
|
||||
self._completed_requests = queue.Queue()
|
||||
|
||||
|
||||
# Define the callback function. Note the last two parameters should be
|
||||
# result and error. InferenceServerClient would povide the results of an
|
||||
# inference as grpcclient.InferResult in result. For successful
|
||||
# inference, error will be None, otherwise it will be an object of
|
||||
# tritonclientutils.InferenceServerException holding the error details
|
||||
def callback(user_data, result, error):
|
||||
if error:
|
||||
user_data._completed_requests.put(error)
|
||||
else:
|
||||
user_data._completed_requests.put(result)
|
||||
|
||||
|
||||
def async_stream_send(triton_client, values, request_id, model_name):
|
||||
|
||||
infer_inputs = []
|
||||
outputs = []
|
||||
for idx, data in enumerate(values):
|
||||
data = np.array([data.encode('utf-8')], dtype=np.object_)
|
||||
infer_input = grpcclient.InferInput('INPUT_0', [len(data)], "BYTES")
|
||||
infer_input.set_data_from_numpy(data)
|
||||
infer_inputs.append(infer_input)
|
||||
|
||||
outputs.append(grpcclient.InferRequestedOutput('OUTPUT_0'))
|
||||
# Issue the asynchronous sequence inference.
|
||||
triton_client.async_stream_infer(
|
||||
model_name=model_name,
|
||||
inputs=infer_inputs,
|
||||
outputs=outputs,
|
||||
request_id=request_id)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'-v',
|
||||
'--verbose',
|
||||
action="store_true",
|
||||
required=False,
|
||||
default=False,
|
||||
help='Enable verbose output')
|
||||
parser.add_argument(
|
||||
'-u',
|
||||
'--url',
|
||||
type=str,
|
||||
required=False,
|
||||
default='localhost:8001',
|
||||
help='Inference server URL and it gRPC port. Default is localhost:8001.')
|
||||
|
||||
FLAGS = parser.parse_args()
|
||||
|
||||
# We use custom "sequence" models which take 1 input
|
||||
# value. The output is the accumulated value of the inputs. See
|
||||
# src/custom/sequence.
|
||||
model_name = "streaming_pp_tts"
|
||||
|
||||
values = ["哈哈哈哈"]
|
||||
|
||||
request_id = "0"
|
||||
|
||||
string_result0_list = []
|
||||
|
||||
user_data = UserData()
|
||||
|
||||
# It is advisable to use client object within with..as clause
|
||||
# when sending streaming requests. This ensures the client
|
||||
# is closed when the block inside with exits.
|
||||
with grpcclient.InferenceServerClient(
|
||||
url=FLAGS.url, verbose=FLAGS.verbose) as triton_client:
|
||||
try:
|
||||
# Establish stream
|
||||
triton_client.start_stream(callback=partial(callback, user_data))
|
||||
# Now send the inference sequences...
|
||||
async_stream_send(triton_client, values, request_id, model_name)
|
||||
except InferenceServerException as error:
|
||||
print(error)
|
||||
sys.exit(1)
|
||||
|
||||
# Retrieve results...
|
||||
recv_count = 0
|
||||
result_dict = {}
|
||||
status = True
|
||||
while True:
|
||||
data_item = user_data._completed_requests.get()
|
||||
if type(data_item) == InferenceServerException:
|
||||
raise data_item
|
||||
else:
|
||||
this_id = data_item.get_response().id
|
||||
if this_id not in result_dict.keys():
|
||||
result_dict[this_id] = []
|
||||
result_dict[this_id].append((recv_count, data_item))
|
||||
sub_wav = data_item.as_numpy('OUTPUT_0')
|
||||
status = data_item.as_numpy('status')
|
||||
print('sub_wav = ', sub_wav, "subwav.shape = ", sub_wav.shape)
|
||||
print('status = ', status)
|
||||
if status[0] == 1:
|
||||
break
|
||||
recv_count += 1
|
||||
|
||||
print("PASS: stream_client")
|
Reference in New Issue
Block a user