288 lines
14 KiB
Python
288 lines
14 KiB
Python
from rule_structure import SourceRule, AssetRule, FileRule
|
|
# gui/prediction_handler.py
|
|
import logging
|
|
from pathlib import Path
|
|
import time # For potential delays if needed
|
|
import os # For cpu_count
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed # For parallel prediction
|
|
from collections import defaultdict
|
|
|
|
# --- PySide6 Imports ---
|
|
from PySide6.QtCore import QObject, Signal, QThread, Slot # Import QThread and Slot
|
|
|
|
# --- Backend Imports ---
|
|
# Adjust path to ensure modules can be found relative to this file's location
|
|
import sys
|
|
script_dir = Path(__file__).parent
|
|
project_root = script_dir.parent
|
|
if str(project_root) not in sys.path:
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
try:
|
|
from configuration import Configuration, ConfigurationError
|
|
from asset_processor import AssetProcessor, AssetProcessingError
|
|
BACKEND_AVAILABLE = True
|
|
except ImportError as e:
|
|
print(f"ERROR (PredictionHandler): Failed to import backend modules: {e}")
|
|
# Define placeholders if imports fail
|
|
Configuration = None
|
|
AssetProcessor = None
|
|
ConfigurationError = Exception
|
|
AssetProcessingError = Exception
|
|
BACKEND_AVAILABLE = False
|
|
|
|
log = logging.getLogger(__name__)
|
|
# Basic config if logger hasn't been set up elsewhere
|
|
if not log.hasHandlers():
|
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s (PredictHandler): %(message)s')
|
|
|
|
|
|
class PredictionHandler(QObject):
|
|
"""
|
|
Handles running predictions in a separate thread to avoid GUI freezes.
|
|
"""
|
|
# --- Signals ---
|
|
# Emits a list of dictionaries, each representing a file row for the table
|
|
# Dict format: {'original_path': str, 'predicted_asset_name': str | None, 'predicted_output_name': str | None, 'status': str, 'details': str | None, 'source_asset': str}
|
|
prediction_results_ready = Signal(list)
|
|
# Emitted when the hierarchical rule structure is ready
|
|
rule_hierarchy_ready = Signal(object) # Emits a SourceRule object
|
|
# Emitted when all predictions for a batch are done
|
|
prediction_finished = Signal()
|
|
# Emitted for status updates
|
|
status_message = Signal(str, int)
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self._is_running = False
|
|
# No explicit cancel needed for prediction for now, it should be fast per-item
|
|
|
|
@property
|
|
def is_running(self):
|
|
return self._is_running
|
|
|
|
def _predict_single_asset(self, input_path_str: str, config: Configuration, rules: SourceRule) -> list[dict] | dict:
|
|
"""
|
|
Helper method to run detailed file prediction for a single input path.
|
|
Runs within the ThreadPoolExecutor.
|
|
Returns a list of file prediction dictionaries for the input, or a dictionary representing an error.
|
|
"""
|
|
input_path = Path(input_path_str)
|
|
source_asset_name = input_path.name # For reference in error reporting
|
|
|
|
try:
|
|
# Create AssetProcessor instance (needs dummy output path for prediction)
|
|
# The detailed prediction method handles its own workspace setup/cleanup
|
|
processor = AssetProcessor(input_path, config, Path(".")) # Dummy output path
|
|
|
|
# Get the detailed file predictions
|
|
# This method returns a list of dictionaries
|
|
detailed_predictions = processor.get_detailed_file_predictions(rules)
|
|
|
|
if detailed_predictions is None:
|
|
log.error(f"AssetProcessor.get_detailed_file_predictions returned None for {input_path_str}.")
|
|
# Return a list containing a single error entry for consistency
|
|
return [{
|
|
'original_path': source_asset_name,
|
|
'predicted_asset_name': None,
|
|
'predicted_output_name': None,
|
|
'status': 'Error',
|
|
'details': 'Prediction returned no results',
|
|
'source_asset': source_asset_name
|
|
}]
|
|
|
|
# Add the source_asset name to each prediction result for grouping later
|
|
for prediction in detailed_predictions:
|
|
prediction['source_asset'] = source_asset_name
|
|
|
|
log.debug(f"Generated {len(detailed_predictions)} detailed predictions for {input_path_str}.")
|
|
return detailed_predictions # Return the list of dictionaries
|
|
|
|
except AssetProcessingError as e:
|
|
log.error(f"Asset processing error during prediction for {input_path_str}: {e}")
|
|
# Return a list containing a single error entry for consistency
|
|
return [{
|
|
'original_path': source_asset_name,
|
|
'predicted_asset_name': None,
|
|
'predicted_output_name': None,
|
|
'status': 'Error',
|
|
'details': f'Asset Error: {e}',
|
|
'source_asset': source_asset_name
|
|
}]
|
|
except Exception as e:
|
|
log.exception(f"Unexpected error during prediction for {input_path_str}: {e}")
|
|
# Return a list containing a single error entry for consistency
|
|
return [{
|
|
'original_path': source_asset_name,
|
|
'predicted_asset_name': None,
|
|
'predicted_output_name': None,
|
|
'status': 'Error',
|
|
'details': f'Unexpected Error: {e}',
|
|
'source_asset': source_asset_name
|
|
}]
|
|
|
|
|
|
@Slot()
|
|
def run_prediction(self, input_paths: list[str], preset_name: str, rules: SourceRule):
|
|
"""
|
|
Runs the prediction logic for the given paths and preset using a ThreadPoolExecutor.
|
|
Generates the hierarchical rule structure and detailed file predictions.
|
|
This method is intended to be run in a separate QThread.
|
|
"""
|
|
if self._is_running:
|
|
log.warning("Prediction is already running.")
|
|
return
|
|
if not BACKEND_AVAILABLE:
|
|
log.error("Backend modules not available. Cannot run prediction.")
|
|
self.status_message.emit("Error: Backend components missing.", 5000)
|
|
self.prediction_finished.emit()
|
|
return
|
|
if not preset_name:
|
|
log.warning("No preset selected for prediction.")
|
|
self.status_message.emit("No preset selected.", 3000)
|
|
self.prediction_finished.emit()
|
|
return
|
|
|
|
self._is_running = True
|
|
thread_id = QThread.currentThread() # Get current thread object
|
|
log.info(f"[{time.time():.4f}][T:{thread_id}] --> Entered PredictionHandler.run_prediction. Starting run for {len(input_paths)} items, Preset='{preset_name}'")
|
|
self.status_message.emit(f"Updating preview for {len(input_paths)} items...", 0)
|
|
|
|
config = None # Load config once if possible
|
|
try:
|
|
config = Configuration(preset_name)
|
|
except ConfigurationError as e:
|
|
log.error(f"Failed to load configuration for preset '{preset_name}': {e}")
|
|
self.status_message.emit(f"Error loading preset '{preset_name}': {e}", 5000)
|
|
self.prediction_finished.emit()
|
|
self._is_running = False
|
|
return
|
|
except Exception as e:
|
|
log.exception(f"Unexpected error loading configuration for preset '{preset_name}': {e}")
|
|
self.status_message.emit(f"Unexpected error loading preset '{preset_name}'.", 5000)
|
|
self.prediction_finished.emit()
|
|
return
|
|
|
|
# Create the root SourceRule object
|
|
# For now, use a generic name. Later, this might be derived from input paths.
|
|
source_rule = SourceRule()
|
|
log.debug(f"Created root SourceRule object.")
|
|
|
|
# Collect all detailed file prediction results from completed futures
|
|
all_file_prediction_results = []
|
|
|
|
futures = []
|
|
max_workers = min(max(1, (os.cpu_count() or 1) // 2), 8)
|
|
log.info(f"Using ThreadPoolExecutor with max_workers={max_workers} for prediction.")
|
|
|
|
try:
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
# Submit tasks for each input path
|
|
for input_path_str in input_paths:
|
|
# _predict_single_asset now returns a list of file prediction dicts or an error dict list
|
|
future = executor.submit(self._predict_single_asset, input_path_str, config, rules)
|
|
futures.append(future)
|
|
|
|
# Process results as they complete
|
|
for future in as_completed(futures):
|
|
try:
|
|
result = future.result()
|
|
if isinstance(result, list):
|
|
# Extend the main list with results from this asset
|
|
all_file_prediction_results.extend(result)
|
|
elif isinstance(result, dict) and result.get('status') == 'Error':
|
|
# Handle error dictionaries returned by _predict_single_asset (should be in a list now, but handle single dict for safety)
|
|
all_file_prediction_results.append(result)
|
|
else:
|
|
log.error(f'Prediction task returned unexpected result type: {type(result)}')
|
|
all_file_prediction_results.append({
|
|
'original_path': '[Unknown Asset - Unexpected Result]',
|
|
'predicted_asset_name': None,
|
|
'predicted_output_name': None,
|
|
'status': 'Error',
|
|
'details': f'Unexpected result type: {type(result)}',
|
|
'source_asset': '[Unknown]'
|
|
})
|
|
|
|
except Exception as exc:
|
|
log.error(f'Prediction task generated an exception: {exc}', exc_info=True)
|
|
all_file_prediction_results.append({
|
|
'original_path': '[Unknown Asset - Executor Error]',
|
|
'predicted_asset_name': None,
|
|
'predicted_output_name': None,
|
|
'status': 'Error',
|
|
'details': f'Executor Error: {exc}',
|
|
'source_asset': '[Unknown]'
|
|
})
|
|
|
|
except Exception as pool_exc:
|
|
log.exception(f"An error occurred with the prediction ThreadPoolExecutor: {pool_exc}")
|
|
self.status_message.emit(f"Error during prediction setup: {pool_exc}", 5000)
|
|
all_file_prediction_results.append({
|
|
'original_path': '[Prediction Pool Error]',
|
|
'predicted_asset_name': None,
|
|
'predicted_output_name': None,
|
|
'status': 'Error',
|
|
'details': f'Pool Error: {pool_exc}',
|
|
'source_asset': '[System]'
|
|
})
|
|
|
|
|
|
# --- Build the hierarchical rule structure (SourceRule -> AssetRule -> FileRule) ---
|
|
# Group file prediction results by predicted_asset_name
|
|
grouped_by_asset = defaultdict(list)
|
|
for file_pred in all_file_prediction_results:
|
|
# Group by predicted_asset_name, handle None or errors
|
|
asset_name = file_pred.get('predicted_asset_name')
|
|
if asset_name is None:
|
|
# Group files without a predicted asset name under a special key or ignore for hierarchy?
|
|
# Let's group them under their source_asset name for now, but mark them clearly.
|
|
asset_name = f"[{file_pred.get('source_asset', 'UnknownSource')}]" # Use source asset name as a fallback identifier
|
|
log.debug(f"File '{file_pred.get('original_path', 'UnknownPath')}' has no predicted asset name, grouping under '{asset_name}' for hierarchy.")
|
|
grouped_by_asset[asset_name].append(file_pred)
|
|
|
|
# Create AssetRule objects from the grouped results
|
|
asset_rules = []
|
|
for asset_name, file_preds in grouped_by_asset.items():
|
|
# Determine the source_path for the AssetRule (use the source_asset from the first file in the group)
|
|
source_asset_path = file_preds[0].get('source_asset', asset_name) # Fallback to asset_name if source_asset is missing
|
|
asset_rule = AssetRule(asset_name=asset_name)
|
|
|
|
# Create FileRule objects from the file prediction dictionaries
|
|
for file_pred in file_preds:
|
|
file_rule = FileRule(
|
|
file_path=file_pred.get('original_path', 'UnknownPath'),
|
|
map_type_override=None, # Assuming these are not predicted here
|
|
resolution_override=None, # Assuming these are not predicted here
|
|
channel_merge_instructions={}, # Assuming these are not predicted here
|
|
output_format_override=None # Assuming these are not predicted here
|
|
)
|
|
asset_rule.files.append(file_rule)
|
|
|
|
asset_rules.append(asset_rule)
|
|
|
|
# Populate the SourceRule with the collected AssetRules
|
|
source_rule.assets = asset_rules
|
|
log.debug(f"Built SourceRule with {len(asset_rules)} AssetRule(s).")
|
|
|
|
|
|
# Emit the hierarchical rule structure
|
|
log.info(f"[{time.time():.4f}][T:{thread_id}] Parallel prediction run finished. Preparing to emit rule hierarchy.")
|
|
self.rule_hierarchy_ready.emit(source_rule)
|
|
log.info(f"[{time.time():.4f}][T:{thread_id}] Emitted rule_hierarchy_ready signal.")
|
|
|
|
# Emit the combined list of detailed file results for the table view
|
|
log.info(f"[{time.time():.4f}][T:{thread_id}] Preparing to emit {len(all_file_prediction_results)} file results for table view.")
|
|
log.debug(f"[{time.time():.4f}][T:{thread_id}] Type of all_file_prediction_results before emit: {type(all_file_prediction_results)}")
|
|
try:
|
|
log.debug(f"[{time.time():.4f}][T:{thread_id}] Content of all_file_prediction_results (first 5) before emit: {all_file_prediction_results[:5]}")
|
|
except Exception as e:
|
|
log.error(f"[{time.time():.4f}][T:{thread_id}] Error logging all_file_prediction_results content: {e}")
|
|
log.info(f"[{time.time():.4f}][T:{thread_id}] Emitting prediction_results_ready signal...")
|
|
self.prediction_results_ready.emit(all_file_prediction_results)
|
|
log.info(f"[{time.time():.4f}][T:{thread_id}] Emitted prediction_results_ready signal.")
|
|
|
|
self.status_message.emit("Preview update complete.", 3000)
|
|
self.prediction_finished.emit()
|
|
self._is_running = False
|
|
log.info(f"[{time.time():.4f}][T:{thread_id}] <-- Exiting PredictionHandler.run_prediction.") |