Source code for orka.nodes.rag_node

import logging
from typing import Any

from ..contracts import Context, Registry
from .base_node import BaseNode

logger = logging.getLogger(__name__)


[docs] class RAGNode(BaseNode): """ RAG Node Implementation ====================== A specialized node that performs Retrieval-Augmented Generation (RAG) operations by combining semantic search with language model generation. Core Functionality ----------------- **RAG Process:** 1. **Query Processing**: Extract and prepare the input query 2. **Embedding Generation**: Convert query to vector representation 3. **Memory Search**: Find relevant documents using semantic similarity 4. **Context Formatting**: Structure retrieved documents for LLM consumption 5. **Answer Generation**: Use LLM to generate response based on context **Integration Points:** - **Memory Backend**: Searches for relevant documents using vector similarity - **Embedder Service**: Generates query embeddings for semantic search - **LLM Service**: Generates final answers based on retrieved context - **Registry System**: Accesses shared resources through dependency injection Architecture Details ------------------- **Node Configuration:** - `top_k`: Number of documents to retrieve (default: 5) - `score_threshold`: Minimum similarity score for relevance (default: 0.7) - `timeout`: Maximum execution time for the operation - `max_concurrency`: Limit on parallel executions **Resource Management:** - Lazy initialization of expensive resources (memory, embedder, LLM) - Registry-based dependency injection for shared services - Automatic resource cleanup and lifecycle management - Thread-safe execution for concurrent operations **Error Handling:** - Graceful handling of missing or invalid queries - Fallback responses when no relevant documents found - Structured error reporting with context preservation - Automatic retry logic for transient failures Implementation Features ---------------------- **Search Capabilities:** - Vector similarity search using embeddings - Configurable relevance thresholds - Top-k result limiting for performance - Metadata filtering and namespace support **Context Management:** - Intelligent document formatting for LLM consumption - Source attribution and reference tracking - Context length optimization for model limits - Structured output with sources and confidence scores **LLM Integration:** - Dynamic prompt construction with retrieved context - Configurable model parameters and settings - Response quality validation and filtering - Token usage tracking and optimization Usage Examples -------------- **Basic Configuration:** ```yaml agents: - id: rag_assistant type: rag top_k: 5 score_threshold: 0.7 timeout: 30.0 ``` **Advanced Configuration:** ```yaml agents: - id: specialized_rag type: rag top_k: 10 score_threshold: 0.8 max_concurrency: 5 llm_config: model: "gpt-4" temperature: 0.1 max_tokens: 500 ``` **Integration with Memory:** ```python # The node automatically integrates with the memory system # Memory backend provides semantic search capabilities # Embedder service generates query vectors # LLM service generates final responses ``` Response Format -------------- **Successful Response:** ```json { "result": { "answer": "Generated response based on retrieved context", "sources": [ { "content": "Source document content", "score": 0.85, "metadata": {...} } ] }, "status": "success", "error": null, "metadata": {"node_id": "rag_assistant"} } ``` **Error Response:** ```json { "result": null, "status": "error", "error": "Query is required for RAG operation", "metadata": {"node_id": "rag_assistant"} } ``` **No Results Response:** ```json { "result": { "answer": "I couldn't find any relevant information to answer your question.", "sources": [] }, "status": "success", "error": null, "metadata": {"node_id": "rag_assistant"} } ``` Performance Considerations ------------------------- **Optimization Features:** - Lazy resource initialization to reduce startup time - Configurable concurrency limits for resource management - Efficient context formatting to minimize token usage - Caching strategies for frequently accessed documents **Scalability:** - Supports high-throughput query processing - Memory-efficient document handling - Parallel processing capabilities - Resource pooling for external services **Monitoring:** - Execution timing and performance metrics - Search quality and relevance tracking - LLM usage and cost monitoring - Error rate and pattern analysis """ def __init__( self, node_id: str, registry: Registry, prompt: str = "", queue: str = "default", timeout: float | None = 30.0, max_concurrency: int = 10, top_k: int = 5, score_threshold: float = 0.7, ): super().__init__( node_id=node_id, prompt=prompt, queue=queue, timeout=timeout, max_concurrency=max_concurrency, ) self.registry = registry self.top_k = top_k self.score_threshold = score_threshold self._memory = None self._embedder = None self._llm = None self._initialized = False
[docs] async def initialize(self) -> None: """Initialize the node and its resources.""" self._memory = self.registry.get("memory") self._embedder = self.registry.get("embedder") self._llm = self.registry.get("llm") self._initialized = True
[docs] async def run(self, context: Context) -> dict[str, Any]: """Run the RAG node with the given context.""" if not self._initialized: await self.initialize() try: result = await self._run_impl(context) return { "result": result, "status": "success", "error": None, "metadata": {"node_id": self.node_id}, } except Exception as e: logger.error(f"RAGNode {self.node_id} failed: {e!s}") return { "result": None, "status": "error", "error": str(e), "metadata": {"node_id": self.node_id}, }
async def _run_impl(self, ctx: Context) -> dict[str, Any]: """Implementation of RAG operations.""" query = ctx.get("query") if not query: raise ValueError("Query is required for RAG operation") # Get embedding for the query query_embedding = await self._get_embedding(query) # Search memory for relevant documents results = await self._memory.search( query_embedding, limit=self.top_k, score_threshold=self.score_threshold, ) if not results: return { "answer": "I couldn't find any relevant information to answer your question.", "sources": [], } # Format context from results context = self._format_context(results) # Generate answer using LLM answer = await self._generate_answer(query, context) return {"answer": answer, "sources": results} async def _get_embedding(self, text: str) -> list[float]: """Get embedding for text using the embedder.""" return await self._embedder.encode(text) def _format_context(self, results: list[dict[str, Any]]) -> str: """Format search results into context for the LLM.""" context_parts = [] for i, result in enumerate(results, 1): context_parts.append(f"Source {i}:\n{result['content']}\n") return "\n".join(context_parts) async def _generate_answer(self, query: str, context: str) -> str: """Generate answer using the LLM.""" prompt = f"""Based on the following context, answer the question. If the context doesn't contain relevant information, say so. Context: {context} Question: {query} Answer:""" response = await self._llm.chat.completions.create( model="gpt-3.5-turbo", messages=[ { "role": "system", "content": "You are a helpful assistant that answers questions based on the provided context.", }, {"role": "user", "content": prompt}, ], ) return response.choices[0].message.content