Source code for orka.nodes.join_node

# OrKa: Orchestrator Kit Agents
# Copyright © 2025 Marco Somma
#
# This file is part of OrKa – https://github.com/marcosomma/orka-reasoning
#
# Licensed under the Apache License, Version 2.0 (Apache 2.0).
# You may not use this file for commercial purposes without explicit permission.
#
# Full license: https://www.apache.org/licenses/LICENSE-2.0
# For commercial use, contact: marcosomma.work@gmail.com
#
# Required attribution: OrKa by Marco Somma – https://github.com/marcosomma/orka-reasoning


import json
import logging

from .base_node import BaseNode


[docs] class JoinNode(BaseNode): """ A node that waits for and merges results from parallel branches created by a ForkNode. Uses a max retry counter to prevent infinite waiting. """ def __init__(self, node_id, prompt, queue, memory_logger=None, **kwargs): super().__init__(node_id, prompt, queue, **kwargs) self.memory_logger = memory_logger self.group_id = kwargs.get("group") self.max_retries = kwargs.get("max_retries", 30) self.output_key = f"{self.node_id}:output" self._retry_key = f"{self.node_id}:join_retry_count"
[docs] def run(self, input_data): """ Run the join operation by collecting and merging results from forked agents. """ fork_group_id = input_data.get("fork_group_id", self.group_id) state_key = "waitfor:join_parallel_checks:inputs" # Get or increment retry count using backend-agnostic hash operations retry_count_str = self.memory_logger.hget("join_retry_counts", self._retry_key) if retry_count_str is None: retry_count = 3 else: retry_count = int(retry_count_str) + 1 self.memory_logger.hset("join_retry_counts", self._retry_key, str(retry_count)) # Get list of received inputs and expected targets inputs_received = self.memory_logger.hkeys(state_key) received = [i.decode() if isinstance(i, bytes) else i for i in inputs_received] fork_targets = self.memory_logger.smembers(f"fork_group:{fork_group_id}") fork_targets = [i.decode() if isinstance(i, bytes) else i for i in fork_targets] pending = [agent for agent in fork_targets if agent not in received] # Check if all expected agents have completed if not pending: self.memory_logger.hdel("join_retry_counts", self._retry_key) return self._complete(fork_targets, state_key) # Check for max retries if retry_count >= self.max_retries: self.memory_logger.hdel("join_retry_counts", self._retry_key) return { "status": "timeout", "pending": pending, "received": received, "max_retries": self.max_retries, } # Return waiting status if not all agents have completed return { "status": "waiting", "pending": pending, "received": received, "retry_count": retry_count, "max_retries": self.max_retries, }
def _complete(self, fork_targets, state_key): """ Complete the join operation by merging all fork results. Args: fork_targets (list): List of agent IDs to collect results from state_key (str): Redis key where results are stored Returns: dict: Merged results from all agents """ logger = logging.getLogger(__name__) # Get all results from Redis merged = {} for agent_id in fork_targets: try: # Get result from Redis result_str = self.memory_logger.hget(state_key, agent_id) if result_str: # Parse result JSON try: result = json.loads(result_str) except (json.JSONDecodeError, TypeError): result = result_str # Store result in merged dict if isinstance(result, dict): if "result" in result: # If result has a nested result field, use that merged[agent_id] = result["result"] elif "response" in result: # If result has a response field (common for LLM agents), use that merged[agent_id] = { "response": result["response"], "confidence": result.get("confidence", "0.0"), "internal_reasoning": result.get("internal_reasoning", ""), "_metrics": result.get("_metrics", {}), "formatted_prompt": result.get("formatted_prompt", ""), } else: # Otherwise use the whole result merged[agent_id] = result else: # If not a dict, use as is merged[agent_id] = result logger.debug(f"- Merged result for agent {agent_id}") # Store the result in Redis key for direct access fork_group_id = result.get("fork_group", "unknown") agent_key = f"agent_result:{fork_group_id}:{agent_id}" self.memory_logger.set(agent_key, json.dumps(merged[agent_id])) logger.debug(f"- Stored result for agent {agent_id}") # Store in Redis hash for group tracking group_key = f"fork_group_results:{fork_group_id}" self.memory_logger.hset(group_key, agent_id, json.dumps(merged[agent_id])) logger.debug(f"- Stored result in group for agent {agent_id}") else: logger.warning(f"No result found for agent {agent_id}") except Exception as e: logger.error(f"Error processing result for agent {agent_id}: {e}") # Add error result to show something went wrong merged[agent_id] = {"error": str(e)} # Store output using hash operations self.memory_logger.hset("join_outputs", self.output_key, json.dumps(merged)) # Clean up state using hash operations if fork_targets: # Only call hdel (hash delete) if there are keys to delete self.memory_logger.hdel(state_key, *fork_targets) # Return merged results with status and individual agent results result = { "status": "done", "merged": merged, **merged, # Expose individual agent results at top level } # Store the final result in Redis join_key = f"join_result:{self.node_id}" self.memory_logger.set(join_key, json.dumps(result)) logger.debug(f"- Stored final join result: {join_key}") # Store in Redis hash for group tracking group_key = f"join_results:{self.node_id}" self.memory_logger.hset(group_key, "result", json.dumps(result)) logger.debug(f"- Stored final result in group") return result