[Optimization] support mm prefill batch (#5313)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* support mm prefill batch

* update code

* update code

* update code

* update code

* fix encoder cache bug

* update code

* update code

* fix bug

* fix paddle ocr bug

* fix xpu bug

* update code
This commit is contained in:
kevin
2025-12-11 22:21:14 +08:00
committed by GitHub
parent 7116982995
commit 954a145d57
14 changed files with 769 additions and 296 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
import copy
import os
import queue
import time
@@ -28,7 +29,7 @@ from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.engine.pooling_params import PoolingParams
from fastdeploy.engine.request import Request, RequestType
from fastdeploy.engine.request import ImagePosition, Request, RequestType
from fastdeploy.model_executor.graph_optimization.utils import (
GPUMemoryChecker,
profile_run_guard,
@@ -382,188 +383,242 @@ class GPUModelRunner(ModelRunnerBase):
schemata_key,
)
def get_chunked_inputs(self, req: Request):
def _process_mm_features(self, request_list: List[Request]):
"""
Get inputs in current chunk
"""
prefill_start_index = req.prefill_start_index
prefill_end_index = req.prefill_end_index
inputs = req.multimodal_inputs
input_ids = inputs["input_ids"][prefill_start_index:prefill_end_index]
token_type_ids = inputs["token_type_ids"][prefill_start_index:prefill_end_index]
image_type_ids = inputs["image_type_ids"][req.image_type_ids_start : req.image_type_ids_end]
images = inputs["images"][req.image_start : req.image_end]
grid_thw = inputs["grid_thw"][req.num_image_start : req.num_image_end]
mm_hashes = inputs["mm_hashes"][req.num_image_start : req.num_image_end]
return (
input_ids,
token_type_ids,
image_type_ids,
images,
grid_thw,
mm_hashes,
)
def batch_uncached_inputs(self, req: Request):
"""
Batch uncached multimodal inputs
"""
(input_ids, token_type_ids, image_type_ids, images, grid_thw, mm_hashes) = self.get_chunked_inputs(req)
image_type_ids_size = grid_thw[:, 0]
image_type_ids_split = np.cumsum(image_type_ids_size)[:-1]
image_type_ids_lst = np.array_split(image_type_ids, image_type_ids_split, axis=0)
images_size = np.prod(grid_thw, axis=1)
images_split = np.cumsum(images_size)[:-1]
images_lst = np.array_split(images, images_split, axis=0)
assert len(image_type_ids_lst) == len(
mm_hashes
), f"image_type_ids_lst length {len(image_type_ids_lst)} != mm_hashes length {len(mm_hashes)}"
assert len(images_lst) == len(
mm_hashes
), f"images_lst length {len(images_lst)} != mm_hashes length {len(mm_hashes)}"
uncached_image_type_ids = []
uncached_images = []
uncached_grid_thw = []
uncached_mm_hashes = []
for i, mm_hash in enumerate(mm_hashes):
if mm_hash in self.encoder_cache:
continue
uncached_image_type_ids.append(image_type_ids_lst[i])
uncached_images.append(images_lst[i])
uncached_grid_thw.append(grid_thw[i])
uncached_mm_hashes.append(mm_hash)
uncached_input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64)
uncached_token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
if len(uncached_mm_hashes) > 0:
uncached_image_type_ids = paddle.to_tensor(np.hstack(uncached_image_type_ids), dtype=paddle.int64)
uncached_images = paddle.to_tensor(
np.vstack(uncached_images), dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16"
)
uncached_grid_thw = paddle.to_tensor(uncached_grid_thw, dtype=paddle.int64)
return (
uncached_input_ids,
uncached_token_type_ids,
uncached_image_type_ids,
uncached_images,
uncached_grid_thw,
uncached_mm_hashes,
)
def scatter_and_cache_features(self, image_features, inputs):
"""
Split batched image features and cache them
"""
merge_size = 2
grid_thw = inputs["grid_thw"]
mm_hashes = inputs["mm_hashes"]
image_features_size = (paddle.prod(grid_thw[:, 1:], axis=1) // (merge_size**2)).tolist()
image_features_lst = paddle.split(image_features, image_features_size, axis=0)
assert len(image_features_lst) == len(
mm_hashes
), f"image_features_lst length {len(image_features_lst)} != mm_hashes length {len(mm_hashes)}"
for i, mm_hash in enumerate(mm_hashes):
self.encoder_cache[mm_hash] = image_features_lst[i].cpu()
def _apply_mm_inputs(self, request: Request, multi_vision_inputs: dict, rope_3d_position_ids: dict):
"""
Apply multimodal inputs to share_inputs
Process and cache vision features from model
- add image_features, extract and cache vision features from model
- add rope_emb, rotate position embeddings
"""
if self.encoder_cache:
evict_mm_hashes = request.get("evict_mm_hashes", None)
if evict_mm_hashes:
for mm_hash in evict_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
if not self.enable_mm:
return
inputs = request.multimodal_inputs
if request.with_image:
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][request.image_start : request.image_end].cuda()
)
multi_vision_inputs["grid_thw_lst"].extend(
inputs["grid_thw"][request.num_image_start : request.num_image_end]
)
if "vit_seqlen" in inputs:
multi_vision_inputs["cu_seqlens"].extend(
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
)
if "vit_position_ids" in inputs:
multi_vision_inputs["vit_position_ids_lst"].extend(
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
)
self.share_inputs["image_features"] = None
multi_vision_inputs = {
"images_lst": [],
"grid_thw_lst": [],
"vit_position_ids_lst": [],
"cu_seqlens": [0],
"encoder_cache_info": [],
"feature_position_list": [],
}
rope_3d_position_ids = {
"position_ids_idx": [],
"position_ids_lst": [],
"position_ids_offset": [0],
"max_tokens_lst": [],
}
for request in request_list:
if request.task_type.value != RequestType.PREFILL.value:
continue
if self.encoder_cache is not None:
evict_mm_hashes = request.get("evict_mm_hashes", None)
if evict_mm_hashes:
for mm_hash in evict_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
position_ids = request.multimodal_inputs["position_ids"]
rope_3d_position_ids["position_ids_idx"].append(request.idx)
rope_3d_position_ids["position_ids_lst"].append(position_ids)
rope_3d_position_ids["position_ids_offset"].append(
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
)
if self.is_pooling_model:
rope_3d_position_ids["max_tokens_lst"].append(0)
else:
vision_inputs = inputs
if self.encoder_cache:
(
vision_inputs["input_ids"],
vision_inputs["token_type_ids"],
vision_inputs["image_type_ids"],
vision_inputs["images"],
vision_inputs["grid_thw"],
vision_inputs["mm_hashes"],
) = self.batch_uncached_inputs(request)
if len(vision_inputs["mm_hashes"]) > 0:
# uncached multimodal inputs exist
image_features = self.extract_vision_features(vision_inputs)
self.scatter_and_cache_features(image_features, vision_inputs)
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
full_image_features_lst = []
for mm_hash in inputs["mm_hashes"][request.num_image_start : request.num_image_end]:
feature = self.encoder_cache[mm_hash].cuda()
full_image_features_lst.append(feature)
image_features = paddle.concat(full_image_features_lst, axis=0)
else:
(
input_ids,
token_type_ids,
image_type_ids,
images,
grid_thw,
mm_hashes,
) = self.get_chunked_inputs(request)
vision_inputs["input_ids"] = paddle.to_tensor(input_ids, dtype=paddle.int64)
vision_inputs["token_type_ids"] = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
vision_inputs["image_type_ids"] = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
vision_inputs["images"] = paddle.to_tensor(
images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16"
if request.with_image:
inputs = request.multimodal_inputs
if self.encoder_cache is not None:
if envs.FD_ENABLE_MAX_PREFILL:
if "vit_seqlen" in inputs:
vit_seqlen_list = inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
if "vit_position_ids" in inputs:
vit_position_ids_list = inputs["vit_position_ids"][
request.num_image_start : request.num_image_end
]
grid_thw_list = inputs["grid_thw"][request.num_image_start : request.num_image_end]
mm_hashes_list = inputs["mm_hashes"][request.num_image_start : request.num_image_end]
feature_positions = self._get_feature_positions(
mm_positions=inputs["mm_positions"][request.num_image_start : request.num_image_end],
prefill_start_index=request.prefill_start_index,
prefill_end_index=request.prefill_end_index,
)
vision_inputs["grid_thw"] = paddle.to_tensor(grid_thw, dtype=paddle.int64)
vision_inputs["mm_hashes"] = mm_hashes
image_start_idx = request.num_image_start
image_features = self.extract_vision_features(vision_inputs)
logger.debug(
f"request {request.request_id} start process encoder info, image_start_idx: {image_start_idx} "
f"grid_thw_list: {grid_thw_list}, feature_positions: {feature_positions}, mm_hashes_list: {mm_hashes_list}"
)
for i, mm_hash in enumerate(mm_hashes_list):
image_offset = np.prod(grid_thw_list[i])
logger.debug(
f"run idx {i} with mm_hash {mm_hash} image_offset: {image_offset} grid_thw: {grid_thw_list[i]}"
)
if mm_hash in self.encoder_cache:
multi_vision_inputs["encoder_cache_info"].append((mm_hash, feature_positions[i], True))
continue
# part of the first image may be already cached
if "ernie" in self.model_config.model_type:
actual_image_token_num = paddle.sum(vision_inputs["input_ids"] == self.model_config.im_patch_id)
elif "qwen" in self.model_config.model_type:
actual_image_token_num = paddle.sum(
vision_inputs["input_ids"] == vision_inputs["image_patch_id"]
) + paddle.sum(vision_inputs["input_ids"] == vision_inputs["video_patch_id"])
multi_vision_inputs["encoder_cache_info"].append((mm_hash, feature_positions[i], False))
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][image_start_idx : image_start_idx + image_offset].cuda()
)
multi_vision_inputs["grid_thw_lst"].append(paddle.to_tensor(grid_thw_list[i]))
multi_vision_inputs["cu_seqlens"].append(vit_seqlen_list[i])
multi_vision_inputs["vit_position_ids_lst"].append(vit_position_ids_list[i])
else:
multi_vision_inputs["images_lst"].append(
paddle.to_tensor(
inputs["images"][image_start_idx : image_start_idx + image_offset],
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
)
)
multi_vision_inputs["grid_thw_lst"].append(
paddle.to_tensor(grid_thw_list[i], dtype=paddle.int64)
)
image_start_idx += image_offset
else:
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
self.share_inputs["image_features"] = image_features[-actual_image_token_num:]
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][request.image_start : request.image_end].cuda()
)
multi_vision_inputs["grid_thw_lst"].extend(
paddle.to_tensor(inputs["grid_thw"][request.num_image_start : request.num_image_end])
)
multi_vision_inputs["cu_seqlens"].extend(
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["vit_position_ids_lst"].extend(
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
)
else:
multi_vision_inputs["images_lst"].append(
paddle.to_tensor(
inputs["images"][request.image_start : request.image_end],
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
)
)
multi_vision_inputs["grid_thw_lst"].extend(
paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end],
dtype=paddle.int64,
)
)
position_ids = request.multimodal_inputs["position_ids"]
rope_3d_position_ids["position_ids_idx"].append(request.idx)
rope_3d_position_ids["position_ids_lst"].append(position_ids)
rope_3d_position_ids["position_ids_offset"].append(
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
multi_vision_inputs["feature_position_list"].extend(
self._get_feature_positions(
mm_positions=inputs["mm_positions"][request.num_image_start : request.num_image_end],
prefill_start_index=request.prefill_start_index,
prefill_end_index=request.prefill_end_index,
)
)
if self.encoder_cache is not None:
if len(multi_vision_inputs["images_lst"]) > 0 or len(multi_vision_inputs["encoder_cache_info"]) > 0:
image_features_output = None
if len(multi_vision_inputs["images_lst"]) > 0:
image_features_output = self.extract_vision_features(multi_vision_inputs)
logger.debug(f"encoder_cache_info: {multi_vision_inputs['encoder_cache_info']}")
merge_image_features, feature_idx, thw_idx = [], 0, 0
for mm_hash, feature_position, use_cache in multi_vision_inputs["encoder_cache_info"]:
if use_cache:
assert mm_hash in self.encoder_cache, f"{mm_hash} not in encoder cache"
mm_feature = self.encoder_cache[mm_hash].cuda()
else:
assert (
image_features_output is not None
), f"image_features_output is None, images_lst length: {len(multi_vision_inputs['images_lst'])}"
mm_token_lenght = paddle.prod(multi_vision_inputs["grid_thw_lst"][thw_idx]) // 4
mm_feature = image_features_output[feature_idx : feature_idx + mm_token_lenght]
# add feature to encoder cache
self.encoder_cache[mm_hash] = mm_feature.detach().cpu()
feature_idx += mm_token_lenght
thw_idx += 1
feature_start = feature_position.offset
feature_end = feature_position.offset + feature_position.length
merge_image_features.append(mm_feature[feature_start:feature_end])
self.share_inputs["image_features"] = paddle.concat(merge_image_features, axis=0)
logger.debug(
f"merge_image_features length: {len(merge_image_features)}, features shape: {self.share_inputs['image_features'].shape}"
)
elif len(multi_vision_inputs["images_lst"]) > 0:
assert len(multi_vision_inputs["feature_position_list"]) == len(
multi_vision_inputs["grid_thw_lst"]
), f"{multi_vision_inputs['feature_position_list']} != {multi_vision_inputs['grid_thw_lst']}"
merge_image_features, feature_idx, thw_idx = [], 0, 0
image_features_output = self.extract_vision_features(multi_vision_inputs)
for feature_position in multi_vision_inputs["feature_position_list"]:
mm_token_lenght = paddle.prod(multi_vision_inputs["grid_thw_lst"][thw_idx]) // 4
mm_feature = image_features_output[feature_idx : feature_idx + mm_token_lenght]
feature_start = feature_position.offset
feature_end = feature_position.offset + feature_position.length
merge_image_features.append(mm_feature[feature_start:feature_end])
feature_idx += mm_token_lenght
thw_idx += 1
self.share_inputs["image_features"] = paddle.concat(merge_image_features, axis=0)
if len(rope_3d_position_ids["position_ids_idx"]) > 0:
packed_position_ids = paddle.to_tensor(
np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64"
)
rope_3d_lst = self.prepare_rope3d(
packed_position_ids,
rope_3d_position_ids["max_tokens_lst"],
rope_3d_position_ids["position_ids_offset"],
)
for i, idx in enumerate(rope_3d_position_ids["position_ids_idx"]):
self.share_inputs["rope_emb"][idx : idx + 1, :] = rope_3d_lst[i]
def _get_feature_positions(
self, mm_positions: List[ImagePosition], prefill_start_index: int, prefill_end_index: int
):
"""
Filter and adjust ImagePosition objects that fall within the specified prefill range.
Args:
mm_positions: List of ImagePosition objects to filter
prefill_start_index: Start index of the prefill range
prefill_end_index: End index of the prefill range
Returns:
List of ImagePosition objects that are within or intersect with the prefill range
"""
feature_positions = []
for position in mm_positions:
position_start = position.offset
position_end = position.offset + position.length
if position_end <= prefill_start_index or position_start >= prefill_end_index:
continue
elif position_start >= prefill_start_index and position_end <= prefill_end_index:
new_position = copy.deepcopy(position)
new_position.offset = 0
feature_positions.append(new_position)
else:
new_position = copy.deepcopy(position)
# Adjust offset if it starts before prefill_start_index
if position_start < prefill_start_index:
new_position.offset = prefill_start_index - position_start
new_position.length = min(position_end, prefill_end_index) - prefill_start_index
# Adjust length if it extends beyond prefill_end_index
elif position_end > prefill_end_index:
new_position.offset = 0
new_position.length = prefill_end_index - position_start
feature_positions.append(new_position)
logger.debug(
f"get feature_positions, original positions: {mm_positions}, filtered positions: {feature_positions}"
)
if self.is_pooling_model:
rope_3d_position_ids["max_tokens_lst"].append(0)
else:
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
return feature_positions
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
"""
@@ -580,15 +635,6 @@ class GPUModelRunner(ModelRunnerBase):
has_decode_task = False
batch_pooling_params = []
self.share_inputs["image_features"] = None
multi_vision_inputs = {"images_lst": [], "grid_thw_lst": [], "vit_position_ids_lst": [], "cu_seqlens": [0]}
rope_3d_position_ids = {
"position_ids_idx": [],
"position_ids_lst": [],
"position_ids_offset": [0],
"max_tokens_lst": [],
}
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
@@ -621,9 +667,6 @@ class GPUModelRunner(ModelRunnerBase):
prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
if self.enable_mm:
self._apply_mm_inputs(request, multi_vision_inputs, rope_3d_position_ids)
if not self.is_pooling_model:
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
# Enable thinking
@@ -763,21 +806,7 @@ class GPUModelRunner(ModelRunnerBase):
self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens)
if len(multi_vision_inputs["images_lst"]) > 0:
self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs)
if len(rope_3d_position_ids["position_ids_idx"]) > 0:
packed_position_ids = paddle.to_tensor(
np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64"
)
rope_3d_lst = self.prepare_rope3d(
packed_position_ids,
rope_3d_position_ids["max_tokens_lst"],
rope_3d_position_ids["position_ids_offset"],
)
for i, idx in enumerate(rope_3d_position_ids["position_ids_idx"]):
self.share_inputs["rope_emb"][idx : idx + 1, :] = rope_3d_lst[i]
self._process_mm_features(req_dicts)
if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
@@ -2826,21 +2855,19 @@ class GPUModelRunner(ModelRunnerBase):
)
return result
def extract_vision_features_ernie(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
assert inputs["images"] is not None
grid_thw = inputs["grid_thw"]
def extract_vision_features_ernie(self, vision_inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor:
"""
vision feature extactor for ernie-vl
"""
assert len(vision_inputs["images_lst"]) > 0, "at least one image needed"
grid_thw = paddle.to_tensor(vision_inputs["grid_thw_lst"], dtype=paddle.int64)
# ernie-vl has images norm
images = inputs["images"].cast("float32")
images = paddle.concat(vision_inputs["images_lst"]).cast("float32")
images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor
images = images / self.image_preprocess.image_std_tensor
images = images.cast("bfloat16")
token_type_ids = inputs["token_type_ids"]
token_type_ids_w_video = token_type_ids
input_ids = inputs["input_ids"]
# convert to img patch id
image_mask = input_ids == self.model_config.im_patch_id
image_type_ids = inputs["image_type_ids"]
with paddle.amp.auto_cast(
True,
custom_black_list=self.amp_black,
@@ -2857,21 +2884,15 @@ class GPUModelRunner(ModelRunnerBase):
# ernie-vl has resampler_model
image_features = self.model.resampler_model(
image_features,
image_mask,
token_type_ids_w_video,
image_type_ids,
grid_thw,
)
return image_features
def extract_vision_features_qwen(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
if envs.FD_ENABLE_MAX_PREFILL:
images = paddle.concat(inputs["images_lst"]).cast("bfloat16")
grid_thw = paddle.to_tensor(inputs["grid_thw_lst"], dtype="int64")
else:
assert inputs["images"] is not None
grid_thw = inputs["grid_thw"]
images = inputs["images"]
def extract_vision_features_qwen(self, vision_inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor:
assert len(vision_inputs["images_lst"]) > 0, "at least one image needed"
grid_thw = paddle.to_tensor(vision_inputs["grid_thw_lst"], dtype=paddle.int64)
images = paddle.concat(vision_inputs["images_lst"]).cast("bfloat16")
with paddle.amp.auto_cast(
True,
custom_black_list=self.amp_black,
@@ -2883,7 +2904,7 @@ class GPUModelRunner(ModelRunnerBase):
return image_features
def extract_vision_features_paddleocr(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
def extract_vision_features_paddleocr(self, inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor:
if envs.FD_ENABLE_MAX_PREFILL:
inputs["vit_position_ids_lst"] = np.concatenate(inputs["vit_position_ids_lst"])
images = paddle.concat(inputs["images_lst"]).cast("bfloat16")
@@ -2927,14 +2948,14 @@ class GPUModelRunner(ModelRunnerBase):
return image_features
@paddle.no_grad()
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
def extract_vision_features(self, multi_vision_inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor:
"""extract_vision_features"""
if "ernie" in self.model_config.model_type:
return self.extract_vision_features_ernie(inputs)
return self.extract_vision_features_ernie(multi_vision_inputs)
elif "qwen" in self.model_config.model_type:
return self.extract_vision_features_qwen(inputs)
return self.extract_vision_features_qwen(multi_vision_inputs)
elif "paddleocr" in self.model_config.model_type:
return self.extract_vision_features_paddleocr(inputs)
return self.extract_vision_features_paddleocr(multi_vision_inputs)
else:
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")