mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Bug Fix] fix a bug in paddle2coreml tool (#1481)
fix a bug in paddle2coreml tool
This commit is contained in:
@@ -6,7 +6,8 @@ import uvicorn
|
|||||||
def argsparser():
|
def argsparser():
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'tools', choices=['compress', 'convert', 'simple_serving', 'paddle2coreml'])
|
'tools',
|
||||||
|
choices=['compress', 'convert', 'simple_serving', 'paddle2coreml'])
|
||||||
## argumentments for auto compression
|
## argumentments for auto compression
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--config_path',
|
'--config_path',
|
||||||
@@ -89,43 +90,33 @@ def argsparser():
|
|||||||
"--p2c_paddle_model_dir",
|
"--p2c_paddle_model_dir",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=True,
|
|
||||||
help="define paddle model path")
|
help="define paddle model path")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--p2c_coreml_model_dir",
|
"--p2c_coreml_model_dir",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=True,
|
|
||||||
help="define generated coreml model path")
|
help="define generated coreml model path")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--p2c_coreml_model_name",
|
"--p2c_coreml_model_name",
|
||||||
type=str,
|
type=str,
|
||||||
default="coreml_model",
|
default="coreml_model",
|
||||||
required=False,
|
|
||||||
help="define generated coreml model name")
|
help="define generated coreml model name")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--p2c_input_names",
|
"--p2c_input_names", type=str, default=None, help="define input names")
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
required=True,
|
|
||||||
help="define input names")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--p2c_input_dtypes",
|
"--p2c_input_dtypes",
|
||||||
type=str,
|
type=str,
|
||||||
default="float32",
|
default="float32",
|
||||||
required=True,
|
|
||||||
help="define input dtypes")
|
help="define input dtypes")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--p2c_input_shapes",
|
"--p2c_input_shapes",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=True,
|
|
||||||
help="define input shapes")
|
help="define input shapes")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--p2c_output_names",
|
"--p2c_output_names",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=True,
|
|
||||||
help="define output names")
|
help="define output names")
|
||||||
## arguments for other tools
|
## arguments for other tools
|
||||||
return parser
|
return parser
|
||||||
@@ -214,9 +205,19 @@ def main():
|
|||||||
app_dir='.',
|
app_dir='.',
|
||||||
log_config=custom_logging_config)
|
log_config=custom_logging_config)
|
||||||
if args.tools == "paddle2coreml":
|
if args.tools == "paddle2coreml":
|
||||||
|
if any([
|
||||||
|
args.p2c_paddle_model_dir is None,
|
||||||
|
args.p2c_coreml_model_dir is None,
|
||||||
|
args.p2c_input_names is None, args.p2c_input_shapes is None,
|
||||||
|
args.p2c_output_names is None
|
||||||
|
]):
|
||||||
|
raise Exception(
|
||||||
|
"paddle2coreml need to define --p2c_paddle_model_dir, --p2c_coreml_model_dir, --p2c_input_names, --p2c_input_shapes, --p2c_output_names"
|
||||||
|
)
|
||||||
import coremltools as ct
|
import coremltools as ct
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
def type_to_np_dtype(dtype):
|
def type_to_np_dtype(dtype):
|
||||||
if dtype == 'float32':
|
if dtype == 'float32':
|
||||||
return np.float32
|
return np.float32
|
||||||
@@ -240,24 +241,29 @@ def main():
|
|||||||
return np.int16
|
return np.int16
|
||||||
else:
|
else:
|
||||||
raise Exception("Unsupported dtype: {}".format(dtype))
|
raise Exception("Unsupported dtype: {}".format(dtype))
|
||||||
|
|
||||||
input_names = args.p2c_input_names.split(' ')
|
input_names = args.p2c_input_names.split(' ')
|
||||||
input_shapes = [[int(i) for i in shape.split(',')] for shape in args.p2c_input_shapes.split(' ')]
|
input_shapes = [[int(i) for i in shape.split(',')]
|
||||||
|
for shape in args.p2c_input_shapes.split(' ')]
|
||||||
input_dtypes = map(type_to_np_dtype, args.p2c_input_dtypes.split(' '))
|
input_dtypes = map(type_to_np_dtype, args.p2c_input_dtypes.split(' '))
|
||||||
output_names = args.p2c_output_names.split(' ')
|
output_names = args.p2c_output_names.split(' ')
|
||||||
sample_input = [ct.TensorType(
|
sample_input = [
|
||||||
name=k,
|
ct.TensorType(
|
||||||
shape=s,
|
name=k,
|
||||||
dtype=d,
|
shape=s,
|
||||||
) for k, s, d in zip(input_names, input_shapes, input_dtypes)]
|
dtype=d, )
|
||||||
|
for k, s, d in zip(input_names, input_shapes, input_dtypes)
|
||||||
|
]
|
||||||
|
|
||||||
coreml_model = ct.convert(
|
coreml_model = ct.convert(
|
||||||
args.p2c_paddle_model_dir,
|
args.p2c_paddle_model_dir,
|
||||||
convert_to="mlprogram",
|
convert_to="mlprogram",
|
||||||
minimum_deployment_target=ct.target.macOS13,
|
minimum_deployment_target=ct.target.macOS13,
|
||||||
inputs=sample_input,
|
inputs=sample_input,
|
||||||
outputs=[ct.TensorType(name=name) for name in output_names],
|
outputs=[ct.TensorType(name=name) for name in output_names], )
|
||||||
)
|
coreml_model.save(
|
||||||
coreml_model.save(os.path.join(args.p2c_coreml_model_dir, args.p2c_coreml_model_name))
|
os.path.join(args.p2c_coreml_model_dir,
|
||||||
|
args.p2c_coreml_model_name))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@@ -7,7 +7,7 @@ install_requires = ['uvicorn==0.16.0']
|
|||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name="fastdeploy-tools", # name of package
|
name="fastdeploy-tools", # name of package
|
||||||
version="0.0.4", #version of package
|
version="0.0.5", #version of package
|
||||||
description="A toolkit for FastDeploy.",
|
description="A toolkit for FastDeploy.",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/plain",
|
long_description_content_type="text/plain",
|
||||||
|
Reference in New Issue
Block a user