mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
38 lines
1.2 KiB
C++
38 lines
1.2 KiB
C++
#pragma once
|
|
#ifndef MARLIN_NAMESPACE_NAME
|
|
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
|
#endif
|
|
|
|
#include "paddle/phi/api/include/api.h"
|
|
#include "paddle/phi/core/enforce.h"
|
|
|
|
#include "moe/moe_wna16_marlin_utils/kernel.h"
|
|
#include "moe/moe_wna16_marlin_utils/types.h"
|
|
|
|
std::vector<paddle::Tensor> MoeWna16MarlinGemmApi(
|
|
const paddle::Tensor& a,
|
|
const paddle::optional<paddle::Tensor>& c_or_none,
|
|
const paddle::Tensor& b_q_weight,
|
|
const paddle::Tensor& b_scales,
|
|
const paddle::optional<paddle::Tensor>& global_scale_or_none,
|
|
const paddle::optional<paddle::Tensor>& b_zeros_or_none,
|
|
const paddle::optional<paddle::Tensor>& g_idx_or_none,
|
|
const paddle::optional<paddle::Tensor>& perm_or_none,
|
|
const paddle::Tensor& workspace,
|
|
const paddle::Tensor& sorted_token_ids,
|
|
const paddle::Tensor& expert_ids,
|
|
const paddle::Tensor& num_tokens_post_padded,
|
|
const paddle::Tensor& topk_weights,
|
|
int64_t moe_block_size,
|
|
int64_t top_k,
|
|
bool mul_topk_weights,
|
|
bool is_ep,
|
|
const std::string& b_q_type_str,
|
|
int64_t size_m,
|
|
int64_t size_n,
|
|
int64_t size_k,
|
|
bool is_k_full,
|
|
bool use_atomic_add,
|
|
bool use_fp32_reduce,
|
|
bool is_zp_float);
|