# OrKa: Orchestrator Kit Agents
# Copyright © 2025 Marco Somma
#
# This file is part of OrKa – https://github.com/marcosomma/orka-reasoning
"""
Safety Controller
================
Comprehensive safety assessment and risk management for path selection.
Implements safety policies, risk scoring, and guardrail enforcement.
"""
import logging
import re
from typing import Any, Dict, List, Optional, Set
logger = logging.getLogger(__name__)
[docs]
class SafetyPolicy:
"""Defines safety policies and risk assessment rules."""
[docs]
def __init__(self, profile: str = "default"):
"""Initialize safety policy with profile."""
self.profile = profile
self.risk_patterns = self._load_risk_patterns()
self.forbidden_capabilities = self._load_forbidden_capabilities()
self.content_filters = self._load_content_filters()
def _load_risk_patterns(self) -> Dict[str, List[str]]:
"""Load risk detection patterns."""
return {
"pii": [
r"\b\d{3}-\d{2}-\d{4}\b", # SSN pattern
r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b", # Credit card
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", # Email
],
"medical": [
r"\b(diagnosis|prescription|medical record|patient)\b",
r"\b(medication|treatment|therapy|surgery)\b",
],
"legal": [
r"\b(legal advice|lawsuit|litigation|attorney)\b",
r"\b(contract|agreement|liability|damages)\b",
],
"financial": [
r"\b(investment advice|trading|stocks|portfolio)\b",
r"\b(loan|mortgage|credit|debt)\b",
],
}
def _load_forbidden_capabilities(self) -> Set[str]:
"""Load forbidden capabilities based on profile."""
if self.profile == "strict":
return {
"external_api_calls",
"file_system_access",
"database_writes",
"email_sending",
"code_execution",
}
elif self.profile == "moderate":
return {"file_system_access", "database_writes", "code_execution"}
else: # default
return {"code_execution"}
def _load_content_filters(self) -> Dict[str, List[str]]:
"""Load content filtering rules."""
return {
"harmful": [r"\b(violence|harm|attack|threat)\b", r"\b(illegal|criminal|fraud|scam)\b"],
"inappropriate": [r"\b(explicit|adult|nsfw)\b", r"\b(hate|discrimination|bias)\b"],
}
[docs]
class SafetyController:
"""
Comprehensive safety assessment and control system.
Evaluates candidate paths for safety risks including:
- Content safety (PII, harmful content)
- Capability restrictions
- Policy compliance
- Risk scoring and thresholds
"""
[docs]
def __init__(self, config: Any):
"""Initialize safety controller with configuration."""
self.config = config
self.safety_threshold = config.safety_threshold
self.policy = SafetyPolicy(config.safety_profile)
logger.debug(f"SafetyController initialized with profile: {config.safety_profile}")
[docs]
async def assess_candidates(
self, candidates: List[Dict[str, Any]], context: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""
Assess all candidates for safety compliance.
Args:
candidates: List of candidates with previews
context: Execution context
Returns:
List of candidates that pass safety assessment
"""
try:
safe_candidates = []
for candidate in candidates:
safety_assessment = await self._assess_candidate_safety(candidate, context)
# Add safety information to candidate
candidate["safety_score"] = safety_assessment["score"]
candidate["safety_risks"] = safety_assessment["risks"]
candidate["safety_details"] = safety_assessment["details"]
# Filter based on safety threshold
if safety_assessment["score"] >= (1.0 - self.safety_threshold):
safe_candidates.append(candidate)
else:
logger.warning(
f"Candidate {candidate['node_id']} failed safety check: "
f"score={safety_assessment['score']:.3f}, "
f"risks={safety_assessment['risks']}"
)
logger.info(
f"Safety assessment: {len(safe_candidates)}/{len(candidates)} " f"candidates passed"
)
return safe_candidates
except Exception as e:
logger.error(f"Safety assessment failed: {e}")
return candidates # Default to allowing all if assessment fails
async def _assess_candidate_safety(
self, candidate: Dict[str, Any], context: Dict[str, Any]
) -> Dict[str, Any]:
"""Assess safety for a single candidate."""
try:
risks = []
risk_scores = []
details = {}
# Content safety assessment
content_risks = await self._assess_content_safety(candidate, context)
risks.extend(content_risks["risks"])
risk_scores.append(content_risks["score"])
details["content"] = content_risks["details"]
# Capability safety assessment
capability_risks = await self._assess_capability_safety(candidate, context)
risks.extend(capability_risks["risks"])
risk_scores.append(capability_risks["score"])
details["capabilities"] = capability_risks["details"]
# Policy compliance assessment
policy_risks = await self._assess_policy_compliance(candidate, context)
risks.extend(policy_risks["risks"])
risk_scores.append(policy_risks["score"])
details["policy"] = policy_risks["details"]
# Calculate overall safety score (average of component scores)
overall_score = sum(risk_scores) / len(risk_scores) if risk_scores else 1.0
return {
"score": overall_score,
"risks": list(set(risks)), # Remove duplicates
"details": details,
}
except Exception as e:
logger.error(f"Individual safety assessment failed: {e}")
return {
"score": 0.0, # Fail safe - assume unsafe if assessment fails
"risks": ["assessment_error"],
"details": {"error": str(e)},
}
async def _assess_content_safety(
self, candidate: Dict[str, Any], context: Dict[str, Any]
) -> Dict[str, Any]:
"""Assess content safety risks."""
try:
risks = []
details = {}
# Get content to analyze
preview = candidate.get("preview", "")
question = context.get("input", "")
content_to_check = f"{question} {preview}"
# Check for PII patterns
pii_risks = self._check_patterns(
content_to_check, self.policy.risk_patterns.get("pii", [])
)
if pii_risks:
risks.append("pii_detected")
details["pii_matches"] = pii_risks
# Check for medical content
medical_risks = self._check_patterns(
content_to_check, self.policy.risk_patterns.get("medical", [])
)
if medical_risks:
risks.append("medical_content")
details["medical_matches"] = medical_risks
# Check for legal content
legal_risks = self._check_patterns(
content_to_check, self.policy.risk_patterns.get("legal", [])
)
if legal_risks:
risks.append("legal_content")
details["legal_matches"] = legal_risks
# Check for harmful content
harmful_risks = self._check_patterns(
content_to_check, self.policy.content_filters.get("harmful", [])
)
if harmful_risks:
risks.append("harmful_content")
details["harmful_matches"] = harmful_risks
# Calculate content safety score
if not risks:
score = 1.0 # Perfect safety
elif len(risks) == 1 and risks[0] in ["medical_content", "legal_content"]:
score = 0.7 # Moderate risk for domain-specific content
else:
score = 0.3 # High risk for PII or harmful content
return {"score": score, "risks": risks, "details": details}
except Exception as e:
logger.error(f"Content safety assessment failed: {e}")
return {"score": 0.5, "risks": [], "details": {}}
async def _assess_capability_safety(
self, candidate: Dict[str, Any], context: Dict[str, Any]
) -> Dict[str, Any]:
"""Assess capability-based safety risks."""
try:
risks = []
details = {}
node_id = candidate["node_id"]
# Check for forbidden capabilities
forbidden_found = []
for forbidden_cap in self.policy.forbidden_capabilities:
if self._node_has_capability(node_id, forbidden_cap):
forbidden_found.append(forbidden_cap)
risks.append(f"forbidden_capability_{forbidden_cap}")
details["forbidden_capabilities"] = forbidden_found
# Calculate capability safety score
if not forbidden_found:
score = 1.0
elif len(forbidden_found) == 1 and forbidden_found[0] in ["external_api_calls"]:
score = 0.8 # Moderate risk for API calls
else:
score = 0.2 # High risk for multiple forbidden capabilities
return {"score": score, "risks": risks, "details": details}
except Exception as e:
logger.error(f"Capability safety assessment failed: {e}")
return {"score": 0.5, "risks": [], "details": {}}
async def _assess_policy_compliance(
self, candidate: Dict[str, Any], context: Dict[str, Any]
) -> Dict[str, Any]:
"""Assess policy compliance."""
try:
risks = []
details = {}
# Check for policy violations
# This is a placeholder for more sophisticated policy checking
# Example: Check if path violates maximum depth policy
path_length = len(candidate.get("path", []))
if path_length > 5: # Arbitrary limit
risks.append("path_too_long")
details["path_length"] = path_length
# Calculate policy compliance score
score = 1.0 if not risks else 0.6
return {"score": score, "risks": risks, "details": details}
except Exception as e:
logger.error(f"Policy compliance assessment failed: {e}")
return {"score": 0.5, "risks": [], "details": {}}
def _check_patterns(self, text: str, patterns: List[str]) -> List[str]:
"""Check text against risk patterns."""
matches = []
try:
for pattern in patterns:
if re.search(pattern, text, re.IGNORECASE):
matches.append(pattern)
return matches
except Exception as e:
logger.error(f"Pattern checking failed: {e}")
return []
def _node_has_capability(self, node_id: str, capability: str) -> bool:
"""Check if node has a specific capability."""
try:
# Simple heuristic based on node ID
node_lower = node_id.lower()
capability_indicators = {
"external_api_calls": ["api", "http", "web", "search"],
"file_system_access": ["file", "disk", "storage"],
"database_writes": ["db", "database", "write", "store"],
"email_sending": ["email", "mail", "send"],
"code_execution": ["exec", "run", "execute", "code"],
}
indicators = capability_indicators.get(capability, [])
return any(indicator in node_lower for indicator in indicators) if indicators else False
except Exception as e:
logger.error(f"Capability checking failed: {e}")
return False