mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 03:46:40 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			274 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			274 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| // Copyright © 2024 PaddlePaddle Name. All Rights Reserved.
 | ||
| //
 | ||
| // This code is partially inspired by and references the implementation found in
 | ||
| // FlashInfer. Specifically, the implementation of Top-p Sampling functionality
 | ||
| // in this code is inspired by the logic of FlashInfer’s
 | ||
| // flashinfer.sampling.top_p_sampling_from_probs function. For more details on
 | ||
| // FlashInfer’s documentation, please refer to:
 | ||
| // https://docs.flashinfer.ai/generated/flashinfer.sampling.top_p_sampling_from_probs.html#flashinfer-sampling-top-p-sampling-from_probs
 | ||
| //
 | ||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | ||
| // you may not use this file except in compliance with the License.
 | ||
| // You may obtain a copy of the License at
 | ||
| //
 | ||
| //     http://www.apache.org/licenses/LICENSE-2.0
 | ||
| //
 | ||
| // Unless required by applicable law or agreed to in writing, software
 | ||
| // distributed under the License is distributed on an "AS IS" BASIS,
 | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||
| // See the License for the specific language governing permissions and
 | ||
| // limitations under the License.
 | ||
| #pragma once
 | ||
| 
 | ||
| #include <cuda_device_runtime_api.h>
 | ||
| #include <cuda_runtime.h>
 | ||
| 
 | ||
| #include <cstdint>
 | ||
| #include <iostream>
 | ||
| #include <sstream>
 | ||
| #include <stdexcept>
 | ||
| #include <vector>
 | ||
| #include <curand.h>
 | ||
| #include <curand_kernel.h>
 | ||
| #include <curand_philox4x32_x.h>
 | ||
| 
 | ||
| /******************* utils *******************/
 | ||
| #define STR_HELPER(x) #x
 | ||
| #define STR(x) STR_HELPER(x)
 | ||
| 
 | ||
| #ifndef NDEBUG
 | ||
| #define CUDA_CALL(func, ...)                                                   \
 | ||
|   {                                                                            \
 | ||
|     cudaError_t e = (func);                                                    \
 | ||
|     if (e != cudaSuccess) {                                                    \
 | ||
|       std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e        \
 | ||
|                 << ") " << __FILE__ << ": line " << __LINE__                   \
 | ||
|                 << " at function " << STR(func) << std::endl;                  \
 | ||
|       return e;                                                                \
 | ||
|     }                                                                          \
 | ||
|   }
 | ||
| #else
 | ||
| #define CUDA_CALL(func, ...)                                                   \
 | ||
|   {                                                                            \
 | ||
|     cudaError_t e = (func);                                                    \
 | ||
|     if (e != cudaSuccess) {                                                    \
 | ||
|       return e;                                                                \
 | ||
|     }                                                                          \
 | ||
|   }
 | ||
| #endif
 | ||
| 
 | ||
| #define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...)              \
 | ||
|   if (deterministic) {                                                         \
 | ||
|     constexpr bool DETERMINISTIC = true;                                       \
 | ||
|     __VA_ARGS__                                                                \
 | ||
|   } else {                                                                     \
 | ||
|     constexpr bool DETERMINISTIC = false;                                      \
 | ||
|     __VA_ARGS__                                                                \
 | ||
|   }
 | ||
| 
 | ||
| #define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...)     \
 | ||
|   switch (aligned_vec_size) {                                                  \
 | ||
|   case 16: {                                                                   \
 | ||
|     constexpr size_t ALIGNED_VEC_SIZE = 16;                                    \
 | ||
|     __VA_ARGS__                                                                \
 | ||
|     break;                                                                     \
 | ||
|   }                                                                            \
 | ||
|   case 8: {                                                                    \
 | ||
|     constexpr size_t ALIGNED_VEC_SIZE = 8;                                     \
 | ||
|     __VA_ARGS__                                                                \
 | ||
|     break;                                                                     \
 | ||
|   }                                                                            \
 | ||
|   case 4: {                                                                    \
 | ||
|     constexpr size_t ALIGNED_VEC_SIZE = 4;                                     \
 | ||
|     __VA_ARGS__                                                                \
 | ||
|     break;                                                                     \
 | ||
|   }                                                                            \
 | ||
|   case 2: {                                                                    \
 | ||
|     constexpr size_t ALIGNED_VEC_SIZE = 2;                                     \
 | ||
|     __VA_ARGS__                                                                \
 | ||
|     break;                                                                     \
 | ||
|   }                                                                            \
 | ||
|   case 1: {                                                                    \
 | ||
|     constexpr size_t ALIGNED_VEC_SIZE = 1;                                     \
 | ||
|     __VA_ARGS__                                                                \
 | ||
|     break;                                                                     \
 | ||
|   }                                                                            \
 | ||
|   default: {                                                                   \
 | ||
|     std::ostringstream err_msg;                                                \
 | ||
|     err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size;           \
 | ||
|     throw std::invalid_argument(err_msg.str());                                \
 | ||
|   }                                                                            \
 | ||
|   }
 | ||
| 
 | ||
| /******************* vec_t<float> *******************/
 | ||
| #define SAMPLING_INLINE inline __attribute__((always_inline)) __device__
 | ||
| template <typename float_t, size_t vec_size> struct vec_t {
 | ||
|   SAMPLING_INLINE float_t &operator[](size_t i);
 | ||
|   SAMPLING_INLINE const float_t &operator[](size_t i) const;
 | ||
|   SAMPLING_INLINE void fill(float_t val);
 | ||
|   SAMPLING_INLINE void load(const float_t *ptr);
 | ||
|   SAMPLING_INLINE void store(float_t *ptr) const;
 | ||
|   template <typename T>
 | ||
|   SAMPLING_INLINE void cast_from(const vec_t<T, vec_size> &src);
 | ||
|   template <typename T> SAMPLING_INLINE void cast_load(const T *ptr);
 | ||
|   template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const;
 | ||
|   SAMPLING_INLINE static void memcpy(float_t *dst, const float_t *src);
 | ||
|   SAMPLING_INLINE float_t *ptr();
 | ||
| };
 | ||
| 
 | ||
| // float x 1
 | ||
| template <> struct vec_t<float, 1> {
 | ||
|   float data;
 | ||
| 
 | ||
|   SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(&data))[i]; }
 | ||
|   SAMPLING_INLINE const float &operator[](size_t i) const {
 | ||
|     return ((const float *)(&data))[i];
 | ||
|   }
 | ||
|   SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
 | ||
|   SAMPLING_INLINE void fill(float val);
 | ||
|   SAMPLING_INLINE void load(const float *ptr);
 | ||
|   SAMPLING_INLINE void store(float *ptr) const;
 | ||
|   template <typename T> SAMPLING_INLINE void cast_from(const vec_t<T, 1> &src) {
 | ||
|     cast_from_impl(*this, src);
 | ||
|   }
 | ||
|   template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
 | ||
|     cast_load_impl(*this, ptr);
 | ||
|   }
 | ||
|   template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
 | ||
|     cast_store_impl(ptr, *this);
 | ||
|   }
 | ||
|   SAMPLING_INLINE static void memcpy(float *dst, const float *src);
 | ||
| };
 | ||
| 
 | ||
| SAMPLING_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
 | ||
| 
 | ||
| SAMPLING_INLINE void vec_t<float, 1>::load(const float *ptr) { data = *ptr; }
 | ||
| 
 | ||
| SAMPLING_INLINE void vec_t<float, 1>::store(float *ptr) const { *ptr = data; }
 | ||
| 
 | ||
| SAMPLING_INLINE void vec_t<float, 1>::memcpy(float *dst, const float *src) {
 | ||
|   *dst = *src;
 | ||
| }
 | ||
| 
 | ||
| // float x 2
 | ||
| template <> struct vec_t<float, 2> {
 | ||
|   float2 data;
 | ||
| 
 | ||
|   SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(&data))[i]; }
 | ||
|   SAMPLING_INLINE const float &operator[](size_t i) const {
 | ||
|     return ((const float *)(&data))[i];
 | ||
|   }
 | ||
|   SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
 | ||
|   SAMPLING_INLINE void fill(float val);
 | ||
|   SAMPLING_INLINE void load(const float *ptr);
 | ||
|   SAMPLING_INLINE void store(float *ptr) const;
 | ||
|   template <typename T> SAMPLING_INLINE void cast_from(const vec_t<T, 2> &src) {
 | ||
|     cast_from_impl(*this, src);
 | ||
|   }
 | ||
|   template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
 | ||
|     cast_load_impl(*this, ptr);
 | ||
|   }
 | ||
|   template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
 | ||
|     cast_store_impl(ptr, *this);
 | ||
|   }
 | ||
|   SAMPLING_INLINE static void memcpy(float *dst, const float *src);
 | ||
| };
 | ||
| 
 | ||
| SAMPLING_INLINE void vec_t<float, 2>::fill(float val) {
 | ||
|   data = make_float2(val, val);
 | ||
| }
 | ||
| 
 | ||
| SAMPLING_INLINE void vec_t<float, 2>::load(const float *ptr) {
 | ||
|   data = *((float2 *)ptr);
 | ||
| }
 | ||
| 
 | ||
| SAMPLING_INLINE void vec_t<float, 2>::store(float *ptr) const {
 | ||
|   *((float2 *)ptr) = data;
 | ||
| }
 | ||
| 
 | ||
| SAMPLING_INLINE void vec_t<float, 2>::memcpy(float *dst, const float *src) {
 | ||
|   *((float2 *)dst) = *((float2 *)src);
 | ||
| }
 | ||
| 
 | ||
| // float x 4 or more
 | ||
| template <size_t vec_size> struct vec_t<float, vec_size> {
 | ||
|   float4 data[vec_size / 4];
 | ||
| 
 | ||
|   SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; }
 | ||
|   SAMPLING_INLINE const float &operator[](size_t i) const {
 | ||
|     return ((const float *)(data))[i];
 | ||
|   }
 | ||
|   SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
 | ||
|   SAMPLING_INLINE void fill(float val) {
 | ||
| #pragma unroll
 | ||
|     for (size_t i = 0; i < vec_size / 4; ++i) {
 | ||
|       data[i] = make_float4(val, val, val, val);
 | ||
|     }
 | ||
|   }
 | ||
|   SAMPLING_INLINE void load(const float *ptr) {
 | ||
| #pragma unroll
 | ||
|     for (size_t i = 0; i < vec_size / 4; ++i) {
 | ||
|       data[i] = ((float4 *)ptr)[i];
 | ||
|     }
 | ||
|   }
 | ||
|   SAMPLING_INLINE void store(float *ptr) const {
 | ||
| #pragma unroll
 | ||
|     for (size_t i = 0; i < vec_size / 4; ++i) {
 | ||
|       ((float4 *)ptr)[i] = data[i];
 | ||
|     }
 | ||
|   }
 | ||
|   template <typename T>
 | ||
|   SAMPLING_INLINE void cast_from(const vec_t<T, vec_size> &src) {
 | ||
|     cast_from_impl(*this, src);
 | ||
|   }
 | ||
|   template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
 | ||
|     cast_load_impl(*this, ptr);
 | ||
|   }
 | ||
|   template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
 | ||
|     cast_store_impl(ptr, *this);
 | ||
|   }
 | ||
|   SAMPLING_INLINE static void memcpy(float *dst, const float *src) {
 | ||
| #pragma unroll
 | ||
|     for (size_t i = 0; i < vec_size / 4; ++i) {
 | ||
|       ((float4 *)dst)[i] = ((float4 *)src)[i];
 | ||
|     }
 | ||
|   }
 | ||
| };
 | ||
| 
 | ||
| template <typename src_float_t, typename tgt_float_t, size_t vec_size>
 | ||
| SAMPLING_INLINE void cast_load_impl(vec_t<tgt_float_t, vec_size>& dst,
 | ||
|                                       const src_float_t* src_ptr) {
 | ||
|   if constexpr (std::is_same_v<src_float_t, tgt_float_t>) {
 | ||
|     dst.load(src_ptr);
 | ||
|   } else {
 | ||
|     vec_t<src_float_t, vec_size> tmp;
 | ||
|     tmp.load(src_ptr);
 | ||
|     dst.cast_from(tmp);
 | ||
|   }
 | ||
| }
 | ||
| 
 | ||
| inline std::pair<int, int> GetCudaComputeCapability() {
 | ||
|   int device_id = 0;
 | ||
|   cudaGetDevice(&device_id);
 | ||
|   int major = 0, minor = 0;
 | ||
|   cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_id);
 | ||
|   cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_id);
 | ||
|   return std::make_pair(major, minor);
 | ||
| }
 | ||
| 
 | ||
| /******************* math *******************/
 | ||
| __forceinline__ __device__ float ptx_rcp(float x) {
 | ||
| #ifdef PADDLE_WITH_COREX
 | ||
|   return __ivcorex_rcpf(x);
 | ||
| #else
 | ||
|   float y;
 | ||
|   asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
 | ||
|   return y;
 | ||
| #endif
 | ||
| }
 | ||
| 
 | ||
| template <typename T1, typename T2>
 | ||
| __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) {
 | ||
|   return (x + y - 1) / y;
 | ||
| }
 | 
