mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			110 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			110 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging
 | |
| import numpy as np
 | |
| import time
 | |
| from typing import Optional
 | |
| import cv2
 | |
| import json
 | |
| 
 | |
| from tritonclient import utils as client_utils
 | |
| from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2
 | |
| 
 | |
| LOGGER = logging.getLogger("run_inference_on_triton")
 | |
| 
 | |
| 
 | |
| class SyncGRPCTritonRunner:
 | |
|     DEFAULT_MAX_RESP_WAIT_S = 120
 | |
| 
 | |
|     def __init__(
 | |
|             self,
 | |
|             server_url: str,
 | |
|             model_name: str,
 | |
|             model_version: str,
 | |
|             *,
 | |
|             verbose=False,
 | |
|             resp_wait_s: Optional[float]=None, ):
 | |
|         self._server_url = server_url
 | |
|         self._model_name = model_name
 | |
|         self._model_version = model_version
 | |
|         self._verbose = verbose
 | |
|         self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
 | |
| 
 | |
|         self._client = InferenceServerClient(
 | |
|             self._server_url, verbose=self._verbose)
 | |
|         error = self._verify_triton_state(self._client)
 | |
|         if error:
 | |
|             raise RuntimeError(
 | |
|                 f"Could not communicate to Triton Server: {error}")
 | |
| 
 | |
|         LOGGER.debug(
 | |
|             f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
 | |
|             f"are up and ready!")
 | |
| 
 | |
|         model_config = self._client.get_model_config(self._model_name,
 | |
|                                                      self._model_version)
 | |
|         model_metadata = self._client.get_model_metadata(self._model_name,
 | |
|                                                          self._model_version)
 | |
|         LOGGER.info(f"Model config {model_config}")
 | |
|         LOGGER.info(f"Model metadata {model_metadata}")
 | |
| 
 | |
|         for tm in model_metadata.inputs:
 | |
|             print("tm:", tm)
 | |
|         self._inputs = {tm.name: tm for tm in model_metadata.inputs}
 | |
|         self._input_names = list(self._inputs)
 | |
|         self._outputs = {tm.name: tm for tm in model_metadata.outputs}
 | |
|         self._output_names = list(self._outputs)
 | |
|         self._outputs_req = [
 | |
|             InferRequestedOutput(name) for name in self._outputs
 | |
|         ]
 | |
| 
 | |
|     def Run(self, inputs):
 | |
|         """
 | |
|         Args:
 | |
|             inputs: list, Each value corresponds to an input name of self._input_names
 | |
|         Returns:
 | |
|             results: dict, {name : numpy.array}
 | |
|         """
 | |
|         infer_inputs = []
 | |
|         for idx, data in enumerate(inputs):
 | |
|             infer_input = InferInput(self._input_names[idx], data.shape,
 | |
|                                      "UINT8")
 | |
|             infer_input.set_data_from_numpy(data)
 | |
|             infer_inputs.append(infer_input)
 | |
| 
 | |
|         results = self._client.infer(
 | |
|             model_name=self._model_name,
 | |
|             model_version=self._model_version,
 | |
|             inputs=infer_inputs,
 | |
|             outputs=self._outputs_req,
 | |
|             client_timeout=self._response_wait_t, )
 | |
|         results = {name: results.as_numpy(name) for name in self._output_names}
 | |
|         return results
 | |
| 
 | |
|     def _verify_triton_state(self, triton_client):
 | |
|         if not triton_client.is_server_live():
 | |
|             return f"Triton server {self._server_url} is not live"
 | |
|         elif not triton_client.is_server_ready():
 | |
|             return f"Triton server {self._server_url} is not ready"
 | |
|         elif not triton_client.is_model_ready(self._model_name,
 | |
|                                               self._model_version):
 | |
|             return f"Model {self._model_name}:{self._model_version} is not ready"
 | |
|         return None
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     model_name = "ppdet"
 | |
|     model_version = "1"
 | |
|     url = "localhost:8001"
 | |
|     runner = SyncGRPCTritonRunner(url, model_name, model_version)
 | |
|     im = cv2.imread("000000014439.jpg")
 | |
|     im = np.array([im, ])
 | |
|     # batch input
 | |
|     # im = np.array([im, im, im])
 | |
|     for i in range(1):
 | |
|         result = runner.Run([im, ])
 | |
|         for name, values in result.items():
 | |
|             print("output_name:", name)
 | |
|             # values is batch
 | |
|             for value in values:
 | |
|                 value = json.loads(value)
 | |
|                 print(value['boxes'])
 | 
