Files
FastDeploy/examples/vision/segmentation/paddleseg/serving/paddleseg_grpc_client.py
huangjianhui 294607fc4a [Serving] PaddleSeg add triton serving && simple serving example (#1171)
* Update keypointdetection result docs

* Update im.copy() to im in examples

* Update new Api, fastdeploy::vision::Visualize to fastdeploy::vision

* Update SwapBackgroundSegmentation && SwapBackgroundMatting to SwapBackground

* Update README_CN.md

* Update README_CN.md

* Update preprocessor.h

* PaddleSeg supports triton serving

* Add PaddleSeg simple serving example

* Add PaddleSeg triton serving client code

* Update triton serving runtime config.pbtxt

* Update paddleseg grpc client

* Add paddle serving README
2023-01-30 09:34:38 +08:00

113 lines
4.1 KiB
Python

import logging
import numpy as np
import time
from typing import Optional
import cv2
import json
from tritonclient import utils as client_utils
from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2
LOGGER = logging.getLogger("run_inference_on_triton")
class SyncGRPCTritonRunner:
DEFAULT_MAX_RESP_WAIT_S = 120
def __init__(
self,
server_url: str,
model_name: str,
model_version: str,
*,
verbose=False,
resp_wait_s: Optional[float]=None, ):
self._server_url = server_url
self._model_name = model_name
self._model_version = model_version
self._verbose = verbose
self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
self._client = InferenceServerClient(
self._server_url, verbose=self._verbose)
error = self._verify_triton_state(self._client)
if error:
raise RuntimeError(
f"Could not communicate to Triton Server: {error}")
LOGGER.debug(
f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
f"are up and ready!")
model_config = self._client.get_model_config(self._model_name,
self._model_version)
model_metadata = self._client.get_model_metadata(self._model_name,
self._model_version)
LOGGER.info(f"Model config {model_config}")
LOGGER.info(f"Model metadata {model_metadata}")
for tm in model_metadata.inputs:
print("tm:", tm)
self._inputs = {tm.name: tm for tm in model_metadata.inputs}
self._input_names = list(self._inputs)
self._outputs = {tm.name: tm for tm in model_metadata.outputs}
self._output_names = list(self._outputs)
self._outputs_req = [
InferRequestedOutput(name) for name in self._outputs
]
def Run(self, inputs):
"""
Args:
inputs: list, Each value corresponds to an input name of self._input_names
Returns:
results: dict, {name : numpy.array}
"""
infer_inputs = []
for idx, data in enumerate(inputs):
infer_input = InferInput(self._input_names[idx], data.shape,
"UINT8")
infer_input.set_data_from_numpy(data)
infer_inputs.append(infer_input)
results = self._client.infer(
model_name=self._model_name,
model_version=self._model_version,
inputs=infer_inputs,
outputs=self._outputs_req,
client_timeout=self._response_wait_t, )
results = {name: results.as_numpy(name) for name in self._output_names}
return results
def _verify_triton_state(self, triton_client):
if not triton_client.is_server_live():
return f"Triton server {self._server_url} is not live"
elif not triton_client.is_server_ready():
return f"Triton server {self._server_url} is not ready"
elif not triton_client.is_model_ready(self._model_name,
self._model_version):
return f"Model {self._model_name}:{self._model_version} is not ready"
return None
if __name__ == "__main__":
model_name = "paddleseg"
model_version = "1"
url = "localhost:8001"
runner = SyncGRPCTritonRunner(url, model_name, model_version)
im = cv2.imread("cityscapes_demo.png")
im = np.array([im, ])
# batch input
# im = np.array([im, im, im])
for i in range(1):
result = runner.Run([im, ])
for name, values in result.items():
print("output_name:", name)
# values is batch
for value in values:
value = json.loads(value)
print(
"Only print the first 20 labels in label_map of SEG_RESULT")
value["label_map"] = value["label_map"][:20]
print(value)