Module generate_all
Script to dynamically generate and update all variables in Python modules. Restricts all to symbols defined within the module itself. Generates backup files with all definitions in utils/all/ for review.
Expand source code
#!/usr/bin/env python3
"""
Script to dynamically generate and update __all__ variables in Python modules.
Restricts __all__ to symbols defined within the module itself.
Generates backup files with __all__ definitions in utils/__all__/ for review.
"""
# Maintained by INRAE\olivier.vitrac@agroparistech.fr
# Revision history: 2024-12-08
import os
import inspect
import pkgutil
import sys
def extract_symbols(module_name):
"""
Extracts all public symbols (classes, functions, variables) from a module,
restricting to symbols defined in the module itself.
"""
try:
module = __import__(module_name, fromlist=["*"])
symbols = []
for name, obj in inspect.getmembers(module):
# Skip private symbols and ensure the object is defined in the current module
if name.startswith("_"):
continue
# Include only symbols where the __module__ attribute matches the module_name
origin = getattr(obj, "__module__", None)
if origin == module_name:
symbols.append(name)
return symbols
except ImportError as e:
print(f"Error importing module {module_name}: {e}")
return []
def write_backup_all(module_name, symbols, all_dir):
"""
Generates a backup file for __all__ in the utils/__all__/ directory.
"""
# Prepare the backup directory
if not os.path.exists(all_dir):
os.makedirs(all_dir)
# Write the __all__ variable to the backup file
backup_path = os.path.join(all_dir, f"{module_name.split('.')[-1]}_all.py")
with open(backup_path, "w") as f:
f.write(f"# __all__ for {module_name}\n")
f.write(f"__all__ = {symbols!r}\n")
print(f"Generated backup __all__ for {module_name}: {backup_path}")
def update_module_with_all(module_name, symbols, base_path, package_root):
"""
Updates the `__all__` variable in the target module.
Adds or replaces the `__all__` definition in the module file.
"""
# Resolve the correct path for the module file
relative_path = module_name.replace(package_root + ".", "").replace(".", "/") + ".py"
module_path = os.path.join(base_path, relative_path)
if not os.path.exists(module_path):
print(f"Module file not found: {module_path}")
return
with open(module_path, "r") as f:
lines = f.readlines()
# Detect encoding comment (e.g., # -*- coding: utf-8 -*-)
encoding_line_index = -1
for i, line in enumerate(lines):
if line.strip().startswith("# -*- coding:"):
encoding_line_index = i
break
# Detect imports
import_end_index = encoding_line_index + 1 if encoding_line_index >= 0 else 0
for i, line in enumerate(lines[import_end_index:], start=import_end_index):
if not line.startswith("import") and not line.startswith("from") and line.strip():
break
import_end_index = i + 1
# Check if __all__ already exists
updated = False
for i, line in enumerate(lines):
if line.strip().startswith("__all__"):
lines[i] = f"__all__ = {symbols!r}\n"
updated = True
break
# Insert __all__ after imports (or encoding line if no imports)
if not updated:
insert_index = import_end_index
lines.insert(insert_index, f"\n__all__ = {symbols!r}\n")
# Write back the updated module
with open(module_path, "w") as f:
f.writelines(lines)
print(f"Updated {module_name} with __all__ = {symbols!r}")
def generate_all_for_package(package_name, base_path, package_root, all_dir):
"""
Generates and updates the `__all__` variable for all modules in the given package.
"""
try:
package = __import__(package_name, fromlist=["*"])
for loader, module_name, is_pkg in pkgutil.iter_modules(package.__path__):
if is_pkg:
continue # Skip sub-packages
module_full_name = f"{package_name}.{module_name}"
print(f"Processing module: {module_full_name}")
symbols = extract_symbols(module_full_name)
if symbols:
# Write a backup __all__ file in utils/__all__/
write_backup_all(module_full_name, symbols, all_dir)
# Update the module file with the __all__ variable
update_module_with_all(module_full_name, symbols, base_path, package_root)
else:
print(f"No public symbols found in {module_full_name}")
except ImportError as e:
print(f"Error importing package {package_name}: {e}")
if __name__ == "__main__":
# Ensure the script is run from utils/
script_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.basename(script_dir) == "utils":
print("Error: This script must be run from the Pizza3/utils/ directory.")
sys.exit(1)
# Adjust the paths
mainfolder = os.path.abspath(os.path.join(script_dir, "../"))
pizza_base_path = os.path.join(mainfolder, "pizza") # Base path for Pizza package
all_dir = os.path.join(script_dir, "__all__") # Directory for backup __all__ files
sys.path.insert(0, mainfolder) # Add mainfolder as root
sys.path.insert(0, pizza_base_path) # Add pizza directory as package root
# Packages to process
packages = ["pizza", "pizza.private"]
# Generate __all__ for each package
for package in packages:
generate_all_for_package(package, pizza_base_path, "pizza", all_dir)
print("\n__all__ generation and updates completed.")
Functions
def extract_symbols(module_name)
-
Extracts all public symbols (classes, functions, variables) from a module, restricting to symbols defined in the module itself.
Expand source code
def extract_symbols(module_name): """ Extracts all public symbols (classes, functions, variables) from a module, restricting to symbols defined in the module itself. """ try: module = __import__(module_name, fromlist=["*"]) symbols = [] for name, obj in inspect.getmembers(module): # Skip private symbols and ensure the object is defined in the current module if name.startswith("_"): continue # Include only symbols where the __module__ attribute matches the module_name origin = getattr(obj, "__module__", None) if origin == module_name: symbols.append(name) return symbols except ImportError as e: print(f"Error importing module {module_name}: {e}") return []
def generate_all_for_package(package_name, base_path, package_root, all_dir)
-
Generates and updates the
__all__
variable for all modules in the given package.Expand source code
def generate_all_for_package(package_name, base_path, package_root, all_dir): """ Generates and updates the `__all__` variable for all modules in the given package. """ try: package = __import__(package_name, fromlist=["*"]) for loader, module_name, is_pkg in pkgutil.iter_modules(package.__path__): if is_pkg: continue # Skip sub-packages module_full_name = f"{package_name}.{module_name}" print(f"Processing module: {module_full_name}") symbols = extract_symbols(module_full_name) if symbols: # Write a backup __all__ file in utils/__all__/ write_backup_all(module_full_name, symbols, all_dir) # Update the module file with the __all__ variable update_module_with_all(module_full_name, symbols, base_path, package_root) else: print(f"No public symbols found in {module_full_name}") except ImportError as e: print(f"Error importing package {package_name}: {e}")
def update_module_with_all(module_name, symbols, base_path, package_root)
-
Updates the
__all__
variable in the target module. Adds or replaces the__all__
definition in the module file.Expand source code
def update_module_with_all(module_name, symbols, base_path, package_root): """ Updates the `__all__` variable in the target module. Adds or replaces the `__all__` definition in the module file. """ # Resolve the correct path for the module file relative_path = module_name.replace(package_root + ".", "").replace(".", "/") + ".py" module_path = os.path.join(base_path, relative_path) if not os.path.exists(module_path): print(f"Module file not found: {module_path}") return with open(module_path, "r") as f: lines = f.readlines() # Detect encoding comment (e.g., # -*- coding: utf-8 -*-) encoding_line_index = -1 for i, line in enumerate(lines): if line.strip().startswith("# -*- coding:"): encoding_line_index = i break # Detect imports import_end_index = encoding_line_index + 1 if encoding_line_index >= 0 else 0 for i, line in enumerate(lines[import_end_index:], start=import_end_index): if not line.startswith("import") and not line.startswith("from") and line.strip(): break import_end_index = i + 1 # Check if __all__ already exists updated = False for i, line in enumerate(lines): if line.strip().startswith("__all__"): lines[i] = f"__all__ = {symbols!r}\n" updated = True break # Insert __all__ after imports (or encoding line if no imports) if not updated: insert_index = import_end_index lines.insert(insert_index, f"\n__all__ = {symbols!r}\n") # Write back the updated module with open(module_path, "w") as f: f.writelines(lines) print(f"Updated {module_name} with __all__ = {symbols!r}")
def write_backup_all(module_name, symbols, all_dir)
-
Generates a backup file for all in the utils/all/ directory.
Expand source code
def write_backup_all(module_name, symbols, all_dir): """ Generates a backup file for __all__ in the utils/__all__/ directory. """ # Prepare the backup directory if not os.path.exists(all_dir): os.makedirs(all_dir) # Write the __all__ variable to the backup file backup_path = os.path.join(all_dir, f"{module_name.split('.')[-1]}_all.py") with open(backup_path, "w") as f: f.write(f"# __all__ for {module_name}\n") f.write(f"__all__ = {symbols!r}\n") print(f"Generated backup __all__ for {module_name}: {backup_path}")