mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Add tts python example and change onnx to paddle (#420)
* add tts example * update example * update use fd engine * add tts python example * add readme * fix comment * change paddle model * fix readme style Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
9
examples/audio/pp-tts/README.md
Normal file
9
examples/audio/pp-tts/README.md
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# PaddleSpeech 流式语音合成
|
||||||
|
|
||||||
|
|
||||||
|
- 本文示例的实现来自[PaddleSpeech 流式语音合成](https://github.com/PaddlePaddle/PaddleSpeech/tree/r1.2).
|
||||||
|
|
||||||
|
## 详细部署文档
|
||||||
|
|
||||||
|
- [Python部署](python)
|
||||||
|
- [Serving部署](serving)
|
27
examples/audio/pp-tts/python/README.md
Normal file
27
examples/audio/pp-tts/python/README.md
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
([简体中文](./README_cn.md)|English)
|
||||||
|
|
||||||
|
# PP-TTS Streaming Text-to-Speech Python Example
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
This demo is an implementation of starting the streaming speech synthesis.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### 1. Installation
|
||||||
|
```bash
|
||||||
|
apt-get install libsndfile1 wget zip
|
||||||
|
For Centos, yum install libsndfile-devel wget zip
|
||||||
|
python3 -m pip install --upgrade pip
|
||||||
|
pip3 install -U fastdeploy-python -f https://www.paddlepaddle.org.cn/whl/fastdeploy.html
|
||||||
|
pip3 install -U paddlespeech paddlepaddle soundfile matplotlib
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Run the example
|
||||||
|
```bash
|
||||||
|
python3 stream_play_tts.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Result
|
||||||
|
The complete voice synthesis audio is saved as `demo_stream.wav`.
|
||||||
|
|
||||||
|
User can install `pyaudio` on their own terminals to play the results of speech synthesis in real time. The relevant code is in `stream_play_tts.py` and you can debug and run it yourself.
|
26
examples/audio/pp-tts/python/README_cn.md
Normal file
26
examples/audio/pp-tts/python/README_cn.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
(简体中文|[English](./README.md))
|
||||||
|
|
||||||
|
# PP-TTS流式语音合成Python示例
|
||||||
|
|
||||||
|
## 介绍
|
||||||
|
本文介绍了使用FastDeploy运行流式语音合成的示例.
|
||||||
|
|
||||||
|
## 使用
|
||||||
|
### 1. 安装
|
||||||
|
```bash
|
||||||
|
apt-get install libsndfile1 wget zip
|
||||||
|
对于Centos系统,使用yum install libsndfile-devel wget zip
|
||||||
|
python3 -m pip install --upgrade pip
|
||||||
|
pip3 install -U fastdeploy-python -f https://www.paddlepaddle.org.cn/whl/fastdeploy.html
|
||||||
|
pip3 install -U paddlespeech paddlepaddle soundfile matplotlib
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 运行示例
|
||||||
|
```bash
|
||||||
|
python3 stream_play_tts.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 运行效果
|
||||||
|
完整的语音合成音频被保存为`demo_stream.wav`文件.
|
||||||
|
|
||||||
|
用户可以在自己的终端上安装pyaudio, 对语音合成的结果进行实时播放, 相关代码在stream_play_tts.py处于注释状态, 用户可自行调试运行.
|
214
examples/audio/pp-tts/python/stream_play_tts.py
Normal file
214
examples/audio/pp-tts/python/stream_play_tts.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import fastdeploy as fd
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
from paddlespeech.server.utils.util import denorm
|
||||||
|
from paddlespeech.server.utils.util import get_chunks
|
||||||
|
from paddlespeech.t2s.frontend.zh_frontend import Frontend
|
||||||
|
|
||||||
|
model_name_fastspeech2 = "fastspeech2_cnndecoder_csmsc_streaming_static_1.0.0"
|
||||||
|
model_zip_fastspeech2 = model_name_fastspeech2 + ".zip"
|
||||||
|
model_url_fastspeech2 = "https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/" + model_zip_fastspeech2
|
||||||
|
model_name_mb_melgan = "mb_melgan_csmsc_static_0.1.1"
|
||||||
|
model_zip_mb_melgan = model_name_mb_melgan + ".zip"
|
||||||
|
model_url_mb_melgan = "https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/" + model_zip_mb_melgan
|
||||||
|
|
||||||
|
dir_name = os.path.dirname(os.path.realpath(__file__)) + "/"
|
||||||
|
|
||||||
|
if not os.path.exists(model_name_fastspeech2):
|
||||||
|
if os.path.exists(model_zip_fastspeech2):
|
||||||
|
os.remove(model_zip_fastspeech2)
|
||||||
|
fd.download_and_decompress(model_url_fastspeech2, path=dir_name)
|
||||||
|
os.remove(model_zip_fastspeech2)
|
||||||
|
if not os.path.exists(model_name_mb_melgan):
|
||||||
|
if os.path.exists(model_zip_mb_melgan):
|
||||||
|
os.remove(model_zip_mb_melgan)
|
||||||
|
fd.download_and_decompress(model_url_mb_melgan, path=dir_name)
|
||||||
|
os.remove(model_zip_mb_melgan)
|
||||||
|
|
||||||
|
voc_block = 36
|
||||||
|
voc_pad = 14
|
||||||
|
am_block = 72
|
||||||
|
am_pad = 12
|
||||||
|
voc_upsample = 300
|
||||||
|
|
||||||
|
# 模型路径
|
||||||
|
|
||||||
|
phones_dict = dir_name + model_name_fastspeech2 + "/phone_id_map.txt"
|
||||||
|
am_stat_path = dir_name + model_name_fastspeech2 + "/speech_stats.npy"
|
||||||
|
|
||||||
|
am_encoder_model = dir_name + model_name_fastspeech2 + "/fastspeech2_csmsc_am_encoder_infer.pdmodel"
|
||||||
|
am_decoder_model = dir_name + model_name_fastspeech2 + "/fastspeech2_csmsc_am_decoder.pdmodel"
|
||||||
|
am_postnet_model = dir_name + model_name_fastspeech2 + "/fastspeech2_csmsc_am_postnet.pdmodel"
|
||||||
|
voc_melgan_model = dir_name + model_name_mb_melgan + "/mb_melgan_csmsc.pdmodel"
|
||||||
|
|
||||||
|
am_encoder_para = dir_name + model_name_fastspeech2 + "/fastspeech2_csmsc_am_encoder_infer.pdiparams"
|
||||||
|
am_decoder_para = dir_name + model_name_fastspeech2 + "/fastspeech2_csmsc_am_decoder.pdiparams"
|
||||||
|
am_postnet_para = dir_name + model_name_fastspeech2 + "/fastspeech2_csmsc_am_postnet.pdiparams"
|
||||||
|
voc_melgan_para = dir_name + model_name_mb_melgan + "/mb_melgan_csmsc.pdiparams"
|
||||||
|
|
||||||
|
frontend = Frontend(phone_vocab_path=phones_dict, tone_vocab_path=None)
|
||||||
|
am_mu, am_std = np.load(am_stat_path)
|
||||||
|
|
||||||
|
option_1 = fd.RuntimeOption()
|
||||||
|
option_1.set_model_path(am_encoder_model, am_encoder_para)
|
||||||
|
option_1.use_cpu()
|
||||||
|
option_1.use_ort_backend()
|
||||||
|
option_1.set_cpu_thread_num(12)
|
||||||
|
am_encoder_runtime = fd.Runtime(option_1)
|
||||||
|
|
||||||
|
option_2 = fd.RuntimeOption()
|
||||||
|
option_2.set_model_path(am_decoder_model, am_decoder_para)
|
||||||
|
option_2.use_cpu()
|
||||||
|
option_2.use_ort_backend()
|
||||||
|
option_2.set_cpu_thread_num(12)
|
||||||
|
am_decoder_runtime = fd.Runtime(option_2)
|
||||||
|
|
||||||
|
option_3 = fd.RuntimeOption()
|
||||||
|
option_3.set_model_path(am_postnet_model, am_postnet_para)
|
||||||
|
option_3.use_cpu()
|
||||||
|
option_3.use_ort_backend()
|
||||||
|
option_3.set_cpu_thread_num(12)
|
||||||
|
am_postnet_runtime = fd.Runtime(option_3)
|
||||||
|
|
||||||
|
option_4 = fd.RuntimeOption()
|
||||||
|
option_4.set_model_path(voc_melgan_model, voc_melgan_para)
|
||||||
|
option_4.use_cpu()
|
||||||
|
option_4.use_ort_backend()
|
||||||
|
option_4.set_cpu_thread_num(12)
|
||||||
|
voc_melgan_runtime = fd.Runtime(option_4)
|
||||||
|
|
||||||
|
|
||||||
|
def depadding(data, chunk_num, chunk_id, block, pad, upsample):
|
||||||
|
"""
|
||||||
|
Streaming inference removes the result of pad inference
|
||||||
|
"""
|
||||||
|
front_pad = min(chunk_id * block, pad)
|
||||||
|
# first chunk
|
||||||
|
if chunk_id == 0:
|
||||||
|
data = data[:block * upsample]
|
||||||
|
# last chunk
|
||||||
|
elif chunk_id == chunk_num - 1:
|
||||||
|
data = data[front_pad * upsample:]
|
||||||
|
# middle chunk
|
||||||
|
else:
|
||||||
|
data = data[front_pad * upsample:(front_pad + block) * upsample]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def inference_stream(text):
|
||||||
|
input_ids = frontend.get_input_ids(
|
||||||
|
text, merge_sentences=False, get_tone_ids=False)
|
||||||
|
phone_ids = input_ids["phone_ids"]
|
||||||
|
for i in range(len(phone_ids)):
|
||||||
|
part_phone_ids = phone_ids[i].numpy()
|
||||||
|
voc_chunk_id = 0
|
||||||
|
orig_hs = am_encoder_runtime.infer({
|
||||||
|
'text':
|
||||||
|
part_phone_ids.astype("int64")
|
||||||
|
})
|
||||||
|
orig_hs = orig_hs[0]
|
||||||
|
|
||||||
|
# streaming voc chunk info
|
||||||
|
mel_len = orig_hs.shape[1]
|
||||||
|
voc_chunk_num = math.ceil(mel_len / voc_block)
|
||||||
|
start = 0
|
||||||
|
end = min(voc_block + voc_pad, mel_len)
|
||||||
|
|
||||||
|
# streaming am
|
||||||
|
hss = get_chunks(orig_hs, am_block, am_pad, "am")
|
||||||
|
am_chunk_num = len(hss)
|
||||||
|
for i, hs in enumerate(hss):
|
||||||
|
am_decoder_output = am_decoder_runtime.infer({
|
||||||
|
'xs':
|
||||||
|
hs.astype("float32")
|
||||||
|
})
|
||||||
|
|
||||||
|
am_postnet_output = am_postnet_runtime.infer({
|
||||||
|
'xs':
|
||||||
|
np.transpose(am_decoder_output[0], (0, 2, 1))
|
||||||
|
})
|
||||||
|
am_output_data = am_decoder_output + np.transpose(
|
||||||
|
am_postnet_output[0], (0, 2, 1))
|
||||||
|
normalized_mel = am_output_data[0][0]
|
||||||
|
|
||||||
|
sub_mel = denorm(normalized_mel, am_mu, am_std)
|
||||||
|
sub_mel = depadding(sub_mel, am_chunk_num, i, am_block, am_pad, 1)
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
mel_streaming = sub_mel
|
||||||
|
else:
|
||||||
|
mel_streaming = np.concatenate((mel_streaming, sub_mel), axis=0)
|
||||||
|
|
||||||
|
# streaming voc
|
||||||
|
# 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理
|
||||||
|
while (mel_streaming.shape[0] >= end and
|
||||||
|
voc_chunk_id < voc_chunk_num):
|
||||||
|
voc_chunk = mel_streaming[start:end, :]
|
||||||
|
|
||||||
|
sub_wav = voc_melgan_runtime.infer({
|
||||||
|
'logmel':
|
||||||
|
voc_chunk.astype("float32")
|
||||||
|
})
|
||||||
|
sub_wav = depadding(sub_wav[0], voc_chunk_num, voc_chunk_id,
|
||||||
|
voc_block, voc_pad, voc_upsample)
|
||||||
|
|
||||||
|
yield sub_wav
|
||||||
|
|
||||||
|
voc_chunk_id += 1
|
||||||
|
start = max(0, voc_chunk_id * voc_block - voc_pad)
|
||||||
|
end = min((voc_chunk_id + 1) * voc_block + voc_pad, mel_len)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
text = "欢迎使用飞桨语音合成系统,测试一下合成效果。"
|
||||||
|
# warm up
|
||||||
|
# onnxruntime 第一次时间会长一些,建议先 warmup 一下
|
||||||
|
'''
|
||||||
|
# pyaudio 播放
|
||||||
|
p = pyaudio.PyAudio()
|
||||||
|
stream = p.open(
|
||||||
|
format=p.get_format_from_width(2), # int16
|
||||||
|
channels=1,
|
||||||
|
rate=24000,
|
||||||
|
output=True)
|
||||||
|
'''
|
||||||
|
# 计时
|
||||||
|
wavs = []
|
||||||
|
t1 = time.time()
|
||||||
|
for sub_wav in inference_stream(text):
|
||||||
|
print("响应时间:", time.time() - t1)
|
||||||
|
t1 = time.time()
|
||||||
|
wavs.append(sub_wav.flatten())
|
||||||
|
# float32 to int16
|
||||||
|
#wav = float2pcm(sub_wav)
|
||||||
|
# to bytes
|
||||||
|
#wav_bytes = wav.tobytes()
|
||||||
|
#stream.write(wav_bytes)
|
||||||
|
|
||||||
|
# 关闭 pyaudio 播放器
|
||||||
|
#stream.stop_stream()
|
||||||
|
#stream.close()
|
||||||
|
#p.terminate()
|
||||||
|
|
||||||
|
# 流式合成的结果导出
|
||||||
|
wav = np.concatenate(wavs)
|
||||||
|
sf.write("demo_stream.wav", data=wav, samplerate=24000)
|
@@ -22,13 +22,16 @@ docker exec -it -u root fastdeploy bash
|
|||||||
#### 1.2 Installation (inside the docker)
|
#### 1.2 Installation (inside the docker)
|
||||||
```bash
|
```bash
|
||||||
apt-get install build-essential python3-dev libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libsndfile1 language-pack-zh-hans wget zip
|
apt-get install build-essential python3-dev libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libsndfile1 language-pack-zh-hans wget zip
|
||||||
pip3 install paddlespeech
|
python3 -m pip install --upgrade pip
|
||||||
|
pip3 install -U fastdeploy-python -f https://www.paddlepaddle.org.cn/whl/fastdeploy.html
|
||||||
|
pip3 install -U paddlespeech paddlepaddle
|
||||||
export LC_ALL="zh_CN.UTF-8"
|
export LC_ALL="zh_CN.UTF-8"
|
||||||
export LANG="zh_CN.UTF-8"
|
export LANG="zh_CN.UTF-8"
|
||||||
export LANGUAGE="zh_CN:zh:en_US:en"
|
export LANGUAGE="zh_CN:zh:en_US:en"
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 1.3 Download models (inside the docker)
|
#### 1.3 Download models (inside the docker, skippable)
|
||||||
|
The model file will be downloaded and decompressed automatically. If you want to download manually, please use the following command.
|
||||||
```bash
|
```bash
|
||||||
cd /models/streaming_pp_tts/1
|
cd /models/streaming_pp_tts/1
|
||||||
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
|
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
|
||||||
|
@@ -3,11 +3,11 @@
|
|||||||
# PP-TTS流式语音合成服务化部署
|
# PP-TTS流式语音合成服务化部署
|
||||||
|
|
||||||
## 介绍
|
## 介绍
|
||||||
本文介绍了使用FastDeploy搭建流式语音合成服务的方法。
|
本文介绍了使用FastDeploy搭建流式语音合成服务的方法.
|
||||||
|
|
||||||
服务端必须在docker内启动,而客户端不是必须在docker容器内.
|
服务端必须在docker内启动,而客户端不是必须在docker容器内.
|
||||||
|
|
||||||
**本文所在路径($PWD)下的streaming_pp_tts里包含模型的配置和代码(服务端会加载模型和代码以启动服务),需要将其映射到docker中使用。**
|
**本文所在路径($PWD)下的streaming_pp_tts里包含模型的配置和代码(服务端会加载模型和代码以启动服务), 需要将其映射到docker中使用.**
|
||||||
|
|
||||||
## 使用
|
## 使用
|
||||||
### 1. 服务端
|
### 1. 服务端
|
||||||
@@ -21,13 +21,18 @@ docker exec -it -u root fastdeploy bash
|
|||||||
#### 1.2 安装(在docker内)
|
#### 1.2 安装(在docker内)
|
||||||
```bash
|
```bash
|
||||||
apt-get install build-essential python3-dev libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libsndfile1 language-pack-zh-hans wget zip
|
apt-get install build-essential python3-dev libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libsndfile1 language-pack-zh-hans wget zip
|
||||||
pip3 install paddlespeech
|
python3 -m pip install --upgrade pip
|
||||||
|
pip3 install -U fastdeploy-python -f https://www.paddlepaddle.org.cn/whl/fastdeploy.html
|
||||||
|
pip3 install -U paddlespeech paddlepaddle
|
||||||
export LC_ALL="zh_CN.UTF-8"
|
export LC_ALL="zh_CN.UTF-8"
|
||||||
export LANG="zh_CN.UTF-8"
|
export LANG="zh_CN.UTF-8"
|
||||||
export LANGUAGE="zh_CN:zh:en_US:en"
|
export LANGUAGE="zh_CN:zh:en_US:en"
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 1.3 下载模型(在docker内)
|
#### 1.3 下载模型(在docker内,可跳过)
|
||||||
|
|
||||||
|
模型文件会自动下载并解压缩, 如果您想要手动下载, 请使用下面的命令.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd /models/streaming_pp_tts/1
|
cd /models/streaming_pp_tts/1
|
||||||
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
|
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
|
||||||
@@ -35,7 +40,7 @@ wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_me
|
|||||||
unzip fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
|
unzip fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
|
||||||
unzip mb_melgan_csmsc_onnx_0.2.0.zip
|
unzip mb_melgan_csmsc_onnx_0.2.0.zip
|
||||||
```
|
```
|
||||||
**为了方便用户使用,我们推荐用户使用1.1中的`docker -v`命令将$PWD(streaming_pp_tts及里面包含的模型的配置和代码)映射到了docker内的`/models`路径,用户也可以使用其他办法,但无论使用哪种方法,最终在docker内的模型目录及结构如下图所示。**
|
**为了方便用户使用, 我们推荐用户使用1.1中的`docker -v`命令将$PWD(streaming_pp_tts及里面包含的模型的配置和代码)映射到了docker内的`/models`路径, 用户也可以使用其他办法, 但无论使用哪种方法, 最终在docker内的模型目录及结构如下图所示.**
|
||||||
|
|
||||||
```
|
```
|
||||||
/models
|
/models
|
||||||
|
@@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
@@ -19,14 +18,34 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import fastdeploy as fd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
|
||||||
import triton_python_backend_utils as pb_utils
|
import triton_python_backend_utils as pb_utils
|
||||||
|
|
||||||
from paddlespeech.server.utils.util import denorm
|
from paddlespeech.server.utils.util import denorm
|
||||||
from paddlespeech.server.utils.util import get_chunks
|
from paddlespeech.server.utils.util import get_chunks
|
||||||
from paddlespeech.t2s.frontend.zh_frontend import Frontend
|
from paddlespeech.t2s.frontend.zh_frontend import Frontend
|
||||||
|
|
||||||
|
model_name_fastspeech2 = "fastspeech2_cnndecoder_csmsc_streaming_static_1.0.0"
|
||||||
|
model_zip_fastspeech2 = model_name_fastspeech2 + ".zip"
|
||||||
|
model_url_fastspeech2 = "https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/" + model_zip_fastspeech2
|
||||||
|
model_name_mb_melgan = "mb_melgan_csmsc_static_0.1.1"
|
||||||
|
model_zip_mb_melgan = model_name_mb_melgan + ".zip"
|
||||||
|
model_url_mb_melgan = "https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/" + model_zip_mb_melgan
|
||||||
|
|
||||||
|
dir_name = os.path.dirname(os.path.realpath(__file__)) + "/"
|
||||||
|
|
||||||
|
if not os.path.exists(model_name_fastspeech2):
|
||||||
|
if os.path.exists(model_zip_fastspeech2):
|
||||||
|
os.remove(model_zip_fastspeech2)
|
||||||
|
fd.download_and_decompress(model_url_fastspeech2, path=dir_name)
|
||||||
|
os.remove(model_zip_fastspeech2)
|
||||||
|
if not os.path.exists(model_name_mb_melgan):
|
||||||
|
if os.path.exists(model_zip_mb_melgan):
|
||||||
|
os.remove(model_zip_mb_melgan)
|
||||||
|
fd.download_and_decompress(model_url_mb_melgan, path=dir_name)
|
||||||
|
os.remove(model_zip_mb_melgan)
|
||||||
|
|
||||||
voc_block = 36
|
voc_block = 36
|
||||||
voc_pad = 14
|
voc_pad = 14
|
||||||
am_block = 72
|
am_block = 72
|
||||||
@@ -34,33 +53,49 @@ am_pad = 12
|
|||||||
voc_upsample = 300
|
voc_upsample = 300
|
||||||
|
|
||||||
# 模型路径
|
# 模型路径
|
||||||
dir_name = "/models/streaming_tts_serving/1/"
|
phones_dict = dir_name + model_name_fastspeech2 + "/phone_id_map.txt"
|
||||||
phones_dict = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/phone_id_map.txt"
|
am_stat_path = dir_name + model_name_fastspeech2 + "/speech_stats.npy"
|
||||||
am_stat_path = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/speech_stats.npy"
|
|
||||||
|
|
||||||
onnx_am_encoder = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_encoder_infer.onnx"
|
am_encoder_model = dir_name + model_name_fastspeech2 + "/fastspeech2_csmsc_am_encoder_infer.pdmodel"
|
||||||
onnx_am_decoder = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_decoder.onnx"
|
am_decoder_model = dir_name + model_name_fastspeech2 + "/fastspeech2_csmsc_am_decoder.pdmodel"
|
||||||
onnx_am_postnet = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_postnet.onnx"
|
am_postnet_model = dir_name + model_name_fastspeech2 + "/fastspeech2_csmsc_am_postnet.pdmodel"
|
||||||
onnx_voc_melgan = dir_name + "mb_melgan_csmsc_onnx_0.2.0/mb_melgan_csmsc.onnx"
|
voc_melgan_model = dir_name + model_name_mb_melgan + "/mb_melgan_csmsc.pdmodel"
|
||||||
|
|
||||||
|
am_encoder_para = dir_name + model_name_fastspeech2 + "/fastspeech2_csmsc_am_encoder_infer.pdiparams"
|
||||||
|
am_decoder_para = dir_name + model_name_fastspeech2 + "/fastspeech2_csmsc_am_decoder.pdiparams"
|
||||||
|
am_postnet_para = dir_name + model_name_fastspeech2 + "/fastspeech2_csmsc_am_postnet.pdiparams"
|
||||||
|
voc_melgan_para = dir_name + model_name_mb_melgan + "/mb_melgan_csmsc.pdiparams"
|
||||||
|
|
||||||
frontend = Frontend(phone_vocab_path=phones_dict, tone_vocab_path=None)
|
frontend = Frontend(phone_vocab_path=phones_dict, tone_vocab_path=None)
|
||||||
am_mu, am_std = np.load(am_stat_path)
|
am_mu, am_std = np.load(am_stat_path)
|
||||||
|
|
||||||
# 用CPU推理
|
option_1 = fd.RuntimeOption()
|
||||||
providers = ['CPUExecutionProvider']
|
option_1.set_model_path(am_encoder_model, am_encoder_para)
|
||||||
|
option_1.use_cpu()
|
||||||
|
option_1.use_ort_backend()
|
||||||
|
option_1.set_cpu_thread_num(12)
|
||||||
|
am_encoder_runtime = fd.Runtime(option_1)
|
||||||
|
|
||||||
# 配置ort session
|
option_2 = fd.RuntimeOption()
|
||||||
sess_options = ort.SessionOptions()
|
option_2.set_model_path(am_decoder_model, am_decoder_para)
|
||||||
|
option_2.use_cpu()
|
||||||
|
option_2.use_ort_backend()
|
||||||
|
option_2.set_cpu_thread_num(12)
|
||||||
|
am_decoder_runtime = fd.Runtime(option_2)
|
||||||
|
|
||||||
# 创建session
|
option_3 = fd.RuntimeOption()
|
||||||
am_encoder_infer_sess = ort.InferenceSession(
|
option_3.set_model_path(am_postnet_model, am_postnet_para)
|
||||||
onnx_am_encoder, providers=providers, sess_options=sess_options)
|
option_3.use_cpu()
|
||||||
am_decoder_sess = ort.InferenceSession(
|
option_3.use_ort_backend()
|
||||||
onnx_am_decoder, providers=providers, sess_options=sess_options)
|
option_3.set_cpu_thread_num(12)
|
||||||
am_postnet_sess = ort.InferenceSession(
|
am_postnet_runtime = fd.Runtime(option_3)
|
||||||
onnx_am_postnet, providers=providers, sess_options=sess_options)
|
|
||||||
voc_melgan_sess = ort.InferenceSession(
|
option_4 = fd.RuntimeOption()
|
||||||
onnx_voc_melgan, providers=providers, sess_options=sess_options)
|
option_4.set_model_path(voc_melgan_model, voc_melgan_para)
|
||||||
|
option_4.use_cpu()
|
||||||
|
option_4.use_ort_backend()
|
||||||
|
option_4.set_cpu_thread_num(12)
|
||||||
|
voc_melgan_runtime = fd.Runtime(option_4)
|
||||||
|
|
||||||
|
|
||||||
def depadding(data, chunk_num, chunk_id, block, pad, upsample):
|
def depadding(data, chunk_num, chunk_id, block, pad, upsample):
|
||||||
@@ -199,8 +234,10 @@ class TritonPythonModel:
|
|||||||
part_phone_ids = phone_ids[i].numpy()
|
part_phone_ids = phone_ids[i].numpy()
|
||||||
voc_chunk_id = 0
|
voc_chunk_id = 0
|
||||||
|
|
||||||
orig_hs = am_encoder_infer_sess.run(
|
orig_hs = am_encoder_runtime.infer({
|
||||||
None, input_feed={'text': part_phone_ids})
|
'text':
|
||||||
|
part_phone_ids.astype("int64")
|
||||||
|
})
|
||||||
orig_hs = orig_hs[0]
|
orig_hs = orig_hs[0]
|
||||||
|
|
||||||
# streaming voc chunk info
|
# streaming voc chunk info
|
||||||
@@ -213,13 +250,16 @@ class TritonPythonModel:
|
|||||||
hss = get_chunks(orig_hs, am_block, am_pad, "am")
|
hss = get_chunks(orig_hs, am_block, am_pad, "am")
|
||||||
am_chunk_num = len(hss)
|
am_chunk_num = len(hss)
|
||||||
for i, hs in enumerate(hss):
|
for i, hs in enumerate(hss):
|
||||||
am_decoder_output = am_decoder_sess.run(
|
|
||||||
None, input_feed={'xs': hs})
|
am_decoder_output = am_decoder_runtime.infer({
|
||||||
am_postnet_output = am_postnet_sess.run(
|
'xs':
|
||||||
None,
|
hs.astype("float32")
|
||||||
input_feed={
|
})
|
||||||
'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
|
|
||||||
})
|
am_postnet_output = am_postnet_runtime.infer({
|
||||||
|
'xs':
|
||||||
|
np.transpose(am_decoder_output[0], (0, 2, 1))
|
||||||
|
})
|
||||||
am_output_data = am_decoder_output + np.transpose(
|
am_output_data = am_decoder_output + np.transpose(
|
||||||
am_postnet_output[0], (0, 2, 1))
|
am_postnet_output[0], (0, 2, 1))
|
||||||
normalized_mel = am_output_data[0][0]
|
normalized_mel = am_output_data[0][0]
|
||||||
@@ -239,9 +279,10 @@ class TritonPythonModel:
|
|||||||
while (mel_streaming.shape[0] >= end and
|
while (mel_streaming.shape[0] >= end and
|
||||||
voc_chunk_id < voc_chunk_num):
|
voc_chunk_id < voc_chunk_num):
|
||||||
voc_chunk = mel_streaming[start:end, :]
|
voc_chunk = mel_streaming[start:end, :]
|
||||||
|
sub_wav = voc_melgan_runtime.infer({
|
||||||
sub_wav = voc_melgan_sess.run(
|
'logmel':
|
||||||
output_names=None, input_feed={'logmel': voc_chunk})
|
voc_chunk.astype("float32")
|
||||||
|
})
|
||||||
sub_wav = depadding(sub_wav[0], voc_chunk_num, voc_chunk_id,
|
sub_wav = depadding(sub_wav[0], voc_chunk_num, voc_chunk_id,
|
||||||
voc_block, voc_pad, voc_upsample)
|
voc_block, voc_pad, voc_upsample)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user