The following script splits a source Project into a number of child projects, each containing a specified number of labels from the source Project. The number of child Projects created depends on the total number of labels contained in the source Projects, and how many labels each child Project contains.
import typer
from encord import EncordUserClient, Project
from encord.objects import LabelRowV2
from pathlib import Path
def get_destination_projects(user_client: EncordUserClient, source_project: Project, n_target_projects: int) -> List[Project]:
dest_project_ids = []
for idx in range(n_target_projects):
target_title = source_project.title + f" slice {idx + 1}"
if found_projects := user_client.get_projects(title_eq=target_title):
if len(found_projects) > 1:
print(f"A few projects with name {target_title} found. Can't proceed, as target name should be unique.")
exit(1)
if found_projects:
dest_project_id = found_projects[0]["project"]["project_hash"]
dest_project_ids.append(dest_project_id)
continue
dest_project_id = source_project.copy_project(copy_datasets=True, copy_collaborators=True, new_title=target_title)
dest_project_ids.append(dest_project_id)
print(f"Destination project ids: {dest_project_ids}")
return [user_client.get_project(project_hash=p) for p in dest_project_ids]
def get_destination_label_row(destination_project: Project, data_hash: str) -> LabelRowV2:
result = destination_project.list_label_rows_v2(data_hashes=[data_hash])
if not result:
print("No such data hash in the destination project, something wrong!")
exit(1)
return result[0]
def main(ssh_key: Path = "<private_key_path>", source_project: str = "<project_hash>", target_project_size: int = <project_size>, continue_from: int = <start_number>, ):
user_client = EncordUserClient.create_with_ssh_private_key(
ssh_private_key_path=ssh_key
)
project = user_client.get_project(project_hash=source_project)
n_items = len(project.list_label_rows_v2())
n_target_projects = int(n_items / target_project_size) + 1
print(f"Project {project.title} has {n_items} entries.")
if n_target_projects < 2:
print("Nothing to do!")
exit(0)
print(f"Splitting source project into {n_target_projects} of size {target_project_size}")
destination_projects = get_destination_projects(user_client, project, n_target_projects)
source_label_rows = project.list_label_rows_v2()
for idx, (src_label_row, destination_project) in enumerate(zip(source_label_rows, [p for p in destination_projects for _ in range(target_project_size)])):
if continue_from > idx:
continue
src_label_row.initialise_labels()
dst_label_row = get_destination_label_row(destination_project, src_label_row.data_hash)
dst_label_row.initialise_labels()
src = src_label_row.to_encord_dict()
dst = dst_label_row.to_encord_dict()
dst['object_answers'] = src['object_answers']
dst['classification_answers'] = src['classification_answers']
dst['object_actions'] = src['object_actions']
dst['data_units'] = src['data_units']
dst_label_row.from_labels_dict(dst)
dst_label_row.save()
print(f"{idx + 1}/{len(source_label_rows)}: '{src_label_row.data_title}' labels copied to project {destination_project.title}")
if __name__ == "__main__":
typer.run(main)