Update create_graph.py

This commit is contained in:
Haiyang Liu
2024-10-19 09:04:01 +00:00
committed by system
parent 86bc7b0fdd
commit 763739cc73

View File

@@ -181,7 +181,7 @@ def create_graph(json_path, smplx_model):
global_i += 1
return graph
def create_edges(graph):
def create_edges(graph, threshold_edges):
adaptive_length = [-4, -3, -2, -1, 1, 2, 3, 4]
# print()
for i, node in enumerate(graph.vs):
@@ -231,7 +231,7 @@ def create_edges(graph):
velocity_similarity = np.linalg.norm(current_velocity - other_velocity, axis=1)
trans_similarity = np.linalg.norm(current_trans - other_trans, axis=0)
if trans_similarity < threshold_trans:
if np.sum(position_similarity < threshold_position) >= 45 and np.sum(velocity_similarity < threshold_velocity) >= 45:
if np.sum(position_similarity < threshold_edges*threshold_position) >= 45 and np.sum(velocity_similarity < threshold_edges*threshold_velocity) >= 45:
graph.add_edge(i, j, is_continue=0)
print(f"nodes: {len(graph.vs)}, edges: {len(graph.es)}")
@@ -456,6 +456,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--json_save_path", type=str, default="")
parser.add_argument("--graph_save_path", type=str, default="")
parser.add_argument("--threshold", type=float, default=1.0)
args = parser.parse_args()
json_path = args.json_save_path
graph_path = args.graph_save_path
@@ -475,16 +476,15 @@ if __name__ == '__main__':
# single_test
# graph = create_graph('/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/Abortion_Laws_-_Last_Week_Tonight_with_John_Oliver_HBO-DRauXXz6t0Y.webm.json')
graph = create_graph(json_path, smplx_model)
graph = create_edges(graph)
graph = create_edges(graph, args.threshold)
# pool_path = "/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/show-oliver-test.pkl"
# graph = igraph.Graph.Read_Pickle(fname=pool_path)
# graph = igraph.Graph.Read_Pickle(fname="/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/test.pkl")
walk, is_continue = random_walk(graph, 100)
motion = path_visualization(graph, walk, is_continue, "./test.mp4", audio_path=None, verbose_continue=True, return_motion=True)
# walk, is_continue = random_walk(graph, 100)
# motion = path_visualization(graph, walk, is_continue, "./test.mp4", audio_path=None, verbose_continue=True, return_motion=True)
# print(motion.shape)
save_graph = graph.write_pickle(fname=graph_path)
graph = graph_pruning(graph)
# graph = graph_pruning(graph)
# show-oliver
# json_path = "/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/"