Files
FastDeploy/custom_ops/gpu_ops/noaux_tc.cu
chen 1a6283424e
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Fix noaux_tc cuda Error 700 in CUDAGraph (#4174)
2025-09-23 18:41:33 +08:00

86 lines
3.4 KiB
Plaintext

// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <optional>
#include "helper.h"
#include "noauxtc_kernel.h"
std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
paddle::Tensor& scores_with_bias,
int n_group,
int topk_group,
int topk,
bool renormalize,
float routed_scaling_factor) {
auto input_shape = scores_with_bias.shape();
PD_CHECK(input_shape.size() == 2);
int64_t num_tokens = input_shape[0];
int64_t num_experts = input_shape[1];
auto input_type = scores_with_bias.dtype();
auto place = scores_with_bias.place();
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place);
auto stream = scores_with_bias.stream();
invokeNoAuxTc<float, int64_t>(reinterpret_cast<float*>(scores.data<float>()),
reinterpret_cast<float*>(group_scores.data<float>()),
reinterpret_cast<float*>(topk_values.data<float>()),
reinterpret_cast<int64_t*>(topk_indices.data<int64_t>()),
reinterpret_cast<float*>(scores_with_bias.data<float>()),
num_tokens,
num_experts,
n_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
stream);
return {scores, topk_values, topk_indices};
}
std::vector<paddle::DataType> NoauxTcInferDtype(
const paddle::DataType& scores_dtype,
const paddle::DataType& scores_with_bias_dtype) {
return {scores_dtype, scores_dtype, paddle::DataType::INT64};
}
std::vector<std::vector<int64_t>> NoauxTcInferShape(
const std::vector<int64_t>& scores_shape,
const std::vector<int64_t>& ,
const int topk) {
auto num_tokens = scores_shape[0];
auto topk_values_shape = std::vector<int64_t>{num_tokens, topk};
auto topk_indices_shape = std::vector<int64_t>{num_tokens, topk};
return {scores_shape, topk_values_shape, topk_indices_shape};
}
PD_BUILD_STATIC_OP(noaux_tc)
.Inputs({"scores", "scores_with_bias"})
.Outputs({"output_tensor", "topk_values", "topk_indices"})
.Attrs({"n_group: int",
"topk_group: int",
"topk:int",
"renormalize: bool",
"routed_scaling_factor: float"})
.SetKernelFn(PD_KERNEL(NoauxTc))
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcInferDtype));