Files
FastDeploy/custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.h
2025-06-29 23:29:37 +00:00

38 lines
1.2 KiB
C++

#pragma once
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/core/enforce.h"
#include "moe/moe_wna16_marlin_utils/kernel.h"
#include "moe/moe_wna16_marlin_utils/types.h"
std::vector<paddle::Tensor> MoeWna16MarlinGemmApi(
const paddle::Tensor& a,
const paddle::optional<paddle::Tensor>& c_or_none,
const paddle::Tensor& b_q_weight,
const paddle::Tensor& b_scales,
const paddle::optional<paddle::Tensor>& global_scale_or_none,
const paddle::optional<paddle::Tensor>& b_zeros_or_none,
const paddle::optional<paddle::Tensor>& g_idx_or_none,
const paddle::optional<paddle::Tensor>& perm_or_none,
const paddle::Tensor& workspace,
const paddle::Tensor& sorted_token_ids,
const paddle::Tensor& expert_ids,
const paddle::Tensor& num_tokens_post_padded,
const paddle::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
const std::string& b_q_type_str,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);