Source code for orka.nodes.memory_writer_node

import logging
from typing import Any

from jinja2 import Template

from ..memory_logger import create_memory_logger
from .base_node import BaseNode

logger = logging.getLogger(__name__)


[docs] class MemoryWriterNode(BaseNode): """Enhanced memory writer using RedisStack through memory logger.""" def __init__(self, node_id: str, **kwargs): super().__init__(node_id=node_id, **kwargs) # ✅ CRITICAL: Use memory logger instead of direct Redis self.memory_logger = kwargs.get("memory_logger") if not self.memory_logger: # Create RedisStack memory logger self.memory_logger = create_memory_logger( backend="redisstack", enable_hnsw=kwargs.get("use_hnsw", True), vector_params=kwargs.get( "vector_params", { "M": 16, "ef_construction": 200, }, ), decay_config=kwargs.get("decay_config", {}), ) # Configuration self.namespace = kwargs.get("namespace", "default") self.session_id = kwargs.get("session_id", "default") self.decay_config = kwargs.get("decay_config", {}) # ✅ CRITICAL: Always store metadata structure defined in YAML self.yaml_metadata = kwargs.get("metadata", {}) # Remove direct Redis client - use memory logger instead # self.redis = redis.from_url(...) # ← REMOVE
[docs] async def run(self, context: dict[str, Any]) -> dict[str, Any]: """Write to memory using RedisStack memory logger.""" try: # 🔍 DEBUG: Log the complete context structure logger.info(f"MemoryWriterNode received context keys: {list(context.keys())}") if "previous_outputs" in context: logger.info(f"Previous outputs keys: {list(context['previous_outputs'].keys())}") # Log structure of a few key outputs for key in ["logic_reasoning", "empathy_reasoning", "moderator_synthesis"]: if key in context["previous_outputs"]: output = context["previous_outputs"][key] logger.info( f"Structure of {key}: {type(output)} with keys: {list(output.keys()) if isinstance(output, dict) else 'not a dict'}", ) if isinstance(output, dict) and "result" in output: result = output["result"] logger.info( f" Result structure: {type(result)} with keys: {list(result.keys()) if isinstance(result, dict) else 'not a dict'}", ) # Log the YAML metadata we're trying to render logger.info(f"YAML metadata to render: {self.yaml_metadata}") # 🎯 CRITICAL FIX: Extract structured memory object from validation guardian memory_content = self._extract_memory_content(context) if not memory_content: return {"status": "error", "error": "No memory content to store"} # Extract configuration from context namespace = context.get("namespace", self.namespace) session_id = context.get("session_id", self.session_id) # ✅ CRITICAL: Always merge metadata from YAML config and context merged_metadata = self._merge_metadata(context) # 🔍 DEBUG: Log the final merged metadata logger.info(f"Final merged metadata: {merged_metadata}") # ✅ CRITICAL: Use memory logger for direct memory storage instead of orchestration logging memory_key = self.memory_logger.log_memory( content=memory_content, node_id=self.node_id, trace_id=session_id, metadata={ "namespace": namespace, "session": session_id, "category": "stored", # Mark as stored memory "log_type": "memory", # 🎯 CRITICAL: Mark as stored memory, not orchestration log "content_type": "user_input", **merged_metadata, # Include all metadata from YAML and context }, importance_score=self._calculate_importance_score(memory_content, merged_metadata), memory_type=self._classify_memory_type( merged_metadata, self._calculate_importance_score(memory_content, merged_metadata), ), expiry_hours=self._get_expiry_hours( self._classify_memory_type( merged_metadata, self._calculate_importance_score(memory_content, merged_metadata), ), self._calculate_importance_score(memory_content, merged_metadata), ), ) return { "status": "success", "session": session_id, "namespace": namespace, "content_length": len(str(memory_content)), "backend": "redisstack", "vector_enabled": True, "memory_key": memory_key, "stored_metadata": merged_metadata, # Include metadata in response } except Exception as e: logger.error(f"Error writing to memory: {e}") return {"status": "error", "error": str(e)}
def _merge_metadata(self, context: dict[str, Any]) -> dict[str, Any]: """Merge metadata from YAML config, context, and guardian outputs.""" try: # Start with YAML metadata structure (always preserve this) merged_metadata = self.yaml_metadata.copy() # ✅ CRITICAL: Render YAML metadata templates first rendered_yaml_metadata = self._render_metadata_templates(merged_metadata, context) # Add context metadata (overrides YAML where keys conflict) context_metadata = context.get("metadata", {}) rendered_yaml_metadata.update(context_metadata) # Extract metadata from guardian outputs if present guardian_metadata = self._extract_guardian_metadata(context) if guardian_metadata: rendered_yaml_metadata.update(guardian_metadata) # Extract structured memory object metadata if present memory_object_metadata = self._extract_memory_object_metadata(context) if memory_object_metadata: rendered_yaml_metadata.update(memory_object_metadata) return rendered_yaml_metadata except Exception as e: logger.warning(f"Error merging metadata: {e}") return self.yaml_metadata.copy() def _render_metadata_templates( self, metadata: dict[str, Any], context: dict[str, Any], ) -> dict[str, Any]: """Render Jinja2 templates in metadata using context data.""" try: rendered_metadata = {} # Create comprehensive template context template_context = { "input": context.get("input", ""), "previous_outputs": context.get("previous_outputs", {}), "timestamp": context.get("timestamp", ""), "now": lambda: context.get("timestamp", ""), # now() function for templates **context, # Include all context keys } # Debug: Log template context structure logger.info(f"Template context keys: {list(template_context.keys())}") logger.info( f"Previous outputs keys: {list(template_context.get('previous_outputs', {}).keys())}", ) # 🔍 DEBUG: Log specific structure that templates are trying to access prev_outputs = template_context.get("previous_outputs", {}) for agent_key in ["logic_reasoning", "empathy_reasoning", "moderator_synthesis"]: if agent_key in prev_outputs: agent_data = prev_outputs[agent_key] logger.info(f"Agent {agent_key} structure: {type(agent_data)}") if isinstance(agent_data, dict): logger.info(f" Keys: {list(agent_data.keys())}") if "result" in agent_data: result = agent_data["result"] logger.info(f" Result type: {type(result)}") if isinstance(result, dict): logger.info(f" Result keys: {list(result.keys())}") for key, value in metadata.items(): try: if isinstance(value, str) and ("{{" in value or "{%" in value): # Debug: Log what we're trying to render logger.info(f"🎯 Rendering template for key '{key}': {value}") # Render string templates with enhanced error handling template = Template(value) rendered_value = template.render(**template_context) # Debug: Log the rendered result logger.info( f"✅ Rendered value for key '{key}': {rendered_value[:200]}{'...' if len(str(rendered_value)) > 200 else ''}", ) # Handle special cases where rendered value might be None or empty if rendered_value is None or rendered_value == "": # Try to extract default value from template if present if "default(" in value: # Use original template string as fallback rendered_metadata[key] = value logger.warning( f"⚠️ Template rendered empty, using default for '{key}'", ) else: rendered_metadata[key] = "" logger.warning(f"⚠️ Template rendered empty for '{key}'") else: rendered_metadata[key] = rendered_value elif isinstance(value, dict): # Recursively render nested dictionaries rendered_metadata[key] = self._render_metadata_templates(value, context) elif isinstance(value, list): # Render templates in lists rendered_list = [] for item in value: if isinstance(item, str) and ("{{" in item or "{%" in item): try: template = Template(item) rendered_item = template.render(**template_context) rendered_list.append(rendered_item) except Exception as e: logger.warning(f"❌ Error rendering list item template: {e}") rendered_list.append(item) else: rendered_list.append(item) rendered_metadata[key] = rendered_list else: # Keep non-template values as-is rendered_metadata[key] = value except Exception as e: logger.error(f"❌ Error rendering template for metadata key '{key}': {e}") logger.error(f"Template value: {value}") logger.error(f"Template context keys: {list(template_context.keys())}") import traceback logger.error(f"Full traceback: {traceback.format_exc()}") # Keep original value if rendering fails rendered_metadata[key] = value return rendered_metadata except Exception as e: logger.error(f"❌ Error rendering metadata templates: {e}") import traceback logger.error(f"Full traceback: {traceback.format_exc()}") return metadata.copy() def _extract_guardian_metadata(self, context: dict[str, Any]) -> dict[str, Any]: """Extract metadata from validation guardian outputs.""" try: guardian_metadata = {} previous_outputs = context.get("previous_outputs", {}) # Check both validation guardians for metadata for guardian_name in ["false_validation_guardian", "true_validation_guardian"]: if guardian_name in previous_outputs: guardian_output = previous_outputs[guardian_name] if isinstance(guardian_output, dict): # Extract metadata from guardian result if "metadata" in guardian_output: guardian_metadata.update(guardian_output["metadata"]) # Extract validation status if "result" in guardian_output: result = guardian_output["result"] if isinstance(result, dict): guardian_metadata["validation_guardian"] = guardian_name guardian_metadata["validation_result"] = result.get( "validation_status", "unknown", ) return guardian_metadata except Exception as e: logger.warning(f"Error extracting guardian metadata: {e}") return {} def _extract_memory_object_metadata(self, context: dict[str, Any]) -> dict[str, Any]: """Extract metadata from structured memory objects.""" try: memory_object_metadata = {} previous_outputs = context.get("previous_outputs", {}) # Look for structured memory objects from guardians for guardian_name in ["false_validation_guardian", "true_validation_guardian"]: if guardian_name in previous_outputs: guardian_output = previous_outputs[guardian_name] if isinstance(guardian_output, dict) and "result" in guardian_output: result = guardian_output["result"] if isinstance(result, dict) and "memory_object" in result: memory_obj = result["memory_object"] if isinstance(memory_obj, dict): # Extract structured fields as metadata memory_object_metadata["structured_data"] = memory_obj memory_object_metadata["analysis_type"] = memory_obj.get( "analysis_type", "unknown", ) memory_object_metadata["confidence"] = memory_obj.get( "confidence", 1.0, ) memory_object_metadata["validation_status"] = memory_obj.get( "validation_status", "unknown", ) return memory_object_metadata except Exception as e: logger.warning(f"Error extracting memory object metadata: {e}") return {} def _extract_memory_content(self, context: dict[str, Any]) -> str: """Extract structured memory content from validation guardian output.""" try: # Look for structured memory object from validation guardian previous_outputs = context.get("previous_outputs", {}) # Try validation guardians (both true and false) for guardian_name in ["false_validation_guardian", "true_validation_guardian"]: if guardian_name in previous_outputs: guardian_output = previous_outputs[guardian_name] if isinstance(guardian_output, dict) and "result" in guardian_output: result = guardian_output["result"] if isinstance(result, dict) and "memory_object" in result: memory_obj = result["memory_object"] # Convert structured object to searchable text return self._memory_object_to_text(memory_obj, context.get("input", "")) # Fallback: use raw input if no structured memory object found return context.get("input", "") except Exception as e: logger.warning(f"Error extracting memory content: {e}") return context.get("input", "") def _memory_object_to_text(self, memory_obj: dict[str, Any], original_input: str) -> str: """Convert structured memory object to searchable text format.""" try: # Create a natural language representation that's searchable number = memory_obj.get("number", original_input) result = memory_obj.get("result", "unknown") condition = memory_obj.get("condition", "") analysis_type = memory_obj.get("analysis_type", "") confidence = memory_obj.get("confidence", 1.0) # Format as searchable text text_parts = [ f"Number: {number}", f"Greater than 5: {result}", f"Condition: {condition}", f"Analysis: {analysis_type}", f"Confidence: {confidence}", f"Validated: {memory_obj.get('validation_status', 'unknown')}", ] # Add the structured data as JSON for exact matching structured_text = " | ".join(text_parts) structured_text += f" | JSON: {memory_obj}" return structured_text except Exception as e: logger.warning(f"Error converting memory object to text: {e}") return str(memory_obj) def _calculate_importance_score(self, content: str, metadata: dict[str, Any]) -> float: """Calculate importance score for memory retention decisions.""" score = 0.5 # Base score # Content length bonus (longer content often more important) if len(content) > 500: score += 0.2 elif len(content) > 100: score += 0.1 # Metadata indicators if metadata.get("category") == "stored": score += 0.3 # Explicitly stored memories are more important # Query presence (memories with queries are often more important) if metadata.get("query"): score += 0.1 # Clamp score between 0.0 and 1.0 return max(0.0, min(1.0, score)) def _classify_memory_type(self, metadata: dict[str, Any], importance_score: float) -> str: """Classify memory as short-term or long-term based on metadata and importance.""" # Stored memories with high importance are long-term if metadata.get("category") == "stored" and importance_score >= 0.7: return "long_term" # Agent-specific configuration if self.decay_config.get("default_long_term", False): return "long_term" return "short_term" def _get_expiry_hours(self, memory_type: str, importance_score: float) -> float: """Get expiry time in hours based on memory type and importance.""" if memory_type == "long_term": # Check agent-level config first, then fall back to global config base_hours = self.decay_config.get("long_term_hours") or self.decay_config.get( "default_long_term_hours", 24.0, ) else: # Check agent-level config first, then fall back to global config base_hours = self.decay_config.get("short_term_hours") or self.decay_config.get( "default_short_term_hours", 1.0, ) # Adjust based on importance (higher importance = longer retention) importance_multiplier = 1.0 + importance_score return base_hours * importance_multiplier