mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
[Feat] ernie4_5_vl_moe
support CudaGraph (#3226)
* delete dynamic control flow for decode * coda-style * fix scatter/gather typos and use input stream instead default stream * support 0-Size Tensor * update runner and model * using static mem address as input * fix mem leak * refine code * update mm_buffer * fix typo * fix buffersize * fix unk token * refine code * refine * support other arch * open cudagraph in vlci * fix * update * update * update * fix cmd * update --------- Co-authored-by: aquagull <hongyuh@qq.com> Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -414,8 +414,8 @@ std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
|
|||||||
const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights);
|
const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights);
|
||||||
|
|
||||||
void TextImageIndexOut(const paddle::Tensor &token_type_ids,
|
void TextImageIndexOut(const paddle::Tensor &token_type_ids,
|
||||||
const paddle::Tensor &text_input,
|
paddle::Tensor &text_input,
|
||||||
const paddle::Tensor &image_input);
|
paddle::Tensor &image_input);
|
||||||
|
|
||||||
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
||||||
paddle::Tensor &image_input,
|
paddle::Tensor &image_input,
|
||||||
|
@@ -132,7 +132,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
|||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(get_padding_offset)
|
PD_BUILD_STATIC_OP(get_padding_offset)
|
||||||
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
|
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
|
||||||
.Outputs({"x_remove_padding",
|
.Outputs({"x_remove_padding",
|
||||||
"batch_id_per_token",
|
"batch_id_per_token",
|
||||||
"cu_seqlens_q",
|
"cu_seqlens_q",
|
||||||
|
@@ -36,6 +36,9 @@ void MoeDispatchKernel(
|
|||||||
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
|
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
|
||||||
using namespace phi;
|
using namespace phi;
|
||||||
|
|
||||||
|
if (num_rows == 0){
|
||||||
|
return;
|
||||||
|
}
|
||||||
typedef PDTraits<T> traits_;
|
typedef PDTraits<T> traits_;
|
||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
typedef typename traits_::data_t data_t;
|
typedef typename traits_::data_t data_t;
|
||||||
@@ -185,6 +188,15 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
|||||||
auto expert_idx_per_token =
|
auto expert_idx_per_token =
|
||||||
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
|
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
|
||||||
|
|
||||||
|
if (token_rows == 0){
|
||||||
|
return {permute_input,
|
||||||
|
tokens_expert_prefix_sum,
|
||||||
|
permute_indices_per_token,
|
||||||
|
topk_weight,
|
||||||
|
topk_idx,
|
||||||
|
expert_idx_per_token};
|
||||||
|
}
|
||||||
|
|
||||||
switch (input_type) {
|
switch (input_type) {
|
||||||
case paddle::DataType::BFLOAT16:
|
case paddle::DataType::BFLOAT16:
|
||||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(
|
MoeDispatchKernel<paddle::DataType::BFLOAT16>(
|
||||||
|
@@ -412,7 +412,9 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
|||||||
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
|
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
|
||||||
permute_input.dtype();
|
permute_input.dtype();
|
||||||
auto ffn_out = paddle::empty_like(permute_input, t_type);
|
auto ffn_out = paddle::empty_like(permute_input, t_type);
|
||||||
|
if(permute_input.numel() == 0){
|
||||||
|
return ffn_out;
|
||||||
|
}
|
||||||
switch (t_type) {
|
switch (t_type) {
|
||||||
case paddle::DataType::BFLOAT16:
|
case paddle::DataType::BFLOAT16:
|
||||||
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
|
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
|
||||||
|
@@ -59,6 +59,10 @@ paddle::Tensor MoeExpertReduceFunc(
|
|||||||
|
|
||||||
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
||||||
|
|
||||||
|
if(num_rows == 0){
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
switch (input_type) {
|
switch (input_type) {
|
||||||
case paddle::DataType::BFLOAT16:
|
case paddle::DataType::BFLOAT16:
|
||||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
||||||
|
@@ -59,7 +59,7 @@ __global__ void text_image_scatter_kernel(
|
|||||||
constexpr int HalfVecSize = VecSize / 2;
|
constexpr int HalfVecSize = VecSize / 2;
|
||||||
using T_Vec = AlignedVector<T, VecSize>;
|
using T_Vec = AlignedVector<T, VecSize>;
|
||||||
T_Vec input_ptr_vec;
|
T_Vec input_ptr_vec;
|
||||||
T_Vec text_imgaes_vec;
|
T_Vec text_images_vec;
|
||||||
|
|
||||||
int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const int64_t step = blockDim.x * gridDim.x * VecSize;
|
const int64_t step = blockDim.x * gridDim.x * VecSize;
|
||||||
@@ -76,16 +76,20 @@ __global__ void text_image_scatter_kernel(
|
|||||||
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
|
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for(int vi = 0; vi < VecSize; ++vi) {
|
for(int vi = 0; vi < VecSize; ++vi) {
|
||||||
text_imgaes_vec[vi] = input_ptr_vec[vi];
|
text_images_vec[vi] = input_ptr_vec[vi];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (token_type_ids_num == 0) {
|
if (token_type_ids_num == 0) {
|
||||||
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
|
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
|
||||||
Store<T,VecSize>(text_imgaes_vec, text_gather_ptr + text_load_offset);
|
Store<T,VecSize>(text_images_vec, text_gather_ptr + text_load_offset);
|
||||||
|
|
||||||
|
} else if(token_type_ids_num == 1){
|
||||||
|
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
||||||
|
Store<T,VecSize>(text_images_vec, image_gather_ptr + image_load_offset);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
// skip cuda graph padding value
|
||||||
Store<T,VecSize>(text_imgaes_vec, image_gather_ptr + image_load_offset);
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -120,9 +124,12 @@ __global__ void text_image_gather_kernel(
|
|||||||
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
|
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
|
||||||
Load<T,VecSize>(text_gather_ptr + text_load_offset, &text_imgaes_vec);
|
Load<T,VecSize>(text_gather_ptr + text_load_offset, &text_imgaes_vec);
|
||||||
|
|
||||||
} else {
|
} else if (token_type_ids_num == 1){
|
||||||
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
||||||
Load<T,VecSize>(image_gather_ptr + image_load_offset, &text_imgaes_vec);
|
Load<T,VecSize>(image_gather_ptr + image_load_offset, &text_imgaes_vec);
|
||||||
|
} else {
|
||||||
|
// skip cuda graph padding value
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -154,7 +161,6 @@ void LaunchTextImageGatherScatter(
|
|||||||
const int64_t token_num = in_dims[0];
|
const int64_t token_num = in_dims[0];
|
||||||
const int64_t hidden_size = in_dims[1];
|
const int64_t hidden_size = in_dims[1];
|
||||||
|
|
||||||
|
|
||||||
const int VecSize = 16 / sizeof(data_t);
|
const int VecSize = 16 / sizeof(data_t);
|
||||||
const int64_t tot_element_num = token_num * hidden_size;
|
const int64_t tot_element_num = token_num * hidden_size;
|
||||||
|
|
||||||
@@ -168,7 +174,7 @@ void LaunchTextImageGatherScatter(
|
|||||||
PADDLE_ENFORCE_GPU_SUCCESS(GetGridSize(tot_pack_num, block_size, kNumWaves, &grid_size_x));
|
PADDLE_ENFORCE_GPU_SUCCESS(GetGridSize(tot_pack_num, block_size, kNumWaves, &grid_size_x));
|
||||||
dim3 grid_dim = dim3(grid_size_x, 1, 1);
|
dim3 grid_dim = dim3(grid_size_x, 1, 1);
|
||||||
if (is_scatter) {
|
if (is_scatter) {
|
||||||
text_image_scatter_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
text_image_scatter_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
|
||||||
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
||||||
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
||||||
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
||||||
@@ -179,7 +185,7 @@ void LaunchTextImageGatherScatter(
|
|||||||
tot_element_num
|
tot_element_num
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
text_image_gather_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
text_image_gather_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
|
||||||
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
||||||
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
||||||
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
||||||
|
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
template <int VecSize>
|
template <int VecSize>
|
||||||
__global__ void text_image_index_out_kernel(
|
__global__ void text_image_index_out_kernel(
|
||||||
int32_t* token_type_ids,
|
const int32_t* token_type_ids,
|
||||||
int32_t* text_index,
|
int32_t* text_index,
|
||||||
int32_t* image_index,
|
int32_t* image_index,
|
||||||
const int64_t token_num
|
const int64_t token_num
|
||||||
@@ -31,23 +31,27 @@ __global__ void text_image_index_out_kernel(
|
|||||||
if (token_type_ids[i] == 0) {
|
if (token_type_ids[i] == 0) {
|
||||||
text_index[i] = text_count;
|
text_index[i] = text_count;
|
||||||
text_count += 1;
|
text_count += 1;
|
||||||
} else {
|
} else if (token_type_ids[i] == 1) {
|
||||||
image_index[i] = images_count;
|
image_index[i] = images_count;
|
||||||
images_count += 1;
|
images_count += 1;
|
||||||
|
} else {
|
||||||
|
// skip cuda graph padding value
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TextImageIndexOut(
|
void TextImageIndexOut(
|
||||||
const paddle::Tensor& token_type_ids,
|
const paddle::Tensor& token_type_ids,
|
||||||
const paddle::Tensor& text_index,
|
paddle::Tensor& text_index,
|
||||||
const paddle::Tensor& image_index) {
|
paddle::Tensor& image_index) {
|
||||||
|
|
||||||
const int64_t token_num = token_type_ids.shape()[0];
|
const int64_t token_num = token_type_ids.shape()[0];
|
||||||
text_image_index_out_kernel<1><<<1, 1>>>(
|
auto stream = token_type_ids.stream();
|
||||||
const_cast<int32_t*>(token_type_ids.data<int32_t>()),
|
text_image_index_out_kernel<1><<<1, 1, 0, stream>>>(
|
||||||
const_cast<int32_t*>(text_index.data<int32_t>()),
|
token_type_ids.data<int32_t>(),
|
||||||
const_cast<int32_t*>(image_index.data<int32_t>()),
|
text_index.data<int32_t>(),
|
||||||
|
image_index.data<int32_t>(),
|
||||||
token_num
|
token_num
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@@ -99,3 +99,35 @@ class GraphOptWrapper:
|
|||||||
fd_config.graph_opt_config.graph_opt_level < 1
|
fd_config.graph_opt_config.graph_opt_level < 1
|
||||||
), "Currently unable to update weights in static graph mode."
|
), "Currently unable to update weights in static graph mode."
|
||||||
self.graph_opt_backend.clear_cudagraph_piecewise_backend()
|
self.graph_opt_backend.clear_cudagraph_piecewise_backend()
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_graph_buffers(buffer_meta):
|
||||||
|
def decorator(cls):
|
||||||
|
original_init = cls.__init__
|
||||||
|
|
||||||
|
def __init__(self, fd_config: FDConfig, **kwargs):
|
||||||
|
original_init(self, fd_config=fd_config, **kwargs)
|
||||||
|
|
||||||
|
def _resolve_path(root, path: str):
|
||||||
|
cur = root
|
||||||
|
for p in path.split("."):
|
||||||
|
cur = getattr(cur, p)
|
||||||
|
return cur
|
||||||
|
|
||||||
|
if not hasattr(self, "_mm_buffers"):
|
||||||
|
self._mm_buffers = {}
|
||||||
|
for name, meta in buffer_meta.items():
|
||||||
|
shape = [_resolve_path(fd_config, s) if isinstance(s, str) else s for s in meta["shape"]]
|
||||||
|
dtype = meta["dtype"]
|
||||||
|
if "." in meta["dtype"]:
|
||||||
|
dtype = _resolve_path(fd_config, meta["dtype"])
|
||||||
|
self._mm_buffers[name] = paddle.full(
|
||||||
|
shape=shape,
|
||||||
|
dtype=dtype,
|
||||||
|
fill_value=meta.get("value", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
cls.__init__ = __init__
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
@@ -32,6 +32,7 @@ from paddleformers.utils.log import logger
|
|||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||||
|
cuda_graph_buffers,
|
||||||
support_graph_optimization,
|
support_graph_optimization,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||||
@@ -66,12 +67,23 @@ class Ernie4_5_VLAttention(Ernie4_5_Attention):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VLMoEMeta:
|
class VLMoEMeta:
|
||||||
image_input: Optional[paddle.Tensor] = None
|
image_input: paddle.Tensor
|
||||||
text_input: Optional[paddle.Tensor] = None
|
text_input: paddle.Tensor
|
||||||
text_index: Optional[paddle.Tensor] = None
|
text_index: paddle.Tensor
|
||||||
image_index: Optional[paddle.Tensor] = None
|
image_index: paddle.Tensor
|
||||||
token_type_ids: Optional[paddle.Tensor] = None
|
token_type_ids: paddle.Tensor
|
||||||
fake_hidden_states: Optional[paddle.Tensor] = None
|
image_token_num: paddle.Tensor
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return (
|
||||||
|
f"VLMoEMeta(\n"
|
||||||
|
f" image_input: {self.image_input}, pointer: {self.image_input.data_ptr()}\n"
|
||||||
|
f" text_input: {self.text_input}, pointer: {self.text_input.data_ptr()}\n"
|
||||||
|
f" text_index: {self.text_index}, pointer: {self.text_index.data_ptr()}\n"
|
||||||
|
f" image_index: {self.image_index}, pointer: {self.image_index.data_ptr()}\n"
|
||||||
|
f" token_type_ids: {self.token_type_ids}, pointer: {self.token_type_ids.data_ptr()}\n\n"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Ernie4_5_VLMoeBlock(nn.Layer):
|
class Ernie4_5_VLMoeBlock(nn.Layer):
|
||||||
@@ -266,31 +278,26 @@ class Ernie4_5_VLMoE(nn.Layer):
|
|||||||
def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
|
def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
shared_experts_out = self.shared_experts(hidden_states)
|
shared_experts_out = self.shared_experts(hidden_states)
|
||||||
if vl_moe_meta.image_input is not None:
|
text_image_gather_scatter(
|
||||||
text_image_gather_scatter(
|
hidden_states,
|
||||||
hidden_states,
|
vl_moe_meta.text_input,
|
||||||
vl_moe_meta.text_input,
|
vl_moe_meta.image_input,
|
||||||
vl_moe_meta.image_input,
|
vl_moe_meta.token_type_ids,
|
||||||
vl_moe_meta.token_type_ids,
|
vl_moe_meta.text_index,
|
||||||
vl_moe_meta.text_index,
|
vl_moe_meta.image_index,
|
||||||
vl_moe_meta.image_index,
|
True,
|
||||||
True,
|
)
|
||||||
)
|
text_out = self.text_fused_moe(vl_moe_meta.text_input)
|
||||||
text_out = self.text_fused_moe(vl_moe_meta.text_input)
|
image_out = self.image_fused_moe(vl_moe_meta.image_input)
|
||||||
image_out = self.image_fused_moe(vl_moe_meta.image_input)
|
text_image_gather_scatter(
|
||||||
text_image_gather_scatter(
|
hidden_states,
|
||||||
hidden_states,
|
text_out,
|
||||||
text_out,
|
image_out,
|
||||||
image_out,
|
vl_moe_meta.token_type_ids,
|
||||||
vl_moe_meta.token_type_ids,
|
vl_moe_meta.text_index,
|
||||||
vl_moe_meta.text_index,
|
vl_moe_meta.image_index,
|
||||||
vl_moe_meta.image_index,
|
False,
|
||||||
False,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states = self.text_fused_moe(hidden_states)
|
|
||||||
if vl_moe_meta.fake_hidden_states is not None:
|
|
||||||
self.image_fused_moe(vl_moe_meta.fake_hidden_states)
|
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
hidden_states += shared_experts_out
|
hidden_states += shared_experts_out
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
@@ -394,6 +401,40 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@cuda_graph_buffers(
|
||||||
|
{
|
||||||
|
"text_input": {
|
||||||
|
"shape": ["parallel_config.max_model_len", "model_config.hidden_size"],
|
||||||
|
"dtype": "model_config.dtype",
|
||||||
|
"value": 1,
|
||||||
|
},
|
||||||
|
"image_input": {
|
||||||
|
"shape": ["parallel_config.max_model_len", "model_config.hidden_size"],
|
||||||
|
"dtype": "model_config.dtype",
|
||||||
|
"value": 1,
|
||||||
|
},
|
||||||
|
"text_index": {
|
||||||
|
"shape": ["parallel_config.max_model_len"],
|
||||||
|
"dtype": "int32",
|
||||||
|
"value": 0,
|
||||||
|
},
|
||||||
|
"image_index": {
|
||||||
|
"shape": ["parallel_config.max_model_len"],
|
||||||
|
"dtype": "int32",
|
||||||
|
"value": 0,
|
||||||
|
},
|
||||||
|
"token_type_ids": {
|
||||||
|
"shape": ["parallel_config.max_model_len"],
|
||||||
|
"dtype": "int32",
|
||||||
|
"value": -1,
|
||||||
|
},
|
||||||
|
"image_token_num": {
|
||||||
|
"shape": [1],
|
||||||
|
"dtype": "int64",
|
||||||
|
"value": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
@support_graph_optimization
|
@support_graph_optimization
|
||||||
class Ernie4_5_VLModel(nn.Layer):
|
class Ernie4_5_VLModel(nn.Layer):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -454,59 +495,46 @@ class Ernie4_5_VLModel(nn.Layer):
|
|||||||
logger.info(f"Start load layer {i}")
|
logger.info(f"Start load layer {i}")
|
||||||
self.layers[i].load_state_dict(state_dict)
|
self.layers[i].load_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(
|
def prepare_vl_moe_meta(
|
||||||
self,
|
self,
|
||||||
ids_remove_padding: paddle.Tensor,
|
ids_remove_padding: paddle.Tensor,
|
||||||
image_features: Optional[paddle.Tensor],
|
) -> VLMoEMeta:
|
||||||
forward_meta: ForwardMeta,
|
|
||||||
):
|
|
||||||
text_input = None
|
|
||||||
image_input = None
|
|
||||||
text_index = None
|
|
||||||
image_index = None
|
|
||||||
fake_hidden_states = None
|
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
|
||||||
token_num, hidden_dim = hidden_states.shape
|
|
||||||
|
|
||||||
# -----------------------
|
|
||||||
image_mask = ids_remove_padding == self.im_patch_id
|
image_mask = ids_remove_padding == self.im_patch_id
|
||||||
|
token_type_ids = image_mask.cast("int32")
|
||||||
image_token_num = image_mask.sum()
|
image_token_num = image_mask.sum()
|
||||||
|
token_num = ids_remove_padding.shape[0]
|
||||||
text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
|
text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
|
||||||
|
|
||||||
token_type_ids = image_mask.cast("int32")
|
# The scenario requiring padding is CUDA graph, thus we only need to pad the maximum capture size.
|
||||||
if self.fd_config.parallel_config.use_ep is True:
|
self._mm_buffers["token_type_ids"][: self.fd_config.graph_opt_config.max_capture_size].fill_(-1)
|
||||||
fake_hidden_states = paddle.empty(
|
self._mm_buffers["token_type_ids"].copy_(token_type_ids, False)
|
||||||
shape=[0, self.fd_config.model_config.hidden_size],
|
self._mm_buffers["image_token_num"].copy_(image_token_num, False)
|
||||||
dtype=paddle.get_default_dtype(),
|
|
||||||
)
|
|
||||||
text_input = fake_hidden_states
|
|
||||||
|
|
||||||
if image_token_num > 0:
|
return VLMoEMeta(
|
||||||
hidden_states[image_mask] = image_features.cast(self._dtype)
|
text_input=self._mm_buffers["text_input"][:text_token_num],
|
||||||
text_input = paddle.ones(
|
image_input=self._mm_buffers["image_input"][:image_token_num],
|
||||||
shape=[text_token_num, hidden_dim],
|
text_index=self._mm_buffers["text_index"][:token_num],
|
||||||
dtype=self._dtype,
|
image_index=self._mm_buffers["image_index"][:token_num],
|
||||||
)
|
token_type_ids=self._mm_buffers["token_type_ids"][:token_num],
|
||||||
image_input = paddle.ones(
|
image_token_num=self._mm_buffers["image_token_num"],
|
||||||
shape=[image_token_num, hidden_dim],
|
|
||||||
dtype=self._dtype,
|
|
||||||
)
|
|
||||||
text_index = paddle.zeros_like(image_mask, dtype="int32")
|
|
||||||
image_index = paddle.zeros_like(image_mask, dtype="int32")
|
|
||||||
text_image_index_out(token_type_ids, text_index, image_index)
|
|
||||||
|
|
||||||
vl_moe_meta = VLMoEMeta(
|
|
||||||
text_input=text_input,
|
|
||||||
image_input=image_input,
|
|
||||||
text_index=text_index,
|
|
||||||
image_index=image_index,
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
fake_hidden_states=fake_hidden_states,
|
|
||||||
)
|
)
|
||||||
# -----------------------
|
|
||||||
|
|
||||||
|
def get_input_embeddings(self, ids_remove_padding: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
return self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_embeddings: paddle.Tensor,
|
||||||
|
ids_remove_padding: paddle.Tensor,
|
||||||
|
forward_meta: ForwardMeta,
|
||||||
|
vl_moe_meta: VLMoEMeta,
|
||||||
|
):
|
||||||
|
text_image_index_out(vl_moe_meta.token_type_ids, vl_moe_meta.text_index, vl_moe_meta.image_index)
|
||||||
|
|
||||||
|
hidden_states = input_embeddings
|
||||||
residual = None
|
residual = None
|
||||||
|
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
hidden_states, residual = self.layers[i](
|
hidden_states, residual = self.layers[i](
|
||||||
forward_meta,
|
forward_meta,
|
||||||
@@ -517,17 +545,15 @@ class Ernie4_5_VLModel(nn.Layer):
|
|||||||
|
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
# -----------------------
|
|
||||||
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time, k=1)
|
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time, k=1)
|
||||||
hidden_states = extract_text_token_output(
|
hidden_states = extract_text_token_output(
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_seq_len_index.cast("int32"),
|
max_seq_len_index.cast("int32"),
|
||||||
image_token_num.cast("int32"),
|
vl_moe_meta.image_token_num.cast("int32"),
|
||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q,
|
||||||
hidden_states.cast("float32"),
|
hidden_states.cast("float32"),
|
||||||
).cast(self._dtype)
|
).cast(self._dtype)
|
||||||
# -----------------------
|
|
||||||
|
|
||||||
out = self.norm(hidden_states)
|
out = self.norm(hidden_states)
|
||||||
|
|
||||||
@@ -552,6 +578,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
# ernie
|
# ernie
|
||||||
self.ernie = Ernie4_5_VLModel(fd_config=fd_config)
|
self.ernie = Ernie4_5_VLModel(fd_config=fd_config)
|
||||||
|
|
||||||
|
# Persistent buffers for CUDA graphs.
|
||||||
|
self._input_embeddings = paddle.zeros(
|
||||||
|
[fd_config.parallel_config.max_model_len, fd_config.model_config.hidden_size],
|
||||||
|
dtype=fd_config.model_config.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
|
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
|
||||||
|
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
@@ -733,16 +765,33 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
|
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
|
||||||
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
|
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
ids_remove_padding: paddle.Tensor,
|
||||||
|
image_features: Optional[paddle.Tensor] = None,
|
||||||
|
) -> paddle.Tensor:
|
||||||
|
input_embeddings = self.ernie.get_input_embeddings(ids_remove_padding=ids_remove_padding)
|
||||||
|
if image_features is not None and len(image_features) > 0:
|
||||||
|
input_embeddings[ids_remove_padding == self.ernie.im_patch_id] = image_features.cast(self.ernie._dtype)
|
||||||
|
return input_embeddings
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ids_remove_padding: paddle.Tensor,
|
ids_remove_padding: paddle.Tensor,
|
||||||
image_features: Optional[paddle.Tensor],
|
image_features: Optional[paddle.Tensor],
|
||||||
forward_meta: ForwardMeta,
|
forward_meta: ForwardMeta,
|
||||||
):
|
):
|
||||||
|
input_embeddings = self.get_input_embeddings(
|
||||||
|
ids_remove_padding=ids_remove_padding, image_features=image_features
|
||||||
|
)
|
||||||
|
self._input_embeddings.copy_(input_embeddings, False)
|
||||||
|
vl_moe_meta = self.ernie.prepare_vl_moe_meta(ids_remove_padding=ids_remove_padding)
|
||||||
|
|
||||||
hidden_states = self.ernie(
|
hidden_states = self.ernie(
|
||||||
|
input_embeddings=self._input_embeddings,
|
||||||
ids_remove_padding=ids_remove_padding,
|
ids_remove_padding=ids_remove_padding,
|
||||||
image_features=image_features,
|
|
||||||
forward_meta=forward_meta,
|
forward_meta=forward_meta,
|
||||||
|
vl_moe_meta=vl_moe_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
Reference in New Issue
Block a user