FDTensor support GPU device (#190)

* fdtensor support GPU

* TRT backend support GPU FDTensor

* FDHostAllocator add FASTDEPLOY_DECL

* fix FDTensor Data

* fix FDTensor dtype

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
heliqi
2022-09-08 03:53:08 -05:00
committed by GitHub
parent bc8e9e4dae
commit 4d1f264d01
17 changed files with 432 additions and 153 deletions

View File

@@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
#include "fastdeploy/backends/ort/ort_backend.h" #include "fastdeploy/backends/ort/ort_backend.h"
#include <memory> #include <memory>
#include "fastdeploy/backends/ort/ops/multiclass_nms.h" #include "fastdeploy/backends/ort/ops/multiclass_nms.h"
#include "fastdeploy/backends/ort/utils.h" #include "fastdeploy/backends/ort/utils.h"
#include "fastdeploy/utils/utils.h" #include "fastdeploy/utils/utils.h"
@@ -164,33 +166,34 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
return true; return true;
} }
void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor, const std::string& name) { void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor,
const std::string& name) {
const auto info = value.GetTensorTypeAndShapeInfo(); const auto info = value.GetTensorTypeAndShapeInfo();
const auto data_type = info.GetElementType(); const auto data_type = info.GetElementType();
size_t numel = info.GetElementCount(); size_t numel = info.GetElementCount();
auto shape = info.GetShape();
FDDataType dtype;
if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
tensor->Allocate(info.GetShape(), FDDataType::FP32, name); dtype = FDDataType::FP32;
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(), numel *= sizeof(float);
numel * sizeof(float));
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { } else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
tensor->Allocate(info.GetShape(), FDDataType::INT32, name); dtype = FDDataType::INT32;
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(), numel *= sizeof(int32_t);
numel * sizeof(int32_t));
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { } else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
tensor->Allocate(info.GetShape(), FDDataType::INT64, name); dtype = FDDataType::INT64;
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(), numel *= sizeof(int64_t);
numel * sizeof(int64_t));
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { } else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
tensor->Allocate(info.GetShape(), FDDataType::FP64, name); dtype = FDDataType::FP64;
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(), numel *= sizeof(double);
numel * sizeof(double));
} else { } else {
FDASSERT( FDASSERT(
false, false,
"Unrecognized data type of %d while calling OrtBackend::CopyToCpu().", "Unrecognized data type of %d while calling OrtBackend::CopyToCpu().",
data_type); data_type);
} }
tensor->Resize(shape, dtype, name);
memcpy(tensor->MutableData(), value.GetTensorData<void*>(), numel);
} }
bool OrtBackend::Infer(std::vector<FDTensor>& inputs, bool OrtBackend::Infer(std::vector<FDTensor>& inputs,

View File

@@ -88,6 +88,7 @@ class OrtBackend : public BaseBackend {
Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle"); Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle");
#endif #endif
OrtBackendOption option_; OrtBackendOption option_;
void CopyToCpu(const Ort::Value& value, FDTensor* tensor, const std::string& name); void CopyToCpu(const Ort::Value& value, FDTensor* tensor,
const std::string& name);
}; };
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -79,16 +79,23 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
} }
TensorInfo PaddleBackend::GetInputInfo(int index) { TensorInfo PaddleBackend::GetInputInfo(int index) {
FDASSERT(index < NumInputs(), "The index: %d should less than the number of inputs: %d.", index, NumInputs()); FDASSERT(index < NumInputs(),
"The index: %d should less than the number of inputs: %d.", index,
NumInputs());
return inputs_desc_[index]; return inputs_desc_[index];
} }
std::vector<TensorInfo> PaddleBackend::GetInputInfo() { return inputs_desc_; }
TensorInfo PaddleBackend::GetOutputInfo(int index) { TensorInfo PaddleBackend::GetOutputInfo(int index) {
FDASSERT(index < NumOutputs(), FDASSERT(index < NumOutputs(),
"The index: %d should less than the number of outputs %d.", index, NumOutputs()); "The index: %d should less than the number of outputs %d.", index,
NumOutputs());
return outputs_desc_[index]; return outputs_desc_[index];
} }
std::vector<TensorInfo> PaddleBackend::GetOutputInfo() { return outputs_desc_; }
bool PaddleBackend::Infer(std::vector<FDTensor>& inputs, bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs) { std::vector<FDTensor>* outputs) {
if (inputs.size() != inputs_desc_.size()) { if (inputs.size() != inputs_desc_.size()) {
@@ -100,7 +107,7 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
auto handle = predictor_->GetInputHandle(inputs[i].name); auto handle = predictor_->GetInputHandle(inputs[i].name);
ShareTensorFromCpu(handle.get(), inputs[i]); ShareTensorFromFDTensor(handle.get(), inputs[i]);
} }
predictor_->Run(); predictor_->Run();

View File

@@ -44,8 +44,11 @@ struct PaddleBackendOption {
std::vector<std::string> delete_pass_names = {}; std::vector<std::string> delete_pass_names = {};
}; };
// convert FD device to paddle place type
paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device);
// Share memory buffer with paddle_infer::Tensor from fastdeploy::FDTensor // Share memory buffer with paddle_infer::Tensor from fastdeploy::FDTensor
void ShareTensorFromCpu(paddle_infer::Tensor* tensor, FDTensor& fd_tensor); void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);
// Copy memory data from paddle_infer::Tensor to fastdeploy::FDTensor // Copy memory data from paddle_infer::Tensor to fastdeploy::FDTensor
void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor, void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
@@ -72,6 +75,8 @@ class PaddleBackend : public BaseBackend {
TensorInfo GetInputInfo(int index); TensorInfo GetInputInfo(int index);
TensorInfo GetOutputInfo(int index); TensorInfo GetOutputInfo(int index);
std::vector<TensorInfo> GetInputInfo();
std::vector<TensorInfo> GetOutputInfo();
private: private:
paddle_infer::Config config_; paddle_infer::Config config_;

View File

@@ -15,23 +15,33 @@
#include "fastdeploy/backends/paddle/paddle_backend.h" #include "fastdeploy/backends/paddle/paddle_backend.h"
namespace fastdeploy { namespace fastdeploy {
void ShareTensorFromCpu(paddle_infer::Tensor* tensor, FDTensor& fd_tensor) { paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device) {
if (device == Device::GPU) {
return paddle_infer::PlaceType::kGPU;
}
return paddle_infer::PlaceType::kCPU;
}
void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor,
FDTensor& fd_tensor) {
std::vector<int> shape(fd_tensor.shape.begin(), fd_tensor.shape.end()); std::vector<int> shape(fd_tensor.shape.begin(), fd_tensor.shape.end());
tensor->Reshape(shape); tensor->Reshape(shape);
auto place = ConvertFDDeviceToPlace(fd_tensor.device);
if (fd_tensor.dtype == FDDataType::FP32) { if (fd_tensor.dtype == FDDataType::FP32) {
tensor->ShareExternalData(static_cast<const float*>(fd_tensor.Data()), tensor->ShareExternalData(static_cast<const float*>(fd_tensor.Data()),
shape, paddle_infer::PlaceType::kCPU); shape, place);
return; return;
} else if (fd_tensor.dtype == FDDataType::INT32) { } else if (fd_tensor.dtype == FDDataType::INT32) {
tensor->ShareExternalData(static_cast<const int32_t*>(fd_tensor.Data()), tensor->ShareExternalData(static_cast<const int32_t*>(fd_tensor.Data()),
shape, paddle_infer::PlaceType::kCPU); shape, place);
return; return;
} else if (fd_tensor.dtype == FDDataType::INT64) { } else if (fd_tensor.dtype == FDDataType::INT64) {
tensor->ShareExternalData(static_cast<const int64_t*>(fd_tensor.Data()), tensor->ShareExternalData(static_cast<const int64_t*>(fd_tensor.Data()),
shape, paddle_infer::PlaceType::kCPU); shape, place);
return; return;
} }
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", Str(fd_tensor.dtype).c_str()); FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
Str(fd_tensor.dtype).c_str());
} }
void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor, void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
@@ -51,7 +61,8 @@ void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
tensor->CopyToCpu(static_cast<int64_t*>(fd_tensor->MutableData())); tensor->CopyToCpu(static_cast<int64_t*>(fd_tensor->MutableData()));
return; return;
} }
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", Str(fd_tensor->dtype).c_str()); FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
Str(fd_tensor->dtype).c_str());
} }
FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) { FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) {
@@ -65,7 +76,10 @@ FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) {
} else if (dtype == paddle_infer::UINT8) { } else if (dtype == paddle_infer::UINT8) {
fd_dtype = FDDataType::UINT8; fd_dtype = FDDataType::UINT8;
} else { } else {
FDASSERT(false, "Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.", int(dtype)); FDASSERT(
false,
"Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.",
int(dtype));
} }
return fd_dtype; return fd_dtype;
} }

View File

@@ -13,9 +13,11 @@
// limitations under the License. // limitations under the License.
#include "fastdeploy/backends/tensorrt/trt_backend.h" #include "fastdeploy/backends/tensorrt/trt_backend.h"
#include <cstring>
#include "NvInferSafeRuntime.h" #include "NvInferSafeRuntime.h"
#include "fastdeploy/utils/utils.h" #include "fastdeploy/utils/utils.h"
#include <cstring>
#ifdef ENABLE_PADDLE_FRONTEND #ifdef ENABLE_PADDLE_FRONTEND
#include "paddle2onnx/converter.h" #include "paddle2onnx/converter.h"
#endif #endif
@@ -210,9 +212,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
outputs_desc_.resize(onnx_reader.num_outputs); outputs_desc_.resize(onnx_reader.num_outputs);
for (int i = 0; i < onnx_reader.num_inputs; ++i) { for (int i = 0; i < onnx_reader.num_inputs; ++i) {
std::string name(onnx_reader.inputs[i].name); std::string name(onnx_reader.inputs[i].name);
std::vector<int64_t> shape(onnx_reader.inputs[i].shape, std::vector<int64_t> shape(
onnx_reader.inputs[i].shape + onnx_reader.inputs[i].shape,
onnx_reader.inputs[i].rank); onnx_reader.inputs[i].shape + onnx_reader.inputs[i].rank);
inputs_desc_[i].name = name; inputs_desc_[i].name = name;
inputs_desc_[i].shape.assign(shape.begin(), shape.end()); inputs_desc_[i].shape.assign(shape.begin(), shape.end());
inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype); inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype);
@@ -231,9 +233,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
for (int i = 0; i < onnx_reader.num_outputs; ++i) { for (int i = 0; i < onnx_reader.num_outputs; ++i) {
std::string name(onnx_reader.outputs[i].name); std::string name(onnx_reader.outputs[i].name);
std::vector<int64_t> shape(onnx_reader.outputs[i].shape, std::vector<int64_t> shape(
onnx_reader.outputs[i].shape + onnx_reader.outputs[i].shape,
onnx_reader.outputs[i].rank); onnx_reader.outputs[i].shape + onnx_reader.outputs[i].rank);
outputs_desc_[i].name = name; outputs_desc_[i].name = name;
outputs_desc_[i].shape.assign(shape.begin(), shape.end()); outputs_desc_[i].shape.assign(shape.begin(), shape.end());
outputs_desc_[i].dtype = outputs_desc_[i].dtype =
@@ -286,24 +288,8 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
BuildTrtEngine(); BuildTrtEngine();
} }
AllocateBufferInDynamicShape(inputs, outputs); SetInputs(inputs);
std::vector<void*> input_binds(inputs.size()); AllocateOutputsBuffer(outputs);
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs[i].dtype == FDDataType::INT64) {
int64_t* data = static_cast<int64_t*>(inputs[i].Data());
std::vector<int32_t> casted_data(data, data + inputs[i].Numel());
FDASSERT(cudaMemcpyAsync(inputs_buffer_[inputs[i].name].data(),
static_cast<void*>(casted_data.data()),
inputs[i].Nbytes() / 2, cudaMemcpyHostToDevice,
stream_) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU.");
} else {
FDASSERT(cudaMemcpyAsync(inputs_buffer_[inputs[i].name].data(),
inputs[i].Data(), inputs[i].Nbytes(),
cudaMemcpyHostToDevice, stream_) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU.");
}
}
if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) { if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
FDERROR << "Failed to Infer with TensorRT." << std::endl; FDERROR << "Failed to Infer with TensorRT." << std::endl;
return false; return false;
@@ -339,18 +325,50 @@ void TrtBackend::GetInputOutputInfo() {
bindings_.resize(num_binds); bindings_.resize(num_binds);
} }
void TrtBackend::AllocateBufferInDynamicShape( void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
const std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs) {
for (const auto& item : inputs) { for (const auto& item : inputs) {
auto idx = engine_->getBindingIndex(item.name.c_str()); auto idx = engine_->getBindingIndex(item.name.c_str());
std::vector<int> shape(item.shape.begin(), item.shape.end()); std::vector<int> shape(item.shape.begin(), item.shape.end());
auto dims = ToDims(shape); auto dims = ToDims(shape);
context_->setBindingDimensions(idx, dims); context_->setBindingDimensions(idx, dims);
if (item.Nbytes() > inputs_buffer_[item.name].nbBytes()) {
if (item.device == Device::GPU) {
if (item.dtype == FDDataType::INT64) {
// TODO(liqi): cast int64 to int32
// TRT don't support INT64
FDASSERT(false,
"TRT don't support INT64 input on GPU, "
"please use INT32 input");
} else {
// no copy
inputs_buffer_[item.name].SetExternalData(dims, item.Data());
}
} else {
// Allocate input buffer memory
inputs_buffer_[item.name].resize(dims); inputs_buffer_[item.name].resize(dims);
bindings_[idx] = inputs_buffer_[item.name].data();
// copy from cpu to gpu
if (item.dtype == FDDataType::INT64) {
int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data()));
std::vector<int32_t> casted_data(data, data + item.Numel());
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(),
static_cast<void*>(casted_data.data()),
item.Nbytes() / 2, cudaMemcpyHostToDevice,
stream_) == 0,
"Error occurs while copy memory from CPU to GPU.");
} else {
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), item.Data(),
item.Nbytes(), cudaMemcpyHostToDevice,
stream_) == 0,
"Error occurs while copy memory from CPU to GPU.");
}
} }
// binding input buffer
bindings_[idx] = inputs_buffer_[item.name].data();
} }
}
void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
if (outputs->size() != outputs_desc_.size()) { if (outputs->size() != outputs_desc_.size()) {
outputs->resize(outputs_desc_.size()); outputs->resize(outputs_desc_.size());
} }
@@ -365,13 +383,15 @@ void TrtBackend::AllocateBufferInDynamicShape(
"Cannot find output: %s of tensorrt network from the original model.", "Cannot find output: %s of tensorrt network from the original model.",
outputs_desc_[i].name.c_str()); outputs_desc_[i].name.c_str());
auto ori_idx = iter->second; auto ori_idx = iter->second;
std::vector<int64_t> shape(output_dims.d, output_dims.d + output_dims.nbDims); // set user's outputs info
(*outputs)[ori_idx].Allocate(shape, GetFDDataType(outputs_desc_[i].dtype), outputs_desc_[i].name); std::vector<int64_t> shape(output_dims.d,
if ((*outputs)[ori_idx].Nbytes() > output_dims.d + output_dims.nbDims);
outputs_buffer_[outputs_desc_[i].name].nbBytes()) { (*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype),
outputs_buffer_[outputs_desc_[i].name].resize(output_dims); outputs_desc_[i].name);
bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data(); // Allocate output buffer memory
} outputs_buffer_[outputs_desc_[i].name].resize(output_dims);
// binding output buffer
bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data();
} }
} }
@@ -580,4 +600,4 @@ TensorInfo TrtBackend::GetOutputInfo(int index) {
info.dtype = GetFDDataType(outputs_desc_[index].dtype); info.dtype = GetFDDataType(outputs_desc_[index].dtype);
return info; return info;
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <cuda_runtime_api.h>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <string> #include <string>
@@ -23,7 +25,6 @@
#include "NvOnnxParser.h" #include "NvOnnxParser.h"
#include "fastdeploy/backends/backend.h" #include "fastdeploy/backends/backend.h"
#include "fastdeploy/backends/tensorrt/utils.h" #include "fastdeploy/backends/tensorrt/utils.h"
#include <cuda_runtime_api.h>
namespace fastdeploy { namespace fastdeploy {
@@ -109,12 +110,12 @@ class TrtBackend : public BaseBackend {
std::map<std::string, ShapeRangeInfo> shape_range_info_; std::map<std::string, ShapeRangeInfo> shape_range_info_;
void GetInputOutputInfo(); void GetInputOutputInfo();
void AllocateBufferInDynamicShape(const std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs);
bool CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer); bool CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer);
bool BuildTrtEngine(); bool BuildTrtEngine();
bool LoadTrtCache(const std::string& trt_engine_file); bool LoadTrtCache(const std::string& trt_engine_file);
int ShapeRangeInfoUpdated(const std::vector<FDTensor>& inputs); int ShapeRangeInfoUpdated(const std::vector<FDTensor>& inputs);
void SetInputs(const std::vector<FDTensor>& inputs);
void AllocateOutputsBuffer(std::vector<FDTensor>* outputs);
}; };
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -14,11 +14,9 @@
#pragma once #pragma once
#include "NvInfer.h"
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/utils/utils.h"
#include <algorithm>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <algorithm>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <memory> #include <memory>
@@ -26,17 +24,24 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "NvInfer.h"
#include "fastdeploy/core/allocate.h"
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/utils/utils.h"
namespace fastdeploy { namespace fastdeploy {
struct FDInferDeleter { struct FDInferDeleter {
template <typename T> void operator()(T* obj) const { template <typename T>
void operator()(T* obj) const {
if (obj) { if (obj) {
obj->destroy(); obj->destroy();
} }
} }
}; };
template <typename T> using FDUniquePtr = std::unique_ptr<T, FDInferDeleter>; template <typename T>
using FDUniquePtr = std::unique_ptr<T, FDInferDeleter>;
int64_t Volume(const nvinfer1::Dims& d); int64_t Volume(const nvinfer1::Dims& d);
@@ -64,13 +69,18 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& vec) {
return out; return out;
} }
template <typename AllocFunc, typename FreeFunc> class FDGenericBuffer { template <typename AllocFunc, typename FreeFunc>
class FDGenericBuffer {
public: public:
//! //!
//! \brief Construct an empty buffer. //! \brief Construct an empty buffer.
//! //!
explicit FDGenericBuffer(nvinfer1::DataType type = nvinfer1::DataType::kFLOAT) explicit FDGenericBuffer(nvinfer1::DataType type = nvinfer1::DataType::kFLOAT)
: mSize(0), mCapacity(0), mType(type), mBuffer(nullptr) {} : mSize(0),
mCapacity(0),
mType(type),
mBuffer(nullptr),
mExternal_buffer(nullptr) {}
//! //!
//! \brief Construct a buffer with the specified allocation size in bytes. //! \brief Construct a buffer with the specified allocation size in bytes.
@@ -82,8 +92,18 @@ template <typename AllocFunc, typename FreeFunc> class FDGenericBuffer {
} }
} }
//!
//! \brief This use to skip memory copy step.
//!
FDGenericBuffer(size_t size, nvinfer1::DataType type, void* buffer)
: mSize(size), mCapacity(size), mType(type) {
mExternal_buffer = buffer;
}
FDGenericBuffer(FDGenericBuffer&& buf) FDGenericBuffer(FDGenericBuffer&& buf)
: mSize(buf.mSize), mCapacity(buf.mCapacity), mType(buf.mType), : mSize(buf.mSize),
mCapacity(buf.mCapacity),
mType(buf.mType),
mBuffer(buf.mBuffer) { mBuffer(buf.mBuffer) {
buf.mSize = 0; buf.mSize = 0;
buf.mCapacity = 0; buf.mCapacity = 0;
@@ -109,12 +129,18 @@ template <typename AllocFunc, typename FreeFunc> class FDGenericBuffer {
//! //!
//! \brief Returns pointer to underlying array. //! \brief Returns pointer to underlying array.
//! //!
void* data() { return mBuffer; } void* data() {
if (mExternal_buffer != nullptr) return mExternal_buffer;
return mBuffer;
}
//! //!
//! \brief Returns pointer to underlying array. //! \brief Returns pointer to underlying array.
//! //!
const void* data() const { return mBuffer; } const void* data() const {
if (mExternal_buffer != nullptr) return mExternal_buffer;
return mBuffer;
}
//! //!
//! \brief Returns the size (in number of elements) of the buffer. //! \brief Returns the size (in number of elements) of the buffer.
@@ -126,11 +152,29 @@ template <typename AllocFunc, typename FreeFunc> class FDGenericBuffer {
//! //!
size_t nbBytes() const { return this->size() * TrtDataTypeSize(mType); } size_t nbBytes() const { return this->size() * TrtDataTypeSize(mType); }
//!
//! \brief Set user memory buffer for TRT Buffer
//!
void SetExternalData(size_t size, nvinfer1::DataType type, void* buffer) {
mSize = mCapacity = size;
mType = type;
mExternal_buffer = const_cast<void*>(buffer);
}
//!
//! \brief Set user memory buffer for TRT Buffer
//!
void SetExternalData(const nvinfer1::Dims& dims, const void* buffer) {
mSize = mCapacity = Volume(dims);
mExternal_buffer = const_cast<void*>(buffer);
}
//! //!
//! \brief Resizes the buffer. This is a no-op if the new size is smaller than //! \brief Resizes the buffer. This is a no-op if the new size is smaller than
//! or equal to the current capacity. //! or equal to the current capacity.
//! //!
void resize(size_t newSize) { void resize(size_t newSize) {
mExternal_buffer = nullptr;
mSize = newSize; mSize = newSize;
if (mCapacity < newSize) { if (mCapacity < newSize) {
freeFn(mBuffer); freeFn(mBuffer);
@@ -146,28 +190,20 @@ template <typename AllocFunc, typename FreeFunc> class FDGenericBuffer {
//! //!
void resize(const nvinfer1::Dims& dims) { return this->resize(Volume(dims)); } void resize(const nvinfer1::Dims& dims) { return this->resize(Volume(dims)); }
~FDGenericBuffer() { freeFn(mBuffer); } ~FDGenericBuffer() {
mExternal_buffer = nullptr;
freeFn(mBuffer);
}
private: private:
size_t mSize{0}, mCapacity{0}; size_t mSize{0}, mCapacity{0};
nvinfer1::DataType mType; nvinfer1::DataType mType;
void* mBuffer; void* mBuffer;
void* mExternal_buffer;
AllocFunc allocFn; AllocFunc allocFn;
FreeFunc freeFn; FreeFunc freeFn;
}; };
class FDDeviceAllocator {
public:
bool operator()(void** ptr, size_t size) const {
return cudaMalloc(ptr, size) == cudaSuccess;
}
};
class FDDeviceFree {
public:
void operator()(void* ptr) const { cudaFree(ptr); }
};
using FDDeviceBuffer = FDGenericBuffer<FDDeviceAllocator, FDDeviceFree>; using FDDeviceBuffer = FDGenericBuffer<FDDeviceAllocator, FDDeviceFree>;
class FDTrtLogger : public nvinfer1::ILogger { class FDTrtLogger : public nvinfer1::ILogger {
@@ -197,7 +233,7 @@ class FDTrtLogger : public nvinfer1::ILogger {
}; };
struct ShapeRangeInfo { struct ShapeRangeInfo {
ShapeRangeInfo(const std::vector<int64_t>& new_shape) { explicit ShapeRangeInfo(const std::vector<int64_t>& new_shape) {
shape.assign(new_shape.begin(), new_shape.end()); shape.assign(new_shape.begin(), new_shape.end());
min.resize(new_shape.size()); min.resize(new_shape.size());
max.resize(new_shape.size()); max.resize(new_shape.size());
@@ -239,4 +275,4 @@ struct ShapeRangeInfo {
} }
}; };
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -0,0 +1,41 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef WITH_GPU
#include <cuda_runtime_api.h>
#endif
#include "fastdeploy/core/allocate.h"
namespace fastdeploy {
bool FDHostAllocator::operator()(void** ptr, size_t size) const {
*ptr = malloc(size);
return *ptr != nullptr;
}
void FDHostFree::operator()(void* ptr) const { free(ptr); }
#ifdef WITH_GPU
bool FDDeviceAllocator::operator()(void** ptr, size_t size) const {
return cudaMalloc(ptr, size) == cudaSuccess;
}
void FDDeviceFree::operator()(void* ptr) const { cudaFree(ptr); }
#endif
} // namespace fastdeploy

View File

@@ -0,0 +1,50 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <new>
#include <numeric>
#include <string>
#include <vector>
#include "fastdeploy/utils/utils.h"
namespace fastdeploy {
class FASTDEPLOY_DECL FDHostAllocator {
public:
bool operator()(void** ptr, size_t size) const;
};
class FASTDEPLOY_DECL FDHostFree {
public:
void operator()(void* ptr) const;
};
#ifdef WITH_GPU
class FASTDEPLOY_DECL FDDeviceAllocator {
public:
bool operator()(void** ptr, size_t size) const;
};
class FASTDEPLOY_DECL FDDeviceFree {
public:
void operator()(void* ptr) const;
};
#endif
} // namespace fastdeploy

View File

@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "fastdeploy/core/fd_tensor.h" #include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/utils/utils.h" #include "fastdeploy/utils/utils.h"
#ifdef WITH_GPU #ifdef WITH_GPU
@@ -25,55 +26,69 @@ void* FDTensor::MutableData() {
if (external_data_ptr != nullptr) { if (external_data_ptr != nullptr) {
return external_data_ptr; return external_data_ptr;
} }
return data.data(); return buffer_;
} }
void* FDTensor::Data() { void* FDTensor::Data() {
if (external_data_ptr != nullptr) { if (external_data_ptr != nullptr) {
if (device == Device::GPU) { return external_data_ptr;
#ifdef WITH_GPU
// need to copy cuda mem to cpu first
temporary_cpu_buffer.resize(Nbytes());
FDASSERT(cudaMemcpy(temporary_cpu_buffer.data(), external_data_ptr,
Nbytes(), cudaMemcpyDeviceToHost) == 0,
"[ERROR] Error occurs while copy memory from GPU to CPU");
return temporary_cpu_buffer.data();
#else
FDASSERT(false,
"The FastDeploy didn't compile under -DWITH_GPU=ON, so this is "
"an unexpected problem happend.");
#endif
} else {
return external_data_ptr;
}
} }
return data.data(); return buffer_;
} }
const void* FDTensor::Data() const { const void* FDTensor::Data() const {
if (external_data_ptr != nullptr) { if (external_data_ptr != nullptr) {
return external_data_ptr; return external_data_ptr;
} }
return data.data(); return buffer_;
}
const void* FDTensor::CpuData() const {
if (device == Device::GPU) {
#ifdef WITH_GPU
auto* cpu_ptr = const_cast<std::vector<int8_t>*>(&temporary_cpu_buffer);
cpu_ptr->resize(Nbytes());
// need to copy cuda mem to cpu first
if (external_data_ptr != nullptr) {
FDASSERT(cudaMemcpy(cpu_ptr->data(), external_data_ptr, Nbytes(),
cudaMemcpyDeviceToHost) == 0,
"[ERROR] Error occurs while copy memory from GPU to CPU");
} else {
FDASSERT(cudaMemcpy(cpu_ptr->data(), buffer_, Nbytes(),
cudaMemcpyDeviceToHost) == 0,
"[ERROR] Error occurs while buffer copy memory from GPU to CPU");
}
return cpu_ptr->data();
#else
FDASSERT(false,
"The FastDeploy didn't compile under -DWITH_GPU=ON, so this is "
"an unexpected problem happend.");
#endif
}
return Data();
} }
void FDTensor::SetExternalData(const std::vector<int64_t>& new_shape, void FDTensor::SetExternalData(const std::vector<int64_t>& new_shape,
const FDDataType& data_type, void* data_buffer) { const FDDataType& data_type, void* data_buffer,
const Device& new_device) {
dtype = data_type; dtype = data_type;
shape.assign(new_shape.begin(), new_shape.end()); shape.assign(new_shape.begin(), new_shape.end());
external_data_ptr = data_buffer; external_data_ptr = data_buffer;
device = new_device;
} }
void FDTensor::Allocate(const std::vector<int64_t>& new_shape, void FDTensor::Allocate(const std::vector<int64_t>& new_shape,
const FDDataType& data_type, const FDDataType& data_type,
const std::string& tensor_name) { const std::string& tensor_name,
const Device& new_device) {
dtype = data_type; dtype = data_type;
name = tensor_name; name = tensor_name;
shape.assign(new_shape.begin(), new_shape.end()); shape.assign(new_shape.begin(), new_shape.end());
int unit = FDDataTypeSize(data_type); device = new_device;
int total_size = size_t nbytes = Nbytes();
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); FDASSERT(AllocFn(nbytes),
data.resize(total_size * unit); "The FastDeploy FDTensor allocate cpu memory error");
} }
int FDTensor::Nbytes() const { return Numel() * FDDataTypeSize(dtype); } int FDTensor::Nbytes() const { return Numel() * FDDataTypeSize(dtype); }
@@ -82,6 +97,44 @@ int FDTensor::Numel() const {
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
} }
void FDTensor::Resize(size_t new_nbytes) {
size_t nbytes = Nbytes();
if (new_nbytes > nbytes) {
FreeFn();
AllocFn(new_nbytes);
}
}
void FDTensor::Resize(const std::vector<int64_t>& new_shape) {
int numel = Numel();
int new_numel = std::accumulate(new_shape.begin(), new_shape.end(), 1,
std::multiplies<int>());
shape.assign(new_shape.begin(), new_shape.end());
if (new_numel > numel) {
FreeFn();
size_t nbytes = new_numel * FDDataTypeSize(dtype);
AllocFn(nbytes);
}
}
void FDTensor::Resize(const std::vector<int64_t>& new_shape,
const FDDataType& data_type,
const std::string& tensor_name,
const Device& new_device) {
name = tensor_name;
device = new_device;
size_t nbytes = Nbytes();
shape.assign(new_shape.begin(), new_shape.end());
dtype = data_type;
int new_nbytes = std::accumulate(new_shape.begin(), new_shape.end(), 1,
std::multiplies<int>()) *
FDDataTypeSize(data_type);
if (new_nbytes > nbytes) {
FreeFn();
AllocFn(new_nbytes);
}
}
template <typename T> template <typename T>
void CalculateStatisInfo(void* src_ptr, int size, double* mean, double* max, void CalculateStatisInfo(void* src_ptr, int size, double* mean, double* max,
double* min) { double* min) {

View File

@@ -18,15 +18,17 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "fastdeploy/core/allocate.h"
#include "fastdeploy/core/fd_type.h" #include "fastdeploy/core/fd_type.h"
namespace fastdeploy { namespace fastdeploy {
struct FASTDEPLOY_DECL FDTensor { struct FASTDEPLOY_DECL FDTensor {
std::vector<int8_t> data; // std::vector<int8_t> data;
std::vector<int64_t> shape; void* buffer_ = nullptr;
std::vector<int64_t> shape = {0};
std::string name = ""; std::string name = "";
FDDataType dtype; FDDataType dtype = FDDataType::INT8;
// This use to skip memory copy step // This use to skip memory copy step
// the external_data_ptr will point to the user allocated memory // the external_data_ptr will point to the user allocated memory
@@ -46,28 +48,32 @@ struct FASTDEPLOY_DECL FDTensor {
// Get data buffer pointer // Get data buffer pointer
void* MutableData(); void* MutableData();
// Use this data to get the tensor data to process
// Since the most senario is process data in CPU
// this function weill return a pointer to cpu memory
// buffer.
// If the original data is on other device, the data
// will copy to cpu store in `temporary_cpu_buffer`
void* Data(); void* Data();
const void* Data() const; const void* Data() const;
// Use this data to get the tensor data to process
// Since the most senario is process data in CPU
// this function will return a pointer to cpu memory
// buffer.
// If the original data is on other device, the data
// will copy to cpu store in `temporary_cpu_buffer`
const void* CpuData() const;
// Set user memory buffer for Tensor, the memory is managed by // Set user memory buffer for Tensor, the memory is managed by
// the user it self, but the Tensor will share the memory with user // the user it self, but the Tensor will share the memory with user
// So take care with the user buffer // So take care with the user buffer
void SetExternalData(const std::vector<int64_t>& new_shape, void SetExternalData(const std::vector<int64_t>& new_shape,
const FDDataType& data_type, void* data_buffer); const FDDataType& data_type, void* data_buffer,
const Device& new_device = Device::CPU);
// Initialize Tensor // Initialize Tensor
// Include setting attribute for tensor // Include setting attribute for tensor
// and allocate cpu memory buffer // and allocate cpu memory buffer
void Allocate(const std::vector<int64_t>& new_shape, void Allocate(const std::vector<int64_t>& new_shape,
const FDDataType& data_type, const FDDataType& data_type,
const std::string& tensor_name = ""); const std::string& tensor_name = "",
const Device& new_device = Device::CPU);
// Total size of tensor memory buffer in bytes // Total size of tensor memory buffer in bytes
int Nbytes() const; int Nbytes() const;
@@ -75,13 +81,51 @@ struct FASTDEPLOY_DECL FDTensor {
// Total number of elements in this tensor // Total number of elements in this tensor
int Numel() const; int Numel() const;
void Resize(size_t nbytes);
void Resize(const std::vector<int64_t>& new_shape);
void Resize(const std::vector<int64_t>& new_shape,
const FDDataType& data_type, const std::string& tensor_name = "",
const Device& new_device = Device::CPU);
// Debug function // Debug function
// Use this function to print shape, dtype, mean, max, min // Use this function to print shape, dtype, mean, max, min
// prefix will also be printed as tag // prefix will also be printed as tag
void PrintInfo(const std::string& prefix = "TensorInfo: "); void PrintInfo(const std::string& prefix = "TensorInfo: ");
bool AllocFn(size_t nbytes) {
if (device == Device::GPU) {
#ifdef WITH_GPU
return FDDeviceAllocator()(&buffer_, nbytes);
#else
FDASSERT(false,
"The FastDeploy FDTensor allocator didn't compile under "
"-DWITH_GPU=ON,"
"so this is an unexpected problem happend.");
#endif
}
return FDHostAllocator()(&buffer_, nbytes);
}
void FreeFn() {
if (external_data_ptr != nullptr) external_data_ptr = nullptr;
if (buffer_ != nullptr) {
if (device == Device::GPU) {
#ifdef WITH_GPU
FDDeviceFree()(buffer_);
#endif
} else {
FDHostFree()(buffer_);
}
buffer_ = nullptr;
}
}
FDTensor() {} FDTensor() {}
explicit FDTensor(const std::string& tensor_name); explicit FDTensor(const std::string& tensor_name);
~FDTensor() { FreeFn(); }
}; };
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "fastdeploy/core/fd_type.h" #include "fastdeploy/core/fd_type.h"
#include "fastdeploy/utils/utils.h" #include "fastdeploy/utils/utils.h"
namespace fastdeploy { namespace fastdeploy {
@@ -33,6 +34,8 @@ int FDDataTypeSize(const FDDataType& data_type) {
return sizeof(double); return sizeof(double);
} else if (data_type == FDDataType::UINT8) { } else if (data_type == FDDataType::UINT8) {
return sizeof(uint8_t); return sizeof(uint8_t);
} else if (data_type == FDDataType::INT8) {
return sizeof(int8_t);
} else { } else {
FDASSERT(false, "Unexpected data type: %s", Str(data_type).c_str()); FDASSERT(false, "Unexpected data type: %s", Str(data_type).c_str());
} }
@@ -42,9 +45,6 @@ int FDDataTypeSize(const FDDataType& data_type) {
std::string Str(const Device& d) { std::string Str(const Device& d) {
std::string out; std::string out;
switch (d) { switch (d) {
case Device::DEFAULT:
out = "Device::DEFAULT";
break;
case Device::CPU: case Device::CPU:
out = "Device::CPU"; out = "Device::CPU";
break; break;

View File

@@ -22,7 +22,7 @@
namespace fastdeploy { namespace fastdeploy {
enum FASTDEPLOY_DECL Device { DEFAULT, CPU, GPU }; enum FASTDEPLOY_DECL Device { CPU, GPU };
FASTDEPLOY_DECL std::string Str(const Device& d); FASTDEPLOY_DECL std::string Str(const Device& d);

View File

@@ -71,14 +71,13 @@ void BindRuntime(pybind11::module& m) {
std::vector<FDTensor> inputs(data.size()); std::vector<FDTensor> inputs(data.size());
int index = 0; int index = 0;
for (auto iter = data.begin(); iter != data.end(); ++iter) { for (auto iter = data.begin(); iter != data.end(); ++iter) {
inputs[index].dtype = std::vector<int64_t> data_shape;
NumpyDataTypeToFDDataType(iter->second.dtype()); data_shape.insert(data_shape.begin(), iter->second.shape(),
inputs[index].shape.insert( iter->second.shape() + iter->second.ndim());
inputs[index].shape.begin(), iter->second.shape(), auto dtype = NumpyDataTypeToFDDataType(iter->second.dtype());
iter->second.shape() + iter->second.ndim());
// TODO(jiangjiajun) Maybe skip memory copy is a better choice // TODO(jiangjiajun) Maybe skip memory copy is a better choice
// use SetExternalData // use SetExternalData
inputs[index].data.resize(iter->second.nbytes()); inputs[index].Resize(data_shape, dtype);
memcpy(inputs[index].MutableData(), iter->second.mutable_data(), memcpy(inputs[index].MutableData(), iter->second.mutable_data(),
iter->second.nbytes()); iter->second.nbytes());
inputs[index].name = iter->first; inputs[index].name = iter->first;
@@ -134,4 +133,4 @@ void BindRuntime(pybind11::module& m) {
m.def("get_available_backends", []() { return GetAvailableBackends(); }); m.def("get_available_backends", []() { return GetAvailableBackends(); });
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -59,13 +59,15 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype) {
void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor, void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
bool share_buffer) { bool share_buffer) {
tensor->dtype = NumpyDataTypeToFDDataType(pyarray.dtype()); auto dtype = NumpyDataTypeToFDDataType(pyarray.dtype());
tensor->shape.insert(tensor->shape.begin(), pyarray.shape(), std::vector<int64_t> data_shape;
pyarray.shape() + pyarray.ndim()); data_shape.insert(data_shape.begin(), pyarray.shape(),
pyarray.shape() + pyarray.ndim());
if (share_buffer) { if (share_buffer) {
tensor->external_data_ptr = pyarray.mutable_data(); tensor-> SetExternalData(data_shape, dtype,
pyarray.mutable_data());
} else { } else {
tensor->data.resize(pyarray.nbytes()); tensor->Resize(data_shape, dtype);
memcpy(tensor->MutableData(), pyarray.mutable_data(), pyarray.nbytes()); memcpy(tensor->MutableData(), pyarray.mutable_data(), pyarray.nbytes());
} }
} }

View File

@@ -17,6 +17,7 @@
#include <pybind11/numpy.h> #include <pybind11/numpy.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <type_traits> #include <type_traits>
#include "fastdeploy/fastdeploy_runtime.h" #include "fastdeploy/fastdeploy_runtime.h"
@@ -42,7 +43,8 @@ pybind11::array TensorToPyArray(const FDTensor& tensor);
cv::Mat PyArrayToCvMat(pybind11::array& pyarray); cv::Mat PyArrayToCvMat(pybind11::array& pyarray);
#endif #endif
template <typename T> FDDataType CTypeToFDDataType() { template <typename T>
FDDataType CTypeToFDDataType() {
if (std::is_same<T, int32_t>::value) { if (std::is_same<T, int32_t>::value) {
return FDDataType::INT32; return FDDataType::INT32;
} else if (std::is_same<T, int64_t>::value) { } else if (std::is_same<T, int64_t>::value) {
@@ -58,16 +60,17 @@ template <typename T> FDDataType CTypeToFDDataType() {
} }
template <typename T> template <typename T>
std::vector<pybind11::array> std::vector<pybind11::array> PyBackendInfer(
PyBackendInfer(T& self, const std::vector<std::string>& names, T& self, const std::vector<std::string>& names,
std::vector<pybind11::array>& data) { std::vector<pybind11::array>& data) {
std::vector<FDTensor> inputs(data.size()); std::vector<FDTensor> inputs(data.size());
for (size_t i = 0; i < data.size(); ++i) { for (size_t i = 0; i < data.size(); ++i) {
// TODO(jiangjiajun) here is considered to use user memory directly // TODO(jiangjiajun) here is considered to use user memory directly
inputs[i].dtype = NumpyDataTypeToFDDataType(data[i].dtype()); auto dtype = NumpyDataTypeToFDDataType(data[i].dtype());
inputs[i].shape.insert(inputs[i].shape.begin(), data[i].shape(), std::vector<int64_t> data_shape;
data[i].shape() + data[i].ndim()); data_shape.insert(data_shape.begin(), data[i].shape(),
inputs[i].data.resize(data[i].nbytes()); data[i].shape() + data[i].ndim());
inputs[i].Resize(data_shape, dtype);
memcpy(inputs[i].MutableData(), data[i].mutable_data(), data[i].nbytes()); memcpy(inputs[i].MutableData(), data[i].mutable_data(), data[i].nbytes());
inputs[i].name = names[i]; inputs[i].name = names[i];
} }
@@ -86,4 +89,4 @@ PyBackendInfer(T& self, const std::vector<std::string>& names,
return results; return results;
} }
} // namespace fastdeploy } // namespace fastdeploy