mirror of
https://huggingface.co/spaces/H-Liu1997/TANGO
synced 2025-09-26 23:45:52 +08:00
498 lines
20 KiB
Python
498 lines
20 KiB
Python
"""
|
|
input: json file with video, audio, motion paths
|
|
output: igraph object with nodes containing video, audio, motion, position, velocity, axis_angle, previous, next, frame, fps
|
|
|
|
preprocess:
|
|
1. assume you have a video for one speaker in folder, listed in
|
|
-- video_a.mp4
|
|
-- video_b.mp4
|
|
run process_video.py to extract frames and audio
|
|
"""
|
|
|
|
import os
|
|
import smplx
|
|
import torch
|
|
import numpy as np
|
|
import cv2
|
|
import librosa
|
|
import igraph
|
|
import json
|
|
import utils.rotation_conversions as rc
|
|
from moviepy.editor import VideoClip, AudioFileClip, VideoFileClip
|
|
from tqdm import tqdm
|
|
import imageio
|
|
import tempfile
|
|
import argparse
|
|
|
|
|
|
def get_motion_reps_tensor(motion_tensor, smplx_model, pose_fps=30, device='cuda'):
|
|
bs, n, _ = motion_tensor.shape
|
|
motion_tensor = motion_tensor.float().to(device)
|
|
motion_tensor_reshaped = motion_tensor.reshape(bs * n, 165)
|
|
|
|
output = smplx_model(
|
|
betas=torch.zeros(bs * n, 300, device=device),
|
|
transl=torch.zeros(bs * n, 3, device=device),
|
|
expression=torch.zeros(bs * n, 100, device=device),
|
|
jaw_pose=torch.zeros(bs * n, 3, device=device),
|
|
global_orient=torch.zeros(bs * n, 3, device=device),
|
|
body_pose=motion_tensor_reshaped[:, 3:21 * 3 + 3],
|
|
left_hand_pose=motion_tensor_reshaped[:, 25 * 3:40 * 3],
|
|
right_hand_pose=motion_tensor_reshaped[:, 40 * 3:55 * 3],
|
|
return_joints=True,
|
|
leye_pose=torch.zeros(bs * n, 3, device=device),
|
|
reye_pose=torch.zeros(bs * n, 3, device=device),
|
|
)
|
|
|
|
joints = output['joints'].reshape(bs, n, 127, 3)[:, :, :55, :]
|
|
dt = 1 / pose_fps
|
|
init_vel = (joints[:, 1:2] - joints[:, 0:1]) / dt
|
|
middle_vel = (joints[:, 2:] - joints[:, :-2]) / (2 * dt)
|
|
final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
|
|
vel = torch.cat([init_vel, middle_vel, final_vel], dim=1)
|
|
|
|
position = joints
|
|
rot_matrices = rc.axis_angle_to_matrix(motion_tensor.reshape(bs, n, 55, 3))
|
|
rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(bs, n, 55, 6)
|
|
|
|
init_vel_ang = (motion_tensor[:, 1:2] - motion_tensor[:, 0:1]) / dt
|
|
middle_vel_ang = (motion_tensor[:, 2:] - motion_tensor[:, :-2]) / (2 * dt)
|
|
final_vel_ang = (motion_tensor[:, -1:] - motion_tensor[:, -2:-1]) / dt
|
|
angular_velocity = torch.cat([init_vel_ang, middle_vel_ang, final_vel_ang], dim=1).reshape(bs, n, 55, 3)
|
|
|
|
rep15d = torch.cat([position, vel, rot6d, angular_velocity], dim=3).reshape(bs, n, 55 * 15)
|
|
|
|
return {
|
|
"position": position,
|
|
"velocity": vel,
|
|
"rotation": rot6d,
|
|
"axis_angle": motion_tensor,
|
|
"angular_velocity": angular_velocity,
|
|
"rep15d": rep15d,
|
|
}
|
|
|
|
|
|
|
|
def get_motion_reps(motion, smplx_model, pose_fps=30):
|
|
gt_motion_tensor = motion["poses"]
|
|
n = gt_motion_tensor.shape[0]
|
|
bs = 1
|
|
gt_motion_tensor = torch.from_numpy(gt_motion_tensor).float().to(device).unsqueeze(0)
|
|
gt_motion_tensor_reshaped = gt_motion_tensor.reshape(bs * n, 165)
|
|
output = smplx_model(
|
|
betas=torch.zeros(bs * n, 300).to(device),
|
|
transl=torch.zeros(bs * n, 3).to(device),
|
|
expression=torch.zeros(bs * n, 100).to(device),
|
|
jaw_pose=torch.zeros(bs * n, 3).to(device),
|
|
global_orient=torch.zeros(bs * n, 3).to(device),
|
|
body_pose=gt_motion_tensor_reshaped[:, 3:21 * 3 + 3],
|
|
left_hand_pose=gt_motion_tensor_reshaped[:, 25 * 3:40 * 3],
|
|
right_hand_pose=gt_motion_tensor_reshaped[:, 40 * 3:55 * 3],
|
|
return_joints=True,
|
|
leye_pose=torch.zeros(bs * n, 3).to(device),
|
|
reye_pose=torch.zeros(bs * n, 3).to(device),
|
|
)
|
|
joints = output["joints"].detach().cpu().numpy().reshape(n, 127, 3)[:, :55, :]
|
|
dt = 1 / pose_fps
|
|
init_vel = (joints[1:2] - joints[0:1]) / dt
|
|
middle_vel = (joints[2:] - joints[:-2]) / (2 * dt)
|
|
final_vel = (joints[-1:] - joints[-2:-1]) / dt
|
|
vel = np.concatenate([init_vel, middle_vel, final_vel], axis=0)
|
|
position = joints
|
|
rot_matrices = rc.axis_angle_to_matrix(gt_motion_tensor.reshape(1, n, 55, 3))[0]
|
|
rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(n, 55, 6).cpu().numpy()
|
|
|
|
init_vel = (motion["poses"][1:2] - motion["poses"][0:1]) / dt
|
|
middle_vel = (motion["poses"][2:] - motion["poses"][:-2]) / (2 * dt)
|
|
final_vel = (motion["poses"][-1:] - motion["poses"][-2:-1]) / dt
|
|
angular_velocity = np.concatenate([init_vel, middle_vel, final_vel], axis=0).reshape(n, 55, 3)
|
|
|
|
rep15d = np.concatenate([
|
|
position,
|
|
vel,
|
|
rot6d,
|
|
angular_velocity],
|
|
axis=2
|
|
).reshape(n, 55*15)
|
|
return {
|
|
"position": position,
|
|
"velocity": vel,
|
|
"rotation": rot6d,
|
|
"axis_angle": motion["poses"],
|
|
"angular_velocity": angular_velocity,
|
|
"rep15d": rep15d,
|
|
"trans": motion["trans"]
|
|
}
|
|
|
|
def create_graph(json_path, smplx_model):
|
|
fps = 30
|
|
data_meta = json.load(open(json_path, "r"))
|
|
graph = igraph.Graph(directed=True)
|
|
global_i = 0
|
|
for data_item in data_meta:
|
|
video_path = os.path.join(data_item['video_path'], data_item['video_id'] + ".mp4")
|
|
# audio_path = os.path.join(data_item['audio_path'], data_item['video_id'] + ".wav")
|
|
motion_path = os.path.join(data_item['motion_path'], data_item['video_id'] + ".npz")
|
|
video_id = data_item.get("video_id", "")
|
|
motion = np.load(motion_path, allow_pickle=True)
|
|
motion_reps = get_motion_reps(motion, smplx_model)
|
|
position = motion_reps['position']
|
|
velocity = motion_reps['velocity']
|
|
trans = motion_reps['trans']
|
|
axis_angle = motion_reps['axis_angle']
|
|
# audio, sr = librosa.load(audio_path, sr=None)
|
|
# audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
|
|
all_frames = []
|
|
reader = imageio.get_reader(video_path)
|
|
all_frames = []
|
|
for frame in reader:
|
|
all_frames.append(frame)
|
|
video_frames = np.array(all_frames)
|
|
min_frames = min(len(video_frames), position.shape[0])
|
|
position = position[:min_frames]
|
|
velocity = velocity[:min_frames]
|
|
video_frames = video_frames[:min_frames]
|
|
# print(min_frames)
|
|
for i in tqdm(range(min_frames)):
|
|
if i == 0:
|
|
previous = -1
|
|
next_node = global_i + 1
|
|
elif i == min_frames - 1:
|
|
previous = global_i - 1
|
|
next_node = -1
|
|
else:
|
|
previous = global_i - 1
|
|
next_node = global_i + 1
|
|
graph.add_vertex(
|
|
idx=global_i,
|
|
name=video_id,
|
|
motion=motion_reps,
|
|
position=position[i],
|
|
velocity=velocity[i],
|
|
axis_angle=axis_angle[i],
|
|
trans=trans[i],
|
|
# audio=audio[],
|
|
video=video_frames[i],
|
|
previous=previous,
|
|
next=next_node,
|
|
frame=i,
|
|
fps=fps,
|
|
)
|
|
global_i += 1
|
|
return graph
|
|
|
|
def create_edges(graph, threshold_edges):
|
|
adaptive_length = [-4, -3, -2, -1, 1, 2, 3, 4]
|
|
# print()
|
|
for i, node in enumerate(graph.vs):
|
|
current_position = node['position']
|
|
current_velocity = node['velocity']
|
|
current_trans = node['trans']
|
|
# print(current_position.shape, current_velocity.shape)
|
|
avg_position = np.zeros(current_position.shape[0])
|
|
avg_velocity = np.zeros(current_position.shape[0])
|
|
avg_trans = 0
|
|
count = 0
|
|
for node_offset in adaptive_length:
|
|
idx = i + node_offset
|
|
if idx < 0 or idx >= len(graph.vs):
|
|
continue
|
|
if node_offset < 0:
|
|
if graph.vs[idx]['next'] == -1:continue
|
|
else:
|
|
if graph.vs[idx]['previous'] == -1:continue
|
|
# add check
|
|
other_node = graph.vs[idx]
|
|
other_position = other_node['position']
|
|
other_velocity = other_node['velocity']
|
|
other_trans = other_node['trans']
|
|
# print(other_position.shape, other_velocity.shape)
|
|
avg_position += np.linalg.norm(current_position - other_position, axis=1)
|
|
avg_velocity += np.linalg.norm(current_velocity - other_velocity, axis=1)
|
|
avg_trans += np.linalg.norm(current_trans - other_trans, axis=0)
|
|
count += 1
|
|
|
|
if count == 0:
|
|
continue
|
|
threshold_position = avg_position / count
|
|
threshold_velocity = avg_velocity / count
|
|
threshold_trans = avg_trans / count
|
|
# print(threshold_position, threshold_velocity, threshold_trans)
|
|
for j, other_node in enumerate(graph.vs):
|
|
if i == j:
|
|
continue
|
|
if j == node['previous'] or j == node['next']:
|
|
graph.add_edge(i, j, is_continue=1)
|
|
continue
|
|
other_position = other_node['position']
|
|
other_velocity = other_node['velocity']
|
|
other_trans = other_node['trans']
|
|
position_similarity = np.linalg.norm(current_position - other_position, axis=1)
|
|
velocity_similarity = np.linalg.norm(current_velocity - other_velocity, axis=1)
|
|
trans_similarity = np.linalg.norm(current_trans - other_trans, axis=0)
|
|
if trans_similarity < threshold_trans:
|
|
if np.sum(position_similarity < threshold_edges*threshold_position) >= 45 and np.sum(velocity_similarity < threshold_edges*threshold_velocity) >= 45:
|
|
graph.add_edge(i, j, is_continue=0)
|
|
|
|
print(f"nodes: {len(graph.vs)}, edges: {len(graph.es)}")
|
|
in_degrees = graph.indegree()
|
|
out_degrees = graph.outdegree()
|
|
avg_in_degree = sum(in_degrees) / len(in_degrees)
|
|
avg_out_degree = sum(out_degrees) / len(out_degrees)
|
|
print(f"Average In-degree: {avg_in_degree}")
|
|
print(f"Average Out-degree: {avg_out_degree}")
|
|
print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}")
|
|
print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}")
|
|
# igraph.plot(graph, target="/content/test.png", bbox=(1000, 1000), vertex_size=10)
|
|
return graph
|
|
|
|
def random_walk(graph, walk_length, start_node=None):
|
|
if start_node is None:
|
|
start_node = np.random.choice(graph.vs)
|
|
walk = [start_node]
|
|
is_continue = [1]
|
|
for _ in range(walk_length):
|
|
current_node = walk[-1]
|
|
neighbor_indices = graph.neighbors(current_node.index, mode='OUT')
|
|
if not neighbor_indices:
|
|
break
|
|
next_idx = np.random.choice(neighbor_indices)
|
|
edge_id = graph.get_eid(current_node.index, next_idx)
|
|
is_cont = graph.es[edge_id]['is_continue']
|
|
walk.append(graph.vs[next_idx])
|
|
is_continue.append(is_cont)
|
|
return walk, is_continue
|
|
|
|
import subprocess
|
|
def path_visualization(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False):
|
|
all_frames = [node['video'] for node in path]
|
|
average_dis_continue = 1 - sum(is_continue) / len(is_continue)
|
|
if verbose_continue:
|
|
print("average_dis_continue:", average_dis_continue)
|
|
|
|
fps = graph.vs[0]['fps']
|
|
duration = len(all_frames) / fps
|
|
|
|
def make_frame(t):
|
|
idx = min(int(t * fps), len(all_frames) - 1)
|
|
return all_frames[idx]
|
|
|
|
video_only_path = 'video_only.mp4' # Temporary file
|
|
video_clip = VideoClip(make_frame, duration=duration)
|
|
video_clip.write_videofile(
|
|
video_only_path,
|
|
codec='libx264',
|
|
fps=fps,
|
|
audio=False
|
|
)
|
|
|
|
# Optionally, ensure audio and video durations match
|
|
if audio_path is not None:
|
|
audio_clip = AudioFileClip(audio_path)
|
|
video_duration = video_clip.duration
|
|
audio_duration = audio_clip.duration
|
|
|
|
if audio_duration > video_duration:
|
|
# Trim the audio
|
|
trimmed_audio_path = 'trimmed_audio.aac'
|
|
audio_clip = audio_clip.subclip(0, video_duration)
|
|
audio_clip.write_audiofile(trimmed_audio_path)
|
|
audio_input = trimmed_audio_path
|
|
else:
|
|
audio_input = audio_path
|
|
|
|
# Use FFmpeg to combine video and audio
|
|
ffmpeg_command = [
|
|
'ffmpeg', '-y',
|
|
'-i', video_only_path,
|
|
'-i', audio_input,
|
|
'-c:v', 'copy',
|
|
'-c:a', 'aac',
|
|
'-strict', 'experimental',
|
|
save_path
|
|
]
|
|
subprocess.check_call(ffmpeg_command)
|
|
|
|
# Clean up temporary files if necessary
|
|
os.remove(video_only_path)
|
|
if audio_input != audio_path:
|
|
os.remove(audio_input)
|
|
|
|
if return_motion:
|
|
all_motion = [node['axis_angle'] for node in path]
|
|
all_motion = np.stack(all_motion, 0)
|
|
return all_motion
|
|
|
|
|
|
|
|
def generate_transition_video(frame_start_path, frame_end_path, output_video_path):
|
|
import subprocess
|
|
import os
|
|
|
|
# Define the path to your model and inference script
|
|
model_path = "./frame-interpolation-pytorch/film_net_fp32.pt"
|
|
inference_script = "./frame-interpolation-pytorch/inference.py"
|
|
|
|
# Build the command to run the inference script
|
|
command = [
|
|
"python",
|
|
inference_script,
|
|
model_path,
|
|
frame_start_path,
|
|
frame_end_path,
|
|
"--save_path", output_video_path,
|
|
"--gpu",
|
|
"--frames", "3",
|
|
"--fps", "30"
|
|
]
|
|
|
|
# Run the command
|
|
try:
|
|
subprocess.run(command, check=True)
|
|
print(f"Generated transition video saved at {output_video_path}")
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error occurred while generating transition video: {e}")
|
|
|
|
|
|
def path_visualization_v2(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False):
|
|
'''
|
|
this is for hugging face demo for fast interpolation. our paper use a diffusion based interpolation method
|
|
'''
|
|
all_frames = [node['video'] for node in path]
|
|
average_dis_continue = 1 - sum(is_continue) / len(is_continue)
|
|
if verbose_continue:
|
|
print("average_dis_continue:", average_dis_continue)
|
|
duration = len(all_frames) / graph.vs[0]['fps']
|
|
|
|
# First loop: Confirm where blending is needed
|
|
discontinuity_indices = []
|
|
for i, cont in enumerate(is_continue):
|
|
if cont == 0:
|
|
discontinuity_indices.append(i)
|
|
|
|
# Identify blending positions without overlapping
|
|
blend_positions = []
|
|
processed_frames = set()
|
|
for i in discontinuity_indices:
|
|
# Define the frames for blending: i-2 to i+2
|
|
start_idx = i - 2
|
|
end_idx = i + 2
|
|
# Check index boundaries
|
|
if start_idx < 0 or end_idx >= len(all_frames):
|
|
continue # Skip if indices are out of bounds
|
|
# Check for overlapping frames
|
|
overlap = any(idx in processed_frames for idx in range(i - 1, i + 2))
|
|
if overlap:
|
|
continue # Skip if frames have been processed
|
|
# Mark frames as processed
|
|
processed_frames.update(range(i - 1, i + 2))
|
|
blend_positions.append(i)
|
|
|
|
# Second loop: Perform blending
|
|
temp_dir = tempfile.mkdtemp(prefix='blending_frames_')
|
|
for i in tqdm(blend_positions):
|
|
start_frame_idx = i - 2
|
|
end_frame_idx = i + 2
|
|
frame_start = all_frames[start_frame_idx]
|
|
frame_end = all_frames[end_frame_idx]
|
|
frame_start_path = os.path.join(temp_dir, f'frame_{start_frame_idx}.png')
|
|
frame_end_path = os.path.join(temp_dir, f'frame_{end_frame_idx}.png')
|
|
# Save the start and end frames as images
|
|
imageio.imwrite(frame_start_path, frame_start)
|
|
imageio.imwrite(frame_end_path, frame_end)
|
|
|
|
# Call FiLM API to generate video
|
|
generated_video_path = os.path.join(temp_dir, f'generated_{start_frame_idx}_{end_frame_idx}.mp4')
|
|
generate_transition_video(frame_start_path, frame_end_path, generated_video_path)
|
|
|
|
# Read the generated video frames
|
|
reader = imageio.get_reader(generated_video_path)
|
|
generated_frames = [frame for frame in reader]
|
|
reader.close()
|
|
|
|
# Replace the middle three frames (i-1, i, i+1) in all_frames
|
|
total_generated_frames = len(generated_frames)
|
|
if total_generated_frames < 5:
|
|
print(f"Generated video has insufficient frames ({total_generated_frames}). Skipping blending at position {i}.")
|
|
continue
|
|
middle_start = 1 # Start index for middle 3 frames
|
|
middle_frames = generated_frames[middle_start:middle_start+3]
|
|
for idx, frame_idx in enumerate(range(i - 1, i + 2)):
|
|
all_frames[frame_idx] = middle_frames[idx]
|
|
|
|
# Create the video clip
|
|
def make_frame(t):
|
|
idx = min(int(t * graph.vs[0]['fps']), len(all_frames) - 1)
|
|
return all_frames[idx]
|
|
|
|
video_clip = VideoClip(make_frame, duration=duration)
|
|
if audio_path is not None:
|
|
audio_clip = AudioFileClip(audio_path)
|
|
video_clip = video_clip.set_audio(audio_clip)
|
|
video_clip.write_videofile(save_path, codec='libx264', fps=graph.vs[0]['fps'], audio_codec='aac')
|
|
|
|
if return_motion:
|
|
all_motion = [node['axis_angle'] for node in path]
|
|
all_motion = np.stack(all_motion, 0)
|
|
return all_motion
|
|
|
|
|
|
def graph_pruning(graph):
|
|
ascc = graph.clusters(mode="STRONG")
|
|
lascc = ascc.giant()
|
|
print(f"before nodes: {len(graph.vs)}, edges: {len(graph.es)}")
|
|
print(f"after nodes: {len(lascc.vs)}, edges: {len(lascc.es)}")
|
|
in_degrees = lascc.indegree()
|
|
out_degrees = lascc.outdegree()
|
|
avg_in_degree = sum(in_degrees) / len(in_degrees)
|
|
avg_out_degree = sum(out_degrees) / len(out_degrees)
|
|
print(f"Average In-degree: {avg_in_degree}")
|
|
print(f"Average Out-degree: {avg_out_degree}")
|
|
print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}")
|
|
print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}")
|
|
return lascc
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--json_save_path", type=str, default="")
|
|
parser.add_argument("--graph_save_path", type=str, default="")
|
|
parser.add_argument("--threshold", type=float, default=1.0)
|
|
args = parser.parse_args()
|
|
json_path = args.json_save_path
|
|
graph_path = args.graph_save_path
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
smplx_model = smplx.create(
|
|
"./emage/smplx_models/",
|
|
model_type='smplx',
|
|
gender='NEUTRAL_2020',
|
|
use_face_contour=False,
|
|
num_betas=300,
|
|
num_expression_coeffs=100,
|
|
ext='npz',
|
|
use_pca=False,
|
|
).to(device).eval()
|
|
|
|
# single_test
|
|
# graph = create_graph('/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/Abortion_Laws_-_Last_Week_Tonight_with_John_Oliver_HBO-DRauXXz6t0Y.webm.json')
|
|
graph = create_graph(json_path, smplx_model)
|
|
graph = create_edges(graph, args.threshold)
|
|
# pool_path = "/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/show-oliver-test.pkl"
|
|
# graph = igraph.Graph.Read_Pickle(fname=pool_path)
|
|
# graph = igraph.Graph.Read_Pickle(fname="/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/test.pkl")
|
|
# walk, is_continue = random_walk(graph, 100)
|
|
# motion = path_visualization(graph, walk, is_continue, "./test.mp4", audio_path=None, verbose_continue=True, return_motion=True)
|
|
# print(motion.shape)
|
|
save_graph = graph.write_pickle(fname=graph_path)
|
|
# graph = graph_pruning(graph)
|
|
|
|
# show-oliver
|
|
# json_path = "/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/"
|
|
# pre_node_path = "/content/drive/MyDrive/003_Codes/TANGO/datasets/cached_graph/show_oliver_test/"
|
|
# for json_file in tqdm(os.listdir(json_path)):
|
|
# graph = create_graph(os.path.join(json_path, json_file))
|
|
# graph = create_edges(graph)
|
|
# if not len(graph.vs) >= 1500:
|
|
# print(f"skip: {len(graph.vs)}", json_file)
|
|
# graph.write_pickle(fname=os.path.join(pre_node_path, json_file.split(".")[0] + ".pkl"))
|
|
# print(f"Graph saved at {json_file.split('.')[0]}.pkl") |