351 lines
17 KiB
Python
351 lines
17 KiB
Python
import os
|
|
import json
|
|
import requests
|
|
from PySide6.QtCore import QObject, Signal, Slot, QThread
|
|
from typing import List, Dict, Any
|
|
|
|
# Assuming rule_structure defines SourceRule, AssetRule, FileRule etc.
|
|
# Adjust the import path if necessary based on project structure
|
|
from rule_structure import SourceRule, AssetRule, FileRule # Ensure AssetRule and FileRule are imported
|
|
|
|
# Assuming configuration loads app_settings.json
|
|
# Adjust the import path if necessary
|
|
# Removed Configuration import, will use load_base_config if needed or passed settings
|
|
# from configuration import Configuration
|
|
from configuration import load_base_config # Keep this for now if needed elsewhere, or remove if settings are always passed
|
|
|
|
class LLMPredictionHandler(QObject):
|
|
"""
|
|
Handles the interaction with an LLM for predicting asset structures
|
|
based on a directory's file list. Designed to run in a QThread.
|
|
"""
|
|
# Signal emitted when prediction for a directory is complete
|
|
# Arguments: directory_path (str), results (List[SourceRule])
|
|
prediction_ready = Signal(str, list)
|
|
# Signal emitted on error
|
|
# Arguments: directory_path (str), error_message (str)
|
|
prediction_error = Signal(str, str)
|
|
# Signal to update status message in the GUI
|
|
status_update = Signal(str)
|
|
|
|
def __init__(self, input_path_str: str, file_list: list, llm_settings: dict, parent: QObject = None): # Accept input_path_str and file_list
|
|
"""
|
|
Initializes the handler.
|
|
|
|
Args:
|
|
input_path_str: The absolute path to the original input source (directory or archive).
|
|
file_list: A list of relative file paths extracted from the input source.
|
|
llm_settings: A dictionary containing necessary LLM configuration.
|
|
parent: The parent QObject.
|
|
"""
|
|
super().__init__(parent)
|
|
self.input_path_str = input_path_str # Store original input path
|
|
self.file_list = file_list # Store the provided file list
|
|
self.llm_settings = llm_settings # Store the settings dictionary
|
|
self.endpoint_url = self.llm_settings.get('llm_endpoint_url')
|
|
self.api_key = self.llm_settings.get('llm_api_key')
|
|
self._is_cancelled = False
|
|
@Slot()
|
|
def run(self):
|
|
"""
|
|
The main execution method to be called when the thread starts.
|
|
Orchestrates the prediction process for the given directory.
|
|
"""
|
|
# Directory check is no longer needed here, input path is just for context
|
|
# File list is provided via __init__
|
|
|
|
try:
|
|
self.status_update.emit(f"Preparing LLM input for {os.path.basename(self.input_path_str)}...")
|
|
if self._is_cancelled: return
|
|
|
|
# Use the file list passed during initialization
|
|
if not self.file_list:
|
|
self.prediction_ready.emit(self.input_path_str, []) # Emit empty list if no files
|
|
return
|
|
if self._is_cancelled: return
|
|
|
|
prompt = self._prepare_prompt(self.file_list) # Use self.file_list
|
|
if self._is_cancelled: return
|
|
|
|
self.status_update.emit(f"Calling LLM for {os.path.basename(self.input_path_str)}...")
|
|
llm_response_json_str = self._call_llm(prompt)
|
|
if self._is_cancelled: return
|
|
|
|
self.status_update.emit(f"Parsing LLM response for {os.path.basename(self.input_path_str)}...")
|
|
predicted_rules = self._parse_llm_response(llm_response_json_str)
|
|
if self._is_cancelled: return
|
|
|
|
self.prediction_ready.emit(self.input_path_str, predicted_rules) # Use input_path_str
|
|
self.status_update.emit(f"LLM interpretation complete for {os.path.basename(self.input_path_str)}.")
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error during LLM prediction for {self.input_path_str}: {e}"
|
|
print(error_msg) # Log the full error
|
|
self.prediction_error.emit(self.input_path_str, f"An error occurred: {e}") # Use input_path_str
|
|
finally:
|
|
# Ensure thread cleanup or final signals if needed
|
|
pass
|
|
|
|
@Slot()
|
|
def cancel(self):
|
|
"""
|
|
Sets the cancellation flag.
|
|
"""
|
|
self._is_cancelled = True
|
|
self.status_update.emit(f"Cancellation requested for {os.path.basename(self.input_path_str)}...") # Use input_path_str
|
|
|
|
|
|
# Removed _get_file_list method as file list is now passed in __init__
|
|
|
|
def _prepare_prompt(self, file_list: List[str]) -> str:
|
|
"""
|
|
Prepares the full prompt string to send to the LLM using stored settings.
|
|
"""
|
|
# Access settings from the stored dictionary
|
|
prompt_template = self.llm_settings.get('prompt_template_content')
|
|
if not prompt_template:
|
|
# Attempt to fall back to reading the default file path if content is missing
|
|
default_template_path = 'llm_prototype/prompt_template.txt'
|
|
print(f"Warning: 'prompt_template_content' missing in llm_settings. Falling back to reading default file: {default_template_path}")
|
|
try:
|
|
with open(default_template_path, 'r', encoding='utf-8') as f:
|
|
prompt_template = f.read()
|
|
except FileNotFoundError:
|
|
raise ValueError(f"LLM predictor prompt template content missing in settings and default file not found at: {default_template_path}")
|
|
except Exception as e:
|
|
raise ValueError(f"Error reading default LLM prompt template file {default_template_path}: {e}")
|
|
|
|
if not prompt_template: # Final check after potential fallback
|
|
raise ValueError("LLM predictor prompt template content is empty or could not be loaded.")
|
|
|
|
|
|
# Access definitions and examples from the settings dictionary
|
|
asset_defs = json.dumps(self.llm_settings.get('asset_types', {}), indent=4)
|
|
file_defs = json.dumps(self.llm_settings.get('file_types', {}), indent=4)
|
|
examples = json.dumps(self.llm_settings.get('examples', []), indent=2)
|
|
|
|
# Format file list as a single string with newlines
|
|
file_list_str = "\n".join(file_list)
|
|
|
|
# Replace placeholders
|
|
prompt = prompt_template.replace('{ASSET_TYPE_DEFINITIONS}', asset_defs)
|
|
prompt = prompt.replace('{FILE_TYPE_DEFINITIONS}', file_defs)
|
|
prompt = prompt.replace('{EXAMPLE_INPUT_OUTPUT_PAIRS}', examples)
|
|
prompt = prompt.replace('{FILE_LIST}', file_list_str)
|
|
|
|
return prompt
|
|
|
|
def _call_llm(self, prompt: str) -> str:
|
|
"""
|
|
Calls the configured LLM API endpoint with the prepared prompt.
|
|
|
|
Args:
|
|
prompt: The complete prompt string.
|
|
|
|
Returns:
|
|
The content string from the LLM response, expected to be JSON.
|
|
|
|
Raises:
|
|
ConnectionError: If the request fails due to network issues or timeouts.
|
|
ValueError: If the endpoint URL is not configured or the response is invalid.
|
|
requests.exceptions.RequestException: For other request-related errors.
|
|
"""
|
|
if not self.endpoint_url:
|
|
raise ValueError("LLM endpoint URL is not configured in settings.")
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
# Construct payload based on OpenAI Chat Completions format
|
|
payload = {
|
|
# Use configured model name, default to 'local-model'
|
|
"model": self.llm_settings.get("llm_model_name", "local-model"),
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
# Use configured temperature, default to 0.5
|
|
"temperature": self.llm_settings.get("llm_temperature", 0.5),
|
|
# Add max_tokens if needed/configurable:
|
|
# "max_tokens": self.llm_settings.get("llm_max_tokens", 1024),
|
|
# Ensure the LLM is instructed to return JSON in the prompt itself
|
|
# Some models/endpoints support a specific json mode:
|
|
# "response_format": { "type": "json_object" } # If supported by endpoint
|
|
}
|
|
|
|
self.status_update.emit(f"Sending request to LLM at {self.endpoint_url}...")
|
|
print(f"--- Calling LLM API: {self.endpoint_url} ---")
|
|
# print(f"--- Payload Preview ---\n{json.dumps(payload, indent=2)[:500]}...\n--- END Payload Preview ---")
|
|
|
|
try:
|
|
# Make the POST request with a timeout (e.g., 120 seconds for potentially long LLM responses)
|
|
response = requests.post(
|
|
self.endpoint_url,
|
|
headers=headers,
|
|
json=payload,
|
|
# Make the POST request with configured timeout, default to 120
|
|
timeout=self.llm_settings.get("llm_request_timeout", 120)
|
|
)
|
|
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
|
|
|
|
except requests.exceptions.Timeout:
|
|
error_msg = f"LLM request timed out after {self.llm_settings.get('llm_request_timeout', 120)} seconds."
|
|
print(error_msg)
|
|
raise ConnectionError(error_msg)
|
|
except requests.exceptions.RequestException as e:
|
|
error_msg = f"LLM request failed: {e}"
|
|
print(error_msg)
|
|
# Attempt to get more detail from response if available
|
|
try:
|
|
if e.response is not None:
|
|
print(f"LLM Response Status Code: {e.response.status_code}")
|
|
print(f"LLM Response Text: {e.response.text[:500]}...") # Log partial response text
|
|
error_msg += f" (Status: {e.response.status_code})"
|
|
except Exception:
|
|
pass # Ignore errors during error reporting enhancement
|
|
raise ConnectionError(error_msg) # Raise a more generic error for the GUI
|
|
|
|
# Parse the JSON response
|
|
try:
|
|
response_data = response.json()
|
|
# print(f"--- LLM Raw Response ---\n{json.dumps(response_data, indent=2)}\n--- END Raw Response ---") # Debugging
|
|
|
|
# Extract content - structure depends on the API (OpenAI format assumed)
|
|
if "choices" in response_data and len(response_data["choices"]) > 0:
|
|
message = response_data["choices"][0].get("message", {})
|
|
content = message.get("content")
|
|
if content:
|
|
# The content itself should be the JSON string we asked for
|
|
print("--- LLM Response Content Extracted Successfully ---")
|
|
return content.strip()
|
|
else:
|
|
raise ValueError("LLM response missing 'content' in choices[0].message.")
|
|
else:
|
|
raise ValueError("LLM response missing 'choices' array or it's empty.")
|
|
|
|
except json.JSONDecodeError:
|
|
error_msg = f"Failed to decode LLM JSON response. Response text: {response.text[:500]}..."
|
|
print(error_msg)
|
|
raise ValueError(error_msg)
|
|
except Exception as e:
|
|
# Capture the potentially problematic response_data in the error message
|
|
response_data_str = "Not available"
|
|
try:
|
|
response_data_str = json.dumps(response_data) if 'response_data' in locals() else response.text[:500] + "..."
|
|
except Exception:
|
|
pass # Avoid errors during error reporting
|
|
error_msg = f"Error parsing LLM response structure: {e}. Response data: {response_data_str}"
|
|
print(error_msg)
|
|
raise ValueError(error_msg)
|
|
|
|
def _parse_llm_response(self, llm_response_json_str: str) -> List[SourceRule]:
|
|
"""
|
|
Parses the LLM's JSON response string into a list of SourceRule objects.
|
|
"""
|
|
# Strip potential markdown code fences before parsing
|
|
clean_json_str = llm_response_json_str.strip()
|
|
if clean_json_str.startswith("```json"):
|
|
clean_json_str = clean_json_str[7:] # Remove ```json\n
|
|
if clean_json_str.endswith("```"):
|
|
clean_json_str = clean_json_str[:-3] # Remove ```
|
|
clean_json_str = clean_json_str.strip() # Remove any extra whitespace
|
|
|
|
try:
|
|
response_data = json.loads(clean_json_str)
|
|
except json.JSONDecodeError as e:
|
|
# Log the full cleaned string that caused the error for better debugging
|
|
error_detail = f"Failed to decode LLM JSON response: {e}\nFull Cleaned Response:\n{clean_json_str}"
|
|
print(f"ERROR: {error_detail}") # Print full error detail to console
|
|
raise ValueError(error_detail) # Raise the error with full detail
|
|
if "predicted_assets" not in response_data or not isinstance(response_data["predicted_assets"], list):
|
|
raise ValueError("Invalid LLM response format: 'predicted_assets' key missing or not a list.")
|
|
|
|
source_rules = []
|
|
# We assume one SourceRule per input source processed by this handler instance
|
|
source_rule = SourceRule(input_path=self.input_path_str) # Use input_path_str
|
|
|
|
# Access valid types from the settings dictionary
|
|
valid_asset_types = list(self.llm_settings.get('asset_types', {}).keys())
|
|
valid_file_types = list(self.llm_settings.get('file_types', {}).keys())
|
|
|
|
for asset_data in response_data["predicted_assets"]:
|
|
if not isinstance(asset_data, dict):
|
|
print(f"Warning: Skipping invalid asset data (not a dict): {asset_data}")
|
|
continue
|
|
|
|
asset_name = asset_data.get("suggested_asset_name", "Unnamed_Asset")
|
|
asset_type = asset_data.get("predicted_asset_type")
|
|
|
|
if asset_type not in valid_asset_types:
|
|
print(f"Warning: Invalid predicted_asset_type '{asset_type}' for asset '{asset_name}'. Defaulting or skipping.")
|
|
# Decide handling: default to a generic type or skip? For now, skip.
|
|
continue # Or assign a default like 'Unknown' if defined
|
|
|
|
# --- MODIFIED LINES for AssetRule ---
|
|
# Create the AssetRule instance first
|
|
asset_rule = AssetRule(asset_name=asset_name, asset_type=asset_type)
|
|
source_rule.assets.append(asset_rule) # Append to the list
|
|
|
|
if "files" not in asset_data or not isinstance(asset_data["files"], list):
|
|
print(f"Warning: 'files' key missing or not a list in asset '{asset_name}'. Skipping files for this asset.")
|
|
continue
|
|
|
|
for file_data in asset_data["files"]:
|
|
if not isinstance(file_data, dict):
|
|
print(f"Warning: Skipping invalid file data (not a dict) in asset '{asset_name}': {file_data}")
|
|
continue
|
|
|
|
file_path_rel = file_data.get("file_path")
|
|
file_type = file_data.get("predicted_file_type")
|
|
|
|
if not file_path_rel:
|
|
print(f"Warning: Missing 'file_path' in file data for asset '{asset_name}'. Skipping file.")
|
|
continue
|
|
|
|
# Convert relative path from LLM (using '/') back to absolute OS-specific path
|
|
# Note: LLM gets relative paths, so we join with the handler's base input path
|
|
file_path_abs = os.path.join(self.input_path_str, file_path_rel.replace('/', os.sep)) # Use input_path_str
|
|
|
|
if file_type not in valid_file_types:
|
|
print(f"Warning: Invalid predicted_file_type '{file_type}' for file '{file_path_rel}'. Defaulting to EXTRA.")
|
|
file_type = "EXTRA" # Default to EXTRA if invalid type from LLM
|
|
|
|
# --- MODIFIED LINES for FileRule ---
|
|
# Create the FileRule instance first
|
|
file_rule = FileRule(file_path=file_path_abs, item_type=file_type) # Use correct field names
|
|
asset_rule.files.append(file_rule) # Append to the list
|
|
|
|
source_rules.append(source_rule)
|
|
return source_rules
|
|
|
|
# Example of how this might be used in MainWindow (conceptual)
|
|
# class MainWindow(QMainWindow):
|
|
# # ... other methods ...
|
|
# def _start_llm_prediction(self, directory_path):
|
|
# self.llm_thread = QThread()
|
|
# self.llm_handler = LLMPredictionHandler(directory_path, self.config_manager)
|
|
# self.llm_handler.moveToThread(self.llm_thread)
|
|
#
|
|
# # Connect signals
|
|
# self.llm_handler.prediction_ready.connect(self._on_llm_prediction_ready)
|
|
# self.llm_handler.prediction_error.connect(self._on_llm_prediction_error)
|
|
# self.llm_handler.status_update.connect(self.statusBar().showMessage)
|
|
# self.llm_thread.started.connect(self.llm_handler.run)
|
|
# self.llm_thread.finished.connect(self.llm_thread.deleteLater)
|
|
# self.llm_handler.prediction_ready.connect(self.llm_thread.quit) # Quit thread on success
|
|
# self.llm_handler.prediction_error.connect(self.llm_thread.quit) # Quit thread on error
|
|
#
|
|
# self.llm_thread.start()
|
|
#
|
|
# @Slot(str, list)
|
|
# def _on_llm_prediction_ready(self, directory_path, results):
|
|
# print(f"LLM Prediction ready for {directory_path}: {len(results)} source rules found.")
|
|
# # Process results, update model, etc.
|
|
# # Make sure to clean up thread/handler references if needed
|
|
# self.llm_handler.deleteLater() # Schedule handler for deletion
|
|
#
|
|
# @Slot(str, str)
|
|
# def _on_llm_prediction_error(self, directory_path, error_message):
|
|
# print(f"LLM Prediction error for {directory_path}: {error_message}")
|
|
# # Show error to user, clean up thread/handler
|
|
# self.llm_handler.deleteLater() |