mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[LLM] First commit the llm deployment code
This commit is contained in:
@@ -0,0 +1,250 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM
|
||||
quantization.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace transform
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Shape, typename Element, typename Layout, int AdvanceRank, int Alignment>
|
||||
class FineGrainedScaleZeroIterator;
|
||||
|
||||
template <typename Shape_, typename Element_, int Alignment_>
|
||||
class FineGrainedScaleZeroIterator<Shape_, Element_, layout::RowMajor, 0, Alignment_>
|
||||
{
|
||||
public:
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::RowMajor;
|
||||
static int const kAdvanceRank = 0;
|
||||
static int const kAlignment = Alignment_;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
/// Row index of scales corresponding to the groupsize of 64
|
||||
int row_groupsize64_;
|
||||
int group_size_;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using TensorView = TensorView<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Pointer = Element*;
|
||||
using NonConstPointer = typename platform::remove_const<Element>::type*;
|
||||
|
||||
using AccessType = AlignedArray<Element, kAlignment>;
|
||||
|
||||
using Fragment = cutlass::Array<Element, kAlignment>;
|
||||
|
||||
// For compatibility with existing iterator interface
|
||||
struct Params
|
||||
{
|
||||
LongIndex stride_ = 0;
|
||||
|
||||
/// amount (in byte) to increment pointer from first access of current tile
|
||||
/// to first access of next tile
|
||||
LongIndex inc_advance_ = 0;
|
||||
|
||||
// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Construct the Params object given a pitch-linear tensor's layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Layout const& layout)
|
||||
: stride_(layout.stride(0))
|
||||
{
|
||||
inc_advance_ = Shape::kRow * stride_ * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
/// Internal pointer type permits fast address arithmetic
|
||||
using BytePointer = char*;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object with precomputed internal state
|
||||
Params const params_;
|
||||
|
||||
/// Internal pointer to first access of tile
|
||||
BytePointer pointer_scale_;
|
||||
BytePointer pointer_zero_;
|
||||
|
||||
bool is_valid_ = false;
|
||||
|
||||
public:
|
||||
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
||||
/// and thread ID
|
||||
CUTLASS_DEVICE
|
||||
FineGrainedScaleZeroIterator(
|
||||
///< Precomputed parameters object
|
||||
Params const& params,
|
||||
///< Pointer to start of scale tensor
|
||||
Pointer pointer_scale,
|
||||
///< Pointer to start of zero tensor
|
||||
Pointer pointer_zero,
|
||||
///< Extent of the scale and bias
|
||||
TensorCoord extent,
|
||||
///< ID of each participating thread
|
||||
int thread_id,
|
||||
///< Initial offset of threadblock
|
||||
TensorCoord const& threadblock_offset,
|
||||
///< Group size
|
||||
int group_size)
|
||||
: params_(params)
|
||||
, pointer_scale_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer_scale)))
|
||||
, pointer_zero_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer_zero)))
|
||||
{
|
||||
row_groupsize64_ = threadblock_offset.row();
|
||||
group_size_ = group_size;
|
||||
|
||||
const LongIndex tb_row_byte_offset
|
||||
= threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits<Element>::value / 8;
|
||||
const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits<Element>::value / 8;
|
||||
pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset);
|
||||
|
||||
if (pointer_zero_ != nullptr)
|
||||
{
|
||||
pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset);
|
||||
}
|
||||
|
||||
static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment;
|
||||
|
||||
int const thread_row = thread_id / THREADS_PER_ROW;
|
||||
int const thread_col = thread_id % THREADS_PER_ROW;
|
||||
|
||||
const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits<Element>::value / 8;
|
||||
const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits<Element>::value / 8;
|
||||
pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset);
|
||||
if (pointer_zero_ != nullptr)
|
||||
{
|
||||
pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset);
|
||||
}
|
||||
|
||||
// For the rows, we must check that we are within the extent AND the tile to avoid extra reads on
|
||||
// a given iteration. The same threads will be responsible for issues reads since the number of scales
|
||||
// read in a given iteration is a constant. Therefore, we should never have to update is_valid_
|
||||
// outside of the constructor.
|
||||
int const global_row = threadblock_offset.row() + thread_row;
|
||||
int const global_col = threadblock_offset.column() + thread_col * kAlignment;
|
||||
|
||||
bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow;
|
||||
bool const col_in_bounds = global_col < extent.column();
|
||||
|
||||
is_valid_ = row_in_bounds && col_in_bounds;
|
||||
}
|
||||
|
||||
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
||||
CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object
|
||||
Pointer pointer_scale, ///< Pointer to start of scale tensor
|
||||
Pointer pointer_zero, ///< Pointer to start of zero tensor
|
||||
TensorCoord extent, ///< Extent of tensor
|
||||
int thread_id, ///< ID of each participating thread
|
||||
int group_size)
|
||||
: FineGrainedScaleZeroIterator(
|
||||
params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(TensorCoord const& tile_offset)
|
||||
{
|
||||
const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_;
|
||||
const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits<Element>::value / 8;
|
||||
pointer_scale_ += row_byte_offset + col_byte_offset;
|
||||
if (pointer_zero_ != nullptr)
|
||||
{
|
||||
pointer_zero_ += row_byte_offset + col_byte_offset;
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE void clear_mask(bool enable = true)
|
||||
{
|
||||
is_valid_ &= (!enable);
|
||||
}
|
||||
|
||||
/// Returns whether access is valid or not
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const
|
||||
{
|
||||
return is_valid_;
|
||||
}
|
||||
|
||||
/// Returns a scale pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType* get_scale() const
|
||||
{
|
||||
return reinterpret_cast<AccessType*>(pointer_scale_);
|
||||
}
|
||||
|
||||
/// Returns a zero pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType* get_zero() const
|
||||
{
|
||||
return reinterpret_cast<AccessType*>(pointer_zero_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace transform
|
||||
} // namespace cutlass
|
||||
Reference in New Issue
Block a user