Source code for orka.memory.base_logger

# OrKa: Orchestrator Kit Agents
# Copyright © 2025 Marco Somma
#
# This file is part of OrKa – https://github.com/marcosomma/orka-reasoning
#
# Licensed under the Apache License, Version 2.0 (Apache 2.0).
# You may not use this file for commercial purposes without explicit permission.
#
# Full license: https://www.apache.org/licenses/LICENSE-2.0
# For commercial use, contact: marcosomma.work@gmail.com
#
# Required attribution: OrKa by Marco Somma – https://github.com/marcosomma/orka-reasoning

"""
Base Memory Logger
=================

Abstract base class for memory loggers that defines the interface that must be
implemented by all memory backends.
"""

import hashlib
import json
import logging
import threading
from abc import ABC, abstractmethod
from datetime import UTC, datetime
from typing import Any, Dict, List, Set

from .file_operations import FileOperationsMixin
from .serialization import SerializationMixin

logger = logging.getLogger(__name__)


[docs] class BaseMemoryLogger(ABC, SerializationMixin, FileOperationsMixin): """ Base Memory Logger ================= Abstract base class that defines the interface and common functionality for all memory logger implementations in OrKa. This class provides the foundation for persistent memory storage across different backends. Core Responsibilities -------------------- **Interface Definition** - Defines abstract methods that all memory backends must implement - Provides common initialization and configuration patterns - Establishes consistent behavior across different storage backends **Memory Lifecycle Management** - Automatic memory decay based on configurable rules - Importance scoring for memory retention decisions - Memory type classification (short-term vs long-term) - Category-based memory organization (logs vs stored memories) **Data Optimization** - Blob deduplication for large objects to reduce storage overhead - Serialization mixins for consistent data handling - File operation mixins for export/import functionality - Configurable thresholds for optimization decisions **Thread Safety** - Thread-safe decay scheduling and management - Concurrent access patterns for multi-threaded environments - Proper resource cleanup and lifecycle management Architecture Details ------------------- **Memory Classification System** - **Categories**: "log" (orchestration events) vs "stored" (persistent memories) - **Types**: "short_term" (temporary) vs "long_term" (persistent) - **Importance Scoring**: 0.0-1.0 scale based on event type and content - **Decay Rules**: Configurable retention policies per category/type **Blob Deduplication** - SHA256 hashing for content identification - Reference counting for cleanup decisions - Configurable size threshold (default: 200 characters) - Automatic cleanup of unused blobs **Decay Management** - Background thread for automatic cleanup - Configurable check intervals (default: 30 minutes) - Dry-run support for testing cleanup operations - Graceful shutdown with proper thread cleanup Implementation Requirements -------------------------- **Required Abstract Methods** All concrete implementations must provide: - `log()` - Store orchestration events and memory entries - `tail()` - Retrieve recent entries for debugging - `cleanup_expired_memories()` - Remove expired entries - `get_memory_stats()` - Provide storage statistics - Redis-compatible methods: `hset`, `hget`, `hkeys`, `hdel`, `get`, `set`, `delete` - Set operations: `smembers`, `sadd`, `srem` **Optional Enhancements** Implementations may provide: - Vector search capabilities for semantic similarity - Advanced filtering and querying options - Performance optimizations for specific use cases - Integration with external systems (Redis, etc.) Configuration Options -------------------- **Decay Configuration** .. code-block:: python decay_config = { "enabled": True, "default_short_term_hours": 1.0, "default_long_term_hours": 24.0, "check_interval_minutes": 30, "memory_type_rules": { "long_term_events": ["success", "completion", "write", "result"], "short_term_events": ["debug", "processing", "start", "progress"] }, "importance_rules": { "base_score": 0.5, "event_type_boosts": {"write": 0.3, "success": 0.2}, "agent_type_boosts": {"memory": 0.2, "openai-answer": 0.1} } } **Blob Deduplication** - `_blob_threshold`: Minimum size for deduplication (default: 200 chars) - Automatic reference counting and cleanup - SHA256 hashing for content identification Usage Patterns -------------- **Implementing a Custom Backend** .. code-block:: python from orka.memory.base_logger import BaseMemoryLogger class CustomMemoryLogger(BaseMemoryLogger): def __init__(self, **kwargs): super().__init__(**kwargs) self._storage = {} # Your storage implementation def log(self, agent_id, event_type, payload, **kwargs): # Implement storage logic pass def cleanup_expired_memories(self, dry_run=False): # Implement cleanup logic pass # ... implement other abstract methods **Memory Classification Logic** - Orchestration logs are always classified as short-term - Only "stored" memories can be classified as long-term - Importance scoring influences retention decisions - Event types and agent types affect classification **Thread Safety Considerations** - Decay scheduler runs in background thread - Proper synchronization for concurrent access - Graceful shutdown handling with stop events - Resource cleanup on object destruction """
[docs] def __init__( self, stream_key: str = "orka:memory", debug_keep_previous_outputs: bool = False, decay_config: dict[str, Any] | None = None, memory_preset: str | None = None, ) -> None: """ Initialize the memory logger. Args: stream_key: Key for the memory stream. Defaults to "orka:memory". debug_keep_previous_outputs: If True, keeps previous_outputs in log files for debugging. decay_config: Configuration for memory decay functionality. memory_preset: Name of memory preset to use (sensory, working, episodic, semantic, procedural, meta). If provided, preset config is used as base and merged with decay_config. """ self.stream_key = stream_key self.memory: list[dict[str, Any]] = [] # Local memory buffer self.debug_keep_previous_outputs = debug_keep_previous_outputs # Handle memory preset configuration effective_decay_config = self._resolve_memory_preset( memory_preset, decay_config or {}, operation=None ) # Initialize decay configuration self.decay_config = self._init_decay_config(effective_decay_config) # Decay state management self._decay_thread: threading.Thread | None = None self._decay_stop_event = threading.Event() self._last_decay_check = datetime.now(UTC) # Initialize automatic decay if enabled if self.decay_config.get("enabled", False): self._start_decay_scheduler() # Blob deduplication storage: SHA256 -> actual blob content self._blob_store: dict[str, Any] = {} # Track blob usage count for potential cleanup self._blob_usage: dict[str, int] = {} # Minimum size threshold for blob deduplication (in chars) self._blob_threshold = 200
def _resolve_memory_preset( self, memory_preset: str | None, decay_config: dict[str, Any], operation: str | None = None ) -> dict[str, Any]: """ Resolve memory preset configuration and merge with custom config. Args: memory_preset: Name of the memory preset to use decay_config: Custom decay configuration to override preset values operation: Memory operation type ('read' or 'write') for operation-specific defaults Returns: Merged configuration dictionary with operation-specific defaults applied """ if not memory_preset: return decay_config try: from .presets import merge_preset_with_config return merge_preset_with_config(memory_preset, decay_config, operation) except ImportError: logger.warning(f"Memory presets not available, using custom config only") return decay_config except Exception as e: logger.error(f"Failed to load memory preset '{memory_preset}': {e}") logger.warning("Falling back to custom decay config") return decay_config def _init_decay_config(self, decay_config: dict[str, Any]) -> dict[str, Any]: """ Initialize decay configuration with defaults. Args: decay_config: Raw decay configuration Returns: Processed decay configuration with defaults applied """ default_config = { "enabled": False, # Disable by default to prevent logs from disappearing "default_short_term_hours": 1.0, "default_long_term_hours": 24.0, "check_interval_minutes": 30, "memory_type_rules": { "long_term_events": ["success", "completion", "write", "result"], "short_term_events": ["debug", "processing", "start", "progress"], }, "importance_rules": { "base_score": 0.5, "event_type_boosts": { "write": 0.3, "success": 0.2, "completion": 0.2, "result": 0.1, }, "agent_type_boosts": { "memory": 0.2, "openai-answer": 0.1, }, }, } # Deep merge with defaults merged_config = default_config.copy() for key, value in decay_config.items(): if isinstance(value, dict) and key in merged_config: target_dict = merged_config.get(key) if isinstance(target_dict, dict): target_dict.update(value) else: # If merged_config[key] is not a dict, replace it entirely merged_config[key] = value else: merged_config[key] = value return merged_config def _calculate_importance_score( self, event_type: str, agent_id: str, payload: dict[str, Any], ) -> float: """ Calculate importance score for a memory entry. Args: event_type: Type of the event agent_id: ID of the agent generating the event payload: Event payload Returns: Importance score between 0.0 and 1.0 """ rules = self.decay_config.get("importance_rules", {}) score = rules.get("base_score", 0.5) # Apply event type boosts event_boost = rules.get("event_type_boosts", {}).get(event_type, 0.0) score += event_boost # Apply agent type boosts for agent_type, boost in rules.get("agent_type_boosts", {}).items(): if agent_type in agent_id: score += boost break # Check payload for result indicators if isinstance(payload, dict): if payload.get("result") or payload.get("response"): score += 0.1 if payload.get("error"): score -= 0.1 # Clamp score between 0.0 and 1.0 return_value: float = max(0.0, min(1.0, score)) return return_value def _classify_memory_type( self, event_type: str, importance_score: float, category: str = "log", ) -> str: """ Classify memory entry as short-term or long-term. Args: event_type: Type of the event importance_score: Calculated importance score category: Memory category ("stored" or "log") Returns: "short_term" or "long_term" """ # CRITICAL: Only "stored" memories should be classified as long-term # Orchestration logs should always be short-term to avoid confusion if category == "log": return "short_term" rules = self.decay_config.get("memory_type_rules", {}) # Check explicit rules first (only for stored memories) if event_type in rules.get("long_term_events", []): return "long_term" if event_type in rules.get("short_term_events", []): return "short_term" # Fallback to importance score (only for stored memories) return "long_term" if importance_score >= 0.7 else "short_term" def _classify_memory_category( self, event_type: str, agent_id: str, payload: dict[str, Any], log_type: str = "log", ) -> str: """ Classify memory entry category for separation between logs and stored memories. Args: event_type: Type of the event agent_id: ID of the agent generating the event payload: Event payload log_type: Explicit log type ("log" or "memory") Returns: "stored" for memory writer outputs, "log" for other events """ # 🎯 CRITICAL: Use explicit log_type parameter first if log_type == "memory": return "stored" elif log_type == "log": return "log" # Fallback to legacy detection (for backward compatibility) # Memory writes from memory writer nodes should be categorized as "stored" if event_type == "write" and ("memory" in agent_id.lower() or "writer" in agent_id.lower()): return "stored" # Check payload for memory content indicators if isinstance(payload, dict): # If payload contains content field, it's likely stored memory if payload.get("content") and payload.get("metadata"): return "stored" # If it's a memory operation result if payload.get("memory_object") or payload.get("memories"): return "stored" # Default to log for orchestration events return "log" def _start_decay_scheduler(self): """Start the automatic decay scheduler thread.""" if self._decay_thread is not None: return # Already running def decay_scheduler() -> None: interval_seconds = self.decay_config.get("check_interval_minutes", 1) * 60 consecutive_failures = 0 max_consecutive_failures = 3 while not self._decay_stop_event.wait(interval_seconds): try: self.cleanup_expired_memories() consecutive_failures = 0 # Reset on success except Exception as e: consecutive_failures += 1 logger.error( f"Error during automatic memory decay (failure {consecutive_failures}): {e}" ) # If we have too many consecutive failures, increase the interval to prevent spam if consecutive_failures >= max_consecutive_failures: logger.warning( f"Memory decay has failed {consecutive_failures} times consecutively. " f"Increasing interval to {interval_seconds * 2} seconds to prevent resource exhaustion." ) interval_seconds = min(interval_seconds * 2, 3600) # Cap at 1 hour consecutive_failures = 0 # Reset counter after backing off self._decay_thread = threading.Thread(target=decay_scheduler, daemon=True) self._decay_thread.start() logger.info( f"Started automatic memory decay scheduler (interval: {self.decay_config['check_interval_minutes']} minutes)", )
[docs] def stop_decay_scheduler(self): """Stop the automatic decay scheduler.""" if self._decay_thread is not None: self._decay_stop_event.set() self._decay_thread.join(timeout=5) self._decay_thread = None logger.info("Stopped automatic memory decay scheduler")
[docs] @abstractmethod def cleanup_expired_memories(self, dry_run: bool = False) -> dict[str, Any]: """ Clean up expired memory entries based on decay configuration. Args: dry_run: If True, return what would be deleted without actually deleting Returns: Dictionary containing cleanup statistics """
[docs] @abstractmethod def get_memory_stats(self) -> dict[str, Any]: """ Get memory usage statistics. Returns: Dictionary containing memory statistics """
[docs] @abstractmethod def log( self, agent_id: str, event_type: str, payload: dict[str, Any], step: int | None = None, run_id: str | None = None, fork_group: str | None = None, parent: str | None = None, previous_outputs: dict[str, Any] | None = None, agent_decay_config: dict[str, Any] | None = None, log_type: str = "log", # 🎯 NEW: "log" for orchestration, "memory" for stored memories ) -> None: """Log an event to the memory backend."""
[docs] @abstractmethod def tail(self, count: int = 10) -> list[dict[str, Any]]: """Retrieve the most recent events."""
[docs] @abstractmethod def hset(self, name: str, key: str, value: str | bytes | int | float) -> int: """Set a field in a hash structure."""
[docs] @abstractmethod def hget(self, name: str, key: str) -> str | None: """Get a field from a hash structure."""
[docs] @abstractmethod def hkeys(self, name: str) -> list[str]: """Get all keys in a hash structure."""
[docs] @abstractmethod def hdel(self, name: str, *keys: str) -> int: """Delete fields from a hash structure."""
[docs] @abstractmethod def smembers(self, name: str) -> list[str]: """Get all members of a set."""
[docs] @abstractmethod def sadd(self, name: str, *values: str) -> int: """Add members to a set."""
[docs] @abstractmethod def srem(self, name: str, *values: str) -> int: """Remove members from a set."""
[docs] @abstractmethod def get(self, key: str) -> str | None: """Get a value by key."""
[docs] @abstractmethod def set(self, key: str, value: str | bytes | int | float) -> bool: """Set a value by key."""
[docs] @abstractmethod def delete(self, *keys: str) -> int: """Delete keys."""
def _compute_blob_hash(self, obj: Any) -> str: """ Compute SHA256 hash of a JSON-serializable object. Args: obj: Object to hash Returns: SHA256 hash as hex string """ try: # Convert to canonical JSON string for consistent hashing json_str = json.dumps(obj, sort_keys=True, separators=(",", ":")) return hashlib.sha256(json_str.encode("utf-8")).hexdigest() except Exception: # If object can't be serialized, return hash of string representation return hashlib.sha256(str(obj).encode("utf-8")).hexdigest() def _should_deduplicate_blob(self, obj: Any) -> bool: """ Determine if an object should be deduplicated as a blob. Args: obj: Object to check Returns: True if object should be deduplicated """ try: # Only deduplicate large dictionary payloads if not isinstance(obj, dict): return False # Check size threshold json_str = json.dumps(obj, separators=(",", ":")) return len(json_str) >= self._blob_threshold except Exception: return False def _store_blob(self, obj: Any) -> str: """ Store a blob and return its reference hash. Args: obj: Object to store as blob Returns: SHA256 hash reference """ blob_hash = self._compute_blob_hash(obj) # Store the blob if not already present if blob_hash not in self._blob_store: self._blob_store[blob_hash] = obj self._blob_usage[blob_hash] = 0 # Increment usage count self._blob_usage[blob_hash] += 1 return blob_hash def _create_blob_reference( self, blob_hash: str, original_keys: list[str] | None = None, ) -> dict[str, Any]: """ Create a blob reference object. Args: blob_hash: SHA256 hash of the blob original_keys: List of keys that were in the original object (for reference) Returns: Blob reference dictionary """ ref: dict[str, Any] = { "ref": blob_hash, "_type": "blob_reference", "_original_keys": None, } if original_keys: ref["_original_keys"] = original_keys return ref def _deduplicate_object(self, obj: Any) -> Any: """ Recursively deduplicate an object, replacing large blobs with references. Args: obj: Object to deduplicate Returns: Deduplicated object with blob references """ if not isinstance(obj, dict): return obj # Check if this object should be stored as a blob if self._should_deduplicate_blob(obj): blob_hash = self._store_blob(obj) return self._create_blob_reference(blob_hash, list(obj.keys())) # Recursively deduplicate nested objects deduplicated = {} for key, value in obj.items(): deduplicated[key] = self._recursive_deduplicate(value) return deduplicated def _recursive_deduplicate(self, obj: Any) -> Any: """ Helper method to recursively apply deduplication. """ if isinstance(obj, dict): return self._deduplicate_object(obj) elif isinstance(obj, list): return [self._recursive_deduplicate(item) for item in obj] else: return obj def _process_memory_for_saving( self, memory_entries: list[dict[str, Any]] ) -> list[dict[str, Any]]: """ Process memory entries before saving, e.g., removing previous_outputs. """ processed_entries = [] for entry in memory_entries: new_entry = entry.copy() if not self.debug_keep_previous_outputs: # Remove previous_outputs to reduce log size unless debugging is enabled if "previous_outputs" in new_entry: # Store a summary instead of the full object new_entry["previous_outputs_summary"] = { "count": len(new_entry["previous_outputs"]), "keys": list(new_entry["previous_outputs"].keys()), } del new_entry["previous_outputs"] processed_entries.append(new_entry) return processed_entries def _sanitize_for_json(self, obj: Any, _seen: Set[Any] | None = None) -> Any: """ Sanitize an object to ensure it's JSON serializable. Converts non-serializable types (like objects, functions) to strings. """ if isinstance(obj, (int, float, str, bool)) or obj is None: return obj elif isinstance(obj, (list, tuple)): return [self._sanitize_for_json(item) for item in obj] elif isinstance(obj, dict): return {k: self._sanitize_for_json(v) for k, v in obj.items()} else: # Fallback for non-serializable objects return f"<non-serializable: {type(obj).__name__}>" def _should_use_deduplication_format(self) -> bool: """ Determine whether to use the deduplication format for saving logs. This is based on whether any blobs were actually stored. """ return bool(self._blob_store) def _build_previous_outputs(self, logs: list[dict[str, Any]]) -> dict[str, Any]: """ Build a dictionary of previous agent outputs from the execution logs. Used to provide context to downstream agents. """ outputs = {} # First, try to get results from Redis try: # Get all agent results from Redis hash group_key = "agent_results" result_keys = self.hkeys(group_key) for agent_id in result_keys: result_str = self.hget(group_key, agent_id) if result_str: result = json.loads(result_str) outputs[agent_id] = result logger.debug(f"- Loaded result for agent {agent_id} from Redis") except Exception as e: logger.warning(f"Failed to load results from Redis: {e}") # Then process logs to update/add any missing results for log in logs: agent_id = str(log.get("agent_id")) if not agent_id: continue payload = log.get("payload", {}) # Case: regular agent output if "result" in payload: outputs[agent_id] = payload["result"] # Case: JoinNode with merged dict if "result" in payload and isinstance(payload["result"], dict): merged = payload["result"].get("merged") if isinstance(merged, dict): outputs.update(merged) # Case: Current run agent responses if "response" in payload: outputs[agent_id] = { "response": payload["response"], "confidence": payload.get("confidence", "0.0"), "internal_reasoning": payload.get("internal_reasoning", ""), "_metrics": payload.get("_metrics", {}), "formatted_prompt": payload.get("formatted_prompt", ""), } # Case: Memory agent responses if "memories" in payload: outputs[agent_id] = { "memories": payload["memories"], "query": payload.get("query", ""), "backend": payload.get("backend", ""), "search_type": payload.get("search_type", ""), "num_results": payload.get("num_results", 0), } # Store the result in Redis for future access try: # Store individual result result_key = f"agent_result:{agent_id}" self.set(result_key, json.dumps(outputs[agent_id])) logger.debug(f"- Stored result for agent {agent_id}") # Store in group hash self.hset(group_key, agent_id, json.dumps(outputs[agent_id])) logger.debug(f"- Stored result in group for agent {agent_id}") except Exception as e: logger.warning(f"Failed to store result in Redis: {e}") return outputs
[docs] def save_enhanced_trace(self, file_path: str, enhanced_data: Dict[str, Any]) -> None: """Save enhanced trace data with memory backend references and blob deduplication.""" try: # Apply blob deduplication to the enhanced trace data deduplicated_data = self._apply_deduplication_to_enhanced_trace(enhanced_data) import json with open(file_path, "w", encoding="utf-8") as f: json.dump(deduplicated_data, f, indent=2, default=str) # Log deduplication statistics if ( "_metadata" in deduplicated_data and "deduplication_enabled" in deduplicated_data["_metadata"] ): if deduplicated_data["_metadata"]["deduplication_enabled"]: stats = deduplicated_data["_metadata"].get("stats", {}) blob_count = deduplicated_data["_metadata"].get("total_blobs_stored", 0) size_reduction = stats.get("size_reduction", 0) logger.info( f"Enhanced trace saved with deduplication: {blob_count} blobs, {size_reduction} bytes saved" ) else: logger.info(f"Enhanced trace saved (no deduplication needed)") else: logger.info(f"Enhanced trace saved to {file_path}") except Exception as e: logger.error(f"Failed to save enhanced trace with deduplication: {e}") # Fallback to simple JSON dump try: import json with open(file_path, "w", encoding="utf-8") as f: json.dump(enhanced_data, f, indent=2, default=str) logger.info(f"Enhanced trace saved (fallback mode) to {file_path}") except Exception as fallback_e: logger.error(f"Fallback save also failed: {fallback_e}") # Last resort: use the original save_to_file method self.save_to_file(file_path)
def _apply_deduplication_to_enhanced_trace( self, enhanced_data: Dict[str, Any] ) -> Dict[str, Any]: """Apply blob deduplication to enhanced trace data using original events format.""" try: import json from datetime import UTC, datetime # Reset blob store for this operation original_blob_store = getattr(self, "_blob_store", {}) self._blob_store = {} # Convert enhanced trace format to original events format for deduplication events = [] blob_stats = { "total_entries": 0, "deduplicated_blobs": 0, "size_reduction": 0, } # Process agent_executions into events format if "agent_executions" in enhanced_data: for execution in enhanced_data["agent_executions"]: blob_stats["total_entries"] += 1 # Create event preserving original structure (agent_id, event_type, timestamp) event = { "agent_id": execution.get("agent_id"), "event_type": execution.get("event_type"), "timestamp": execution.get("timestamp"), } # Add other top-level fields if they exist for key in ["step", "run_id", "fork_group", "parent"]: if key in execution: event[key] = execution[key] # Handle payload separately - only deduplicate if large if "payload" in execution: payload = execution["payload"] # Calculate payload size to decide if it needs deduplication payload_size = len(json.dumps(payload, separators=(",", ":"))) if payload_size > getattr(self, "_blob_threshold", 200): # Payload is large, deduplicate it original_size = payload_size deduplicated_payload = self._deduplicate_object(payload) new_size = len(json.dumps(deduplicated_payload, separators=(",", ":"))) if new_size < original_size: blob_stats["deduplicated_blobs"] += 1 blob_stats["size_reduction"] += original_size - new_size event["payload"] = deduplicated_payload else: # Payload is small, keep as-is event["payload"] = payload # Add enhanced trace specific fields (memory_references, template_resolution) for key in ["memory_references", "template_resolution"]: if key in execution: event[key] = execution[key] events.append(event) # Decide whether to use deduplication format use_dedup_format = bool(self._blob_store) if use_dedup_format: # Extract token and cost data from agent executions cost_analysis = self._extract_cost_analysis(enhanced_data, events) # Create the original blob_store + events format with cost analysis result = { "_metadata": { "version": "1.2.0", # Use original version for compatibility "deduplication_enabled": True, "blob_threshold_chars": getattr(self, "_blob_threshold", 200), "total_blobs_stored": len(self._blob_store), "stats": blob_stats, "generated_at": datetime.now(UTC).isoformat(), }, "blob_store": self._blob_store.copy(), "events": events, # Use 'events' key like original format "cost_analysis": cost_analysis, # New key for token/cost data } else: # No deduplication needed - use enhanced format with metadata result = enhanced_data.copy() result["_metadata"] = { "version": "1.2.0", "deduplication_enabled": False, "generated_at": datetime.now(UTC).isoformat(), } # Add cost analysis even when no deduplication result["cost_analysis"] = self._extract_cost_analysis(enhanced_data, events) # Restore original blob store self._blob_store = original_blob_store return result except Exception as e: logger.error(f"Failed to apply deduplication to enhanced trace: {e}") # Restore original blob store on error if "original_blob_store" in locals(): self._blob_store = original_blob_store # Return original data if deduplication fails return enhanced_data def _extract_cost_analysis( self, enhanced_data: Dict[str, Any], events: List[Dict[str, Any]] ) -> Dict[str, Any]: """Extract token and cost analysis from agent executions.""" try: cost_analysis: Dict[str, Any] = { "summary": { "total_agents": 0, "total_tokens": 0, "total_prompt_tokens": 0, "total_completion_tokens": 0, "total_cost_usd": 0.0, "total_latency_ms": 0.0, "models_used": set(), "providers_used": set(), }, "agents": {}, "by_model": {}, "by_provider": {}, } # Process each agent execution to extract cost data for event in events: agent_id = event.get("agent_id") event_type = event.get("event_type") # Only process LLM agents that have cost data if not agent_id or not event_type or "LLMAgent" not in str(event_type): continue # Extract metrics from payload or blob_store metrics = self._extract_agent_metrics(event, enhanced_data) if metrics: # Update agent-specific data if agent_id not in cost_analysis["agents"]: cost_analysis["agents"][agent_id] = { "executions": 0, "total_tokens": 0, "total_prompt_tokens": 0, "total_completion_tokens": 0, "total_cost_usd": 0.0, "total_latency_ms": 0.0, "models": set(), "providers": set(), "event_type": event_type, } agent_data = cost_analysis["agents"][agent_id] agent_data["executions"] += 1 agent_data["total_tokens"] += metrics.get("tokens", 0) agent_data["total_prompt_tokens"] += metrics.get("prompt_tokens", 0) agent_data["total_completion_tokens"] += metrics.get("completion_tokens", 0) agent_data["total_cost_usd"] += metrics.get("cost_usd", 0.0) agent_data["total_latency_ms"] += metrics.get("latency_ms", 0.0) model = metrics.get("model", "unknown") provider = metrics.get("provider", "unknown") agent_data["models"].add(model) agent_data["providers"].add(provider) # Update summary summary = cost_analysis["summary"] summary["total_agents"] += 1 summary["total_tokens"] += metrics.get("tokens", 0) summary["total_prompt_tokens"] += metrics.get("prompt_tokens", 0) summary["total_completion_tokens"] += metrics.get("completion_tokens", 0) summary["total_cost_usd"] += metrics.get("cost_usd", 0.0) summary["total_latency_ms"] += metrics.get("latency_ms", 0.0) summary["models_used"].add(model) summary["providers_used"].add(provider) # Update by_model aggregation if model not in cost_analysis["by_model"]: cost_analysis["by_model"][model] = { "agents": 0, "total_tokens": 0, "total_cost_usd": 0.0, "total_latency_ms": 0.0, } model_data = cost_analysis["by_model"][model] model_data["agents"] += 1 model_data["total_tokens"] += metrics.get("tokens", 0) model_data["total_cost_usd"] += metrics.get("cost_usd", 0.0) model_data["total_latency_ms"] += metrics.get("latency_ms", 0.0) # Update by_provider aggregation if provider not in cost_analysis["by_provider"]: cost_analysis["by_provider"][provider] = { "agents": 0, "total_tokens": 0, "total_cost_usd": 0.0, "total_latency_ms": 0.0, } provider_data = cost_analysis["by_provider"][provider] provider_data["agents"] += 1 provider_data["total_tokens"] += metrics.get("tokens", 0) provider_data["total_cost_usd"] += metrics.get("cost_usd", 0.0) provider_data["total_latency_ms"] += metrics.get("latency_ms", 0.0) # Convert sets to lists for JSON serialization cost_analysis["summary"]["models_used"] = list(cost_analysis["summary"]["models_used"]) cost_analysis["summary"]["providers_used"] = list( cost_analysis["summary"]["providers_used"] ) for agent_data in cost_analysis["agents"].values(): agent_data["models"] = list(agent_data["models"]) agent_data["providers"] = list(agent_data["providers"]) return cost_analysis except Exception as e: logger.error(f"Failed to extract cost analysis: {e}") return {"error": str(e)} def _extract_agent_metrics( self, event: Dict[str, Any], enhanced_data: Dict[str, Any] ) -> Dict[str, Any]: """Extract metrics from an agent event, resolving blob references if needed.""" try: payload = event.get("payload", {}) # If payload is a blob reference, resolve it if isinstance(payload, dict) and payload.get("_type") == "blob_reference": blob_ref = payload.get("ref") if blob_ref and hasattr(self, "_blob_store") and blob_ref in self._blob_store: # Get from current blob store resolved_payload = self._blob_store[blob_ref] elif ( blob_ref and "blob_store" in enhanced_data and blob_ref in enhanced_data["blob_store"] ): # Get from enhanced_data blob store resolved_payload = enhanced_data["blob_store"][blob_ref] else: return {} else: resolved_payload = payload # Look for metrics in various locations within the resolved payload metrics = {} # Check if there's a direct _metrics field if "_metrics" in resolved_payload: metrics = resolved_payload["_metrics"] # Check in previous_outputs for agent responses with _metrics elif "previous_outputs" in resolved_payload: prev_outputs = resolved_payload["previous_outputs"] for agent_response in prev_outputs.values(): if isinstance(agent_response, dict) and "_metrics" in agent_response: # Merge metrics from all agent responses agent_metrics = agent_response["_metrics"] for key, value in agent_metrics.items(): if key in ["tokens", "prompt_tokens", "completion_tokens"]: metrics[key] = metrics.get(key, 0) + value elif key in ["cost_usd", "latency_ms"]: metrics[key] = metrics.get(key, 0.0) + value else: metrics[key] = value # model, provider, etc. # Check for response field with _metrics elif "response" in resolved_payload and isinstance(resolved_payload["response"], dict): response = resolved_payload["response"] if "_metrics" in response: metrics = response["_metrics"] return metrics except Exception as e: logger.error(f"Failed to extract agent metrics: {e}") return {}