Asset-Frameworker/gui/prediction_handler.py
2025-04-29 18:26:13 +02:00

232 lines
12 KiB
Python

# gui/prediction_handler.py
import logging
from pathlib import Path
import time # For potential delays if needed
import os # For cpu_count
from concurrent.futures import ThreadPoolExecutor, as_completed # For parallel prediction
# --- PySide6 Imports ---
from PySide6.QtCore import QObject, Signal, QThread # Import QThread
# --- Backend Imports ---
# Adjust path to ensure modules can be found relative to this file's location
import sys
script_dir = Path(__file__).parent
project_root = script_dir.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
try:
from configuration import Configuration, ConfigurationError
from asset_processor import AssetProcessor, AssetProcessingError
BACKEND_AVAILABLE = True
except ImportError as e:
print(f"ERROR (PredictionHandler): Failed to import backend modules: {e}")
# Define placeholders if imports fail
Configuration = None
AssetProcessor = None
ConfigurationError = Exception
AssetProcessingError = Exception
BACKEND_AVAILABLE = False
log = logging.getLogger(__name__)
# Basic config if logger hasn't been set up elsewhere
if not log.hasHandlers():
logging.basicConfig(level=logging.INFO, format='%(levelname)s (PredictHandler): %(message)s')
class PredictionHandler(QObject):
"""
Handles running predictions in a separate thread to avoid GUI freezes.
"""
# --- Signals ---
# Emits a list of dictionaries, each representing a file row for the table
# Dict format: {'original_path': str, 'predicted_asset_name': str | None, 'predicted_output_name': str | None, 'status': str, 'details': str | None, 'source_asset': str}
prediction_results_ready = Signal(list)
# Emitted when all predictions for a batch are done
prediction_finished = Signal()
# Emitted for status updates
status_message = Signal(str, int)
def __init__(self, parent=None):
super().__init__(parent)
self._is_running = False
# No explicit cancel needed for prediction for now, it should be fast per-item
@property
def is_running(self):
return self._is_running
def _predict_single_asset(self, input_path_str: str, config: Configuration) -> list[dict]:
"""
Helper method to predict a single asset. Runs within the ThreadPoolExecutor.
Returns a list of prediction dictionaries for the asset, or a single error dict.
"""
input_path = Path(input_path_str)
source_asset_name = input_path.name # For reference in the results
asset_results = []
try:
# Create AssetProcessor instance (needs dummy output path)
# Ensure AssetProcessor is thread-safe or create a new instance per thread.
# Based on its structure (using temp dirs), creating new instances should be safe.
processor = AssetProcessor(input_path, config, Path(".")) # Dummy output path
# Get detailed file predictions
detailed_predictions = processor.get_detailed_file_predictions()
if detailed_predictions is None:
log.error(f"Detailed prediction failed critically for {input_path_str}. Adding asset-level error.")
# Add a single error entry for the whole asset if the method returns None
asset_results.append({
'original_path': source_asset_name, # Use asset name as placeholder
'predicted_asset_name': None, # New key
'predicted_output_name': None, # New key
'status': 'Error',
'details': 'Critical prediction failure (check logs)',
'source_asset': source_asset_name
})
else:
log.debug(f"Received {len(detailed_predictions)} detailed predictions for {input_path_str}.")
# Add source_asset key and ensure correct keys exist
for prediction_dict in detailed_predictions:
# Ensure all expected keys are present, even if None
result_entry = {
'original_path': prediction_dict.get('original_path', '[Missing Path]'),
'predicted_asset_name': prediction_dict.get('predicted_asset_name'), # New key
'predicted_output_name': prediction_dict.get('predicted_output_name'), # New key
'status': prediction_dict.get('status', 'Error'),
'details': prediction_dict.get('details', '[Missing Details]'),
'source_asset': source_asset_name # Add the source asset identifier
}
asset_results.append(result_entry)
except AssetProcessingError as e: # Catch errors during processor instantiation or prediction setup
log.error(f"Asset processing error during prediction setup for {input_path_str}: {e}")
asset_results.append({
'original_path': source_asset_name,
'predicted_asset_name': None,
'predicted_output_name': None,
'status': 'Error',
'details': f'Asset Error: {e}',
'source_asset': source_asset_name
})
except Exception as e: # Catch unexpected errors
log.exception(f"Unexpected error during prediction for {input_path_str}: {e}")
asset_results.append({
'original_path': source_asset_name,
'predicted_asset_name': None,
'predicted_output_name': None,
'status': 'Error',
'details': f'Unexpected Error: {e}',
'source_asset': source_asset_name
})
finally:
# Cleanup for the single asset prediction if needed (AssetProcessor handles its own temp dir)
pass
return asset_results
def run_prediction(self, input_paths: list[str], preset_name: str):
"""
Runs the prediction logic for the given paths and preset using a ThreadPoolExecutor.
This method is intended to be run in a separate QThread.
"""
if self._is_running:
log.warning("Prediction is already running.")
return
if not BACKEND_AVAILABLE:
log.error("Backend modules not available. Cannot run prediction.")
self.status_message.emit("Error: Backend components missing.", 5000)
self.prediction_finished.emit()
return
if not preset_name:
log.warning("No preset selected for prediction.")
self.status_message.emit("No preset selected.", 3000)
self.prediction_finished.emit()
return
self._is_running = True
thread_id = QThread.currentThread() # Get current thread object
log.info(f"[{time.time():.4f}][T:{thread_id}] --> Entered PredictionHandler.run_prediction. Starting run for {len(input_paths)} items, Preset='{preset_name}'")
self.status_message.emit(f"Updating preview for {len(input_paths)} items...", 0)
config = None # Load config once if possible
try:
config = Configuration(preset_name)
except ConfigurationError as e:
log.error(f"Failed to load configuration for preset '{preset_name}': {e}")
self.status_message.emit(f"Error loading preset '{preset_name}': {e}", 5000)
# Emit error for all items? Or just finish? Finish for now.
self.prediction_finished.emit()
self._is_running = False
return
except Exception as e:
log.exception(f"Unexpected error loading configuration for preset '{preset_name}': {e}")
self.status_message.emit(f"Unexpected error loading preset '{preset_name}'.", 5000)
self.prediction_finished.emit()
return
all_file_results = [] # Accumulate results here
futures = []
# Determine number of workers - use half the cores, minimum 1, max 8?
max_workers = min(max(1, (os.cpu_count() or 1) // 2), 8)
log.info(f"Using ThreadPoolExecutor with max_workers={max_workers} for prediction.")
try:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit tasks for each input path
for input_path_str in input_paths:
future = executor.submit(self._predict_single_asset, input_path_str, config)
futures.append(future)
# Process results as they complete
for future in as_completed(futures):
try:
# Result is a list of dicts for one asset
asset_result_list = future.result()
if asset_result_list: # Check if list is not empty
all_file_results.extend(asset_result_list)
except Exception as exc:
# This catches errors within the future execution itself if not handled by _predict_single_asset
log.error(f'Prediction task generated an exception: {exc}', exc_info=True)
# We might not know which input path failed here easily without more mapping
# Add a generic error?
all_file_results.append({
'original_path': '[Unknown Asset - Executor Error]',
'predicted_asset_name': None,
'predicted_output_name': None,
'status': 'Error',
'details': f'Executor Error: {exc}',
'source_asset': '[Unknown]'
})
except Exception as pool_exc:
log.exception(f"An error occurred with the prediction ThreadPoolExecutor: {pool_exc}")
self.status_message.emit(f"Error during prediction setup: {pool_exc}", 5000)
# Add a generic error if the pool fails
all_file_results.append({
'original_path': '[Prediction Pool Error]',
'predicted_asset_name': None,
'predicted_output_name': None,
'status': 'Error',
'details': f'Pool Error: {pool_exc}',
'source_asset': '[System]'
})
# Emit the combined list of detailed file results at the end
# Note: thread_id was already defined earlier in this function
log.info(f"[{time.time():.4f}][T:{thread_id}] Parallel prediction run finished. Preparing to emit {len(all_file_results)} file results.")
# <<< Add logging before emit >>>
log.debug(f"[{time.time():.4f}][T:{thread_id}] Type of all_file_results before emit: {type(all_file_results)}")
try:
log.debug(f"[{time.time():.4f}][T:{thread_id}] Content of all_file_results (first 5) before emit: {all_file_results[:5]}")
except Exception as e:
log.error(f"[{time.time():.4f}][T:{thread_id}] Error logging all_file_results content: {e}")
# <<< End added logging >>>
log.info(f"[{time.time():.4f}][T:{thread_id}] Emitting prediction_results_ready signal...")
self.prediction_results_ready.emit(all_file_results)
log.info(f"[{time.time():.4f}][T:{thread_id}] Emitted prediction_results_ready signal.")
self.status_message.emit("Preview update complete.", 3000)
self.prediction_finished.emit()
self._is_running = False
log.info(f"[{time.time():.4f}][T:{thread_id}] <-- Exiting PredictionHandler.run_prediction.")