Add argmax, argmin function (#104)

* Add argmax argmin function

* Add unittest for argmax, argmin
This commit is contained in:
Jack Zhou
2022-08-12 20:22:11 +08:00
committed by GitHub
parent 679f39ae9f
commit b6247238f5
4 changed files with 266 additions and 3 deletions

View File

@@ -96,5 +96,33 @@ FASTDEPLOY_DECL void Prod(const FDTensor& x, FDTensor* out,
const std::vector<int64_t>& dims,
bool keep_dim = false, bool reduce_all = false);
/** Excute the argmax operation for input FDTensor along given dims.
@param x The input tensor.
@param out The output tensor which stores the result.
@param axis The axis which will be reduced.
@param output_dtype The data type of output FDTensor, INT64 or INT32,
default to INT64.
@param keep_dim Whether to keep the reduced dims, default false.
@param flatten Whether to flatten FDTensor to get the argmin index, default
false.
*/
FASTDEPLOY_DECL void ArgMax(const FDTensor& x, FDTensor* out, int64_t axis,
FDDataType output_dtype = FDDataType::INT64,
bool keep_dim = false, bool flatten = false);
/** Excute the argmin operation for input FDTensor along given dims.
@param x The input tensor.
@param out The output tensor which stores the result.
@param axis The axis which will be reduced.
@param output_dtype The data type of output FDTensor, INT64 or INT32,
default to INT64.
@param keep_dim Whether to keep the reduced dims, default false.
@param flatten Whether to flatten FDTensor to get the argmin index, default
false.
*/
FASTDEPLOY_DECL void ArgMin(const FDTensor& x, FDTensor* out, int64_t axis,
FDDataType output_dtype = FDDataType::INT64,
bool keep_dim = false, bool flatten = false);
#endif
} // namespace fastdeploy