import torch import argparse import os import glob from safetensors.torch import save_file from typing import Any, Optional def find_first_tensor(data: Any) -> Optional[torch.Tensor]: """ Recursively finds the first torch.Tensor in a loaded data structure. """ # Base case: check if the current data is a tensor if isinstance(data, torch.Tensor): return data # Recursive step: if data is a dict, iterate through its values if isinstance(data, dict): for key, value in data.items(): result = find_first_tensor(value) if result is not None: # We don't print the key here to keep the batch log cleaner return result # Recursive step: if data is a list or tuple, iterate through its items if isinstance(data, (list, tuple)): for item in data: result = find_first_tensor(item) if result is not None: return result # If no tensor is found in any path, return None return None def convert_single_file(input_path: str, key_name: str) -> bool: """ Converts a single .pt file to a .safetensors file. Args: input_path (str): Path to the input .pt file. key_name (str): The key name for the embedding tensor. Returns: bool: True if conversion was successful, False otherwise. """ # Generate the output path automatically base, _ = os.path.splitext(input_path) output_path = base + ".safetensors" try: # 1. Load the .pt file to the CPU # print(f" -> Loading file...") # Optional: uncomment for more verbosity loaded_data = torch.load(input_path, map_location="cpu") # 2. Find the first tensor embedding_tensor = find_first_tensor(loaded_data) if embedding_tensor is None: print(f" ❌ Error: Could not find any tensor in '{os.path.basename(input_path)}'. Skipping.") return False # 3. Create the state_dict and save state_dict = {key_name: embedding_tensor.contiguous()} save_file(state_dict, output_path) print(f" ✅ Successfully converted -> {os.path.basename(output_path)}") return True except Exception as e: print(f" ❌ An unexpected error occurred with '{os.path.basename(input_path)}': {e}. Skipping.") return False def batch_convert_folder(folder_path: str, key_name: str): """ Finds all .pt files in a folder and converts them to .safetensors. """ # 1. Validate the input directory if not os.path.isdir(folder_path): print(f"❌ Error: The provided path is not a directory: {folder_path}") return # 2. Find all .pt files in the directory search_pattern = os.path.join(folder_path, "*.pt") pt_files = glob.glob(search_pattern) if not pt_files: print(f"ℹ️ No .pt files found in the directory: {folder_path}") return total_files = len(pt_files) print(f"🔍 Found {total_files} .pt file(s). Starting batch conversion...") success_count = 0 # 3. Loop through each file and convert it for i, file_path in enumerate(pt_files): print(f"\n--- [{i+1}/{total_files}] Processing: {os.path.basename(file_path)} ---") if convert_single_file(file_path, key_name): success_count += 1 # 4. Print the final summary print(f"\n🎉 Batch conversion complete.") print(f"Successfully converted {success_count} out of {total_files} files.") if __name__ == "__main__": # Set up the command-line argument parser parser = argparse.ArgumentParser( description="Batch converts all .pt files in a specified folder to the .safetensors format." ) parser.add_argument( "input_folder", type=str, help="Path to the input folder containing .pt files." ) parser.add_argument( "-k", "--key-name", type=str, default="emb_params", help="The key name for the embedding in the output files. (default: emb_params)" ) args = parser.parse_args() # Execute the batch conversion batch_convert_folder(args.input_folder, args.key_name)