[MM_PROCESS] add _extract_labels (#2879)

This commit is contained in:
LokeZhou
2025-07-17 14:20:01 +08:00
committed by GitHub
parent dbb9e2506b
commit f50c25178b

View File

@@ -77,6 +77,7 @@ class DataProcessor:
CLS_TOKEN = "<|begin_of_sentence|>"
SEP_TOKEN = "<|end_of_sentence|>"
EOS_TOKEN = "</s>"
IMG_START = "<|IMAGE_START|>"
IMG_END = "<|IMAGE_END|>"
VID_START = "<|VIDEO_START|>"
@@ -125,6 +126,7 @@ class DataProcessor:
# Special tokens and IDs
self.cls_token = self.CLS_TOKEN
self.sep_token = self.SEP_TOKEN
self.eos_token = self.EOS_TOKEN
self.image_start = self.IMG_START
self.image_end = self.IMG_END
self.video_start = self.VID_START
@@ -132,6 +134,9 @@ class DataProcessor:
self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>")
self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start)
self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start)
self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.sep_token)
self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.eos_token)
self.token_type_mapping = self._build_token_type_mapping()
self.is_training = True
@@ -204,7 +209,7 @@ class DataProcessor:
return outputs
def request2ids(self, request: Dict[str, Any]) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
def request2ids(self, request: Dict[str, Any],tgts: List[str]=None) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
"""
Convert chat messages into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
@@ -258,6 +263,10 @@ class DataProcessor:
self._add_video(frames, outputs)
image_message_index += 1
self._add_text(prompt_token_ids[image_start_index:], outputs)
if self.is_training:
assert tgts, f"training must give tgt !"
self._extract_labels(outputs,tgts)
return outputs
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
@@ -339,6 +348,26 @@ class DataProcessor:
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
def _extract_labels(self, outputs: Dict, tgts: List[str]) -> None:
input_ids = copy.deepcopy(outputs['input_ids'])
labels = [self.tokenizer.ignored_index] * len(input_ids)
tgt_count=input_ids.count(self.sep_token_id)
assert tgt_count==len(tgts),f'len(tgts) != len(src) {len(tgts)} vs {tgt_count}'
tgt_index=0
for i,token_id in enumerate(input_ids):
if token_id==self.sep_token_id:
labels_token = self.tokenizer.tokenize(tgts[tgt_index])
labels_token_id = self.tokenizer.convert_tokens_to_ids(labels_token)
labels[i-len(labels_token_id):i]=labels_token_id
labels[i] = self.eos_token_id #</s>
tgt_index += 1
outputs['labels']=labels
def _load_and_process_video(self, url: str, item: Dict) -> List[Image.Image]:
reader, meta, path = read_video_decord(url, save_to_disk=False)