ELE ENGINE CODE

"""
ELE ENGINE — INTEGRATED CANONICAL IMPLEMENTATION
===============================================

Empirical Linguistic Engine (ELE)
---------------------------------
A unified, interaction-dominant architecture:

    P1 (Physics) → P2 (Physiology) → L (Linguistics) → C1 (Cognition) → C2 (Communication)
    with feedback: C2 → P2 (and rerun L → C1 → C2)

Features:
- Aerodynamic grounding (P1: VC, lung pressure, breath groups)
- Motor grounding (P2: CT/TA activations, f0 trajectories) with accent profiles
- Recursive Language Model (L: SimpleRLM / GRU in PyTorch)
- Embodied grounding (C1: SensorimotorSimModule with kinematics, success/failure)
- Pragmatics & ToM (C2: norm_level, repair_strategy, feedback_to_lower)
- Cognitive disruption tracking (C1: disruption_index)
- Robust auto-repair loop (ELEngine.robust_process) that:
    - interprets noise/disruption,
    - adjusts parameters (depth, belief),
    - reruns until coherent or safely exhausted.
"""

from typing import Dict, Any, List, Callable, Optional, Tuple
from dataclasses import dataclass, field
from enum import Enum
from abc import ABC, abstractmethod

import random
import numpy as np
import torch
import torch.nn as nn


# ───────────────────────────────────────────
# ENUMS & METRICS
# ───────────────────────────────────────────

class ELEModule(Enum):
    PHYSICS = "P1_Physics"
    PHYSIOLOGY = "P2_Physiology"
    LINGUISTICS = "L_Linguistics"
    COGNITION = "C1_Cognition"
    COMMUNICATION = "C2_Communication"


@dataclass
class EmpiricalMetric:
    name: str
    unit: str
    norm: float
    current: float = field(default=0.0)

    def update(self, variability: float = 0.1) -> None:
        sigma = max(self.norm * variability, 1e-6)
        self.current = float(np.random.normal(self.norm, sigma))


class ELEState(ABC):
    def __init__(self, metrics: Optional[List[EmpiricalMetric]] = None) -> None:
        self.metrics: List[EmpiricalMetric] = metrics or []
        self._init_metrics()

    def _init_metrics(self) -> None:
        for m in self.metrics:
            m.update()

    def get_metrics(self) -> Dict[str, Dict[str, Any]]:
        return {m.name: {"unit": m.unit, "norm": m.norm, "current": m.current} for m in self.metrics}


# ───────────────────────────────────────────
# STATE CLASSES
# ───────────────────────────────────────────

@dataclass
class PhysicsState(ELEState):
    vc: float = 3300.0
    p_sub: float = 10.0
    airflow_rate: float = 100.0
    max_phon_time: float = 0.0
    metrics: List[EmpiricalMetric] = field(default_factory=lambda: [
        EmpiricalMetric("Vital Capacity", "mL", 3300),
        EmpiricalMetric("Lung Pressure", "cm H2O", 10),
    ])

    def __post_init__(self) -> None:
        super().__init__(metrics=self.metrics)

    def compute_max_phon_time(self) -> float:
        self.max_phon_time = self.vc / max(self.airflow_rate, 1e-6)
        return self.max_phon_time


@dataclass
class PhysiologyState(ELEState):
    ct_activation: float = 50.0
    ta_activation: float = 40.0
    vocal_strain: float = 0.0
    accent_profile: str = "neutral"
    metrics: List[EmpiricalMetric] = field(default_factory=lambda: [
        EmpiricalMetric("CT Activation", "%", 50),
        EmpiricalMetric("TA Activation", "%", 40),
    ])

    def __post_init__(self) -> None:
        super().__init__(metrics=self.metrics)


@dataclass
class LinguisticsState(ELEState):
    phonemes: str = ""
    morphemes: str = ""
    lexemes: str = ""
    sememes: str = ""
    recursion_depth: int = 2
    metrics: List[EmpiricalMetric] = field(default_factory=lambda: [
        EmpiricalMetric("Morph Decomposition ERP", "uV", 5.0),
        EmpiricalMetric("Recursion Depth", "levels", 2),
    ])

    def __post_init__(self) -> None:
        super().__init__(metrics=self.metrics)


@dataclass
class CognitionState(ELEState):
    concept_graph: Dict[str, List[str]] = field(default_factory=dict)
    chunks: List[str] = field(default_factory=list)
    chunk_capacity: int = 7
    disruption_index: float = 0.0
    metrics: List[EmpiricalMetric] = field(default_factory=lambda: [
        EmpiricalMetric("Chunk Capacity", "items", 7),
        EmpiricalMetric("Retrieval Time", "ms", 300),
        EmpiricalMetric("Disruption Index", "score", 0.0),
    ])

    def __post_init__(self) -> None:
        super().__init__(metrics=self.metrics)


@dataclass
class CommunicationState(ELEState):
    pragmemes: str = ""
    tom_beliefs: Dict[str, str] = field(default_factory=dict)
    norm_level: float = 0.85
    repair_strategy: str = "none"
    speaker_accent: str = "neutral"
    metrics: List[EmpiricalMetric] = field(default_factory=lambda: [
        EmpiricalMetric("Pragmatic Score", "0-100", 85),
        EmpiricalMetric("ToM Accuracy", "%", 75),
    ])

    def __post_init__(self) -> None:
        super().__init__(metrics=self.metrics)


@dataclass
class SensorimotorState(ELEState):
    grasp_success_rate: float = 80.0
    metrics: List[EmpiricalMetric] = field(default_factory=lambda: [
        EmpiricalMetric("Grasp Success Rate", "%", 80),
    ])

    def __post_init__(self) -> None:
        super().__init__(metrics=self.metrics)


# ───────────────────────────────────────────
# BASE MODULE
# ───────────────────────────────────────────

class ELEModuleBase(ABC):
    def __init__(self, name: str, state_cls: Callable[[], ELEState]) -> None:
        self.name = name
        self.state: ELEState = state_cls()

    @abstractmethod
    def process(self, inputs: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        ...

    @abstractmethod
    def _define_inputs(self) -> Dict[str, str]:
        ...

    @abstractmethod
    def _define_outputs(self) -> Dict[str, str]:
        ...

    def get_api_contract(self) -> Dict[str, Any]:
        return {
            "inputs": self._define_inputs(),
            "outputs": self._define_outputs(),
            "states": [m.name for m in self.state.metrics],
        }


# ───────────────────────────────────────────
# P1: PHYSICS MODULE
# ───────────────────────────────────────────

class PhysicsModule(ELEModuleBase):
    def _define_inputs(self) -> Dict[str, str]:
        return {"utterance_intent": "str", "sensor_data": "Optional[Dict]"}

    def _define_outputs(self) -> Dict[str, str]:
        return {
            "max_phon_time": "float",
            "breath_groups": "List[str]",
            "acoustic_envelope": "Dict",
            "raw_signal": "str",
        }

    def process(self, inputs: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        sensor = inputs.get("sensor_data") or {}
        self.state.vc = sensor.get("vc", self.state.vc)
        self.state.p_sub = sensor.get("p_sub", self.state.p_sub)

        max_time = self.state.compute_max_phon_time()
        intent = inputs.get("utterance_intent", "") or ""
        duration = len(intent) / 5.0
        constrained = intent[: int(max_time * 5)] if duration > max_time else intent

        breath_groups = [constrained[i:i+10] for i in range(0, len(constrained), 10)]
        if not breath_groups:
            breath_groups = [""]

        envelope = {
            "loudness": [self.state.p_sub] * len(breath_groups),
            "time": list(range(len(breath_groups))),
        }

        self.state._init_metrics()
        outputs = {
            "max_phon_time": max_time,
            "breath_groups": breath_groups,
            "acoustic_envelope": envelope,
            "raw_signal": constrained,
        }
        return outputs, self.state.get_metrics()


# ───────────────────────────────────────────
# P2: PHYSIOLOGY MODULE
# ───────────────────────────────────────────

class PhysiologyModule(ELEModuleBase):
    def _define_inputs(self) -> Dict[str, str]:
        return {"acoustic_envelope": "Dict", "context_mod": "str", "accent_profile": "str"}

    def _define_outputs(self) -> Dict[str, str]:
        return {"pneuma_commands": "List[Dict]", "f0_trajectory": "List[float]"}

    def process(self, inputs: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        envelope = inputs.get("acoustic_envelope", {})
        context_mod = inputs.get("context_mod", "normal")
        accent_profile = inputs.get("accent_profile", "neutral")
        self.state.accent_profile = accent_profile

        # Accent baseline
        if accent_profile == "soft":
            base_ct = 45.0
        elif accent_profile == "harsh":
            base_ct = 55.0
        else:
            base_ct = 50.0

        # Context modulation
        if context_mod == "shout":
            self.state.ct_activation = base_ct + 10.0
        elif context_mod == "whisper":
            self.state.ct_activation = base_ct - 15.0
        else:
            self.state.ct_activation = base_ct

        self.state.ta_activation = self.state.ct_activation * 0.8
        self.state.vocal_strain = (self.state.ct_activation + self.state.ta_activation) / 2.0

        num_points = len(envelope.get("time", [0]))
        pneuma = [
            {
                "t": i,
                "ct": self.state.ct_activation,
                "ta": self.state.ta_activation,
                "accent": self.state.accent_profile,
            }
            for i in range(num_points)
        ]

        base_f0 = 200.0
        if accent_profile == "soft":
            base_f0 -= 10.0
        elif accent_profile == "harsh":
            base_f0 += 10.0

        if context_mod == "whisper":
            base_f0 -= 20.0
        elif context_mod == "shout":
            base_f0 += 20.0

        f0 = [base_f0 + random.randint(-10, 10) for _ in range(num_points)]

        self.state._init_metrics()
        outputs = {"pneuma_commands": pneuma, "f0_trajectory": f0}
        return outputs, self.state.get_metrics()


# ───────────────────────────────────────────
# SIMPLE RLM (L)
# ───────────────────────────────────────────

class SimpleRLM(nn.Module):
    def __init__(self, vocab_size: int = 100, embed_dim: int = 64, num_layers: int = 2) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.GRU(embed_dim, embed_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(embed_dim, vocab_size)
        self.vocab_size = vocab_size

    def forward(self, input_ids: torch.Tensor, recursion_depth: int) -> Tuple[torch.Tensor, torch.Tensor]:
        embedded = self.embedding(input_ids)
        output, hidden = self.rnn(embedded)

        current_hidden = hidden
        for _ in range(recursion_depth):
            sub_input = torch.randint(0, self.vocab_size, (1, 5), device=input_ids.device)
            sub_embed = self.embedding(sub_input)
            noise = torch.randn_like(current_hidden) * 0.1
            _, new_hidden = self.rnn(sub_embed, current_hidden + noise)
            current_hidden = new_hidden

        logits = self.fc(output)
        return logits, current_hidden


# ───────────────────────────────────────────
# L: LINGUISTICS MODULE
# ───────────────────────────────────────────

class LinguisticsModule(ELEModuleBase):
    def __init__(self, name: str, state_cls: Callable[[], ELEState]) -> None:
        super().__init__(name, state_cls)
        self.rlm = SimpleRLM(vocab_size=100)

    def _define_inputs(self) -> Dict[str, str]:
        return {"raw_signal": "str", "breath_groups": "List[str]"}

    def _define_outputs(self) -> Dict[str, str]:
        return {
            "phonemes": "str",
            "morphemes": "str",
            "lexemes": "str",
            "sememes": "str",
            "structured_prop": "Dict",
            "rlm_hidden": "List",
        }

    def process(self, inputs: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        raw = inputs.get("raw_signal", "") or "dummy"
        depth = context.get("recursion_depth", 2)
        self.state.recursion_depth = depth

        token_ids = torch.tensor([[ord(c) % 100 for c in raw[:20]]], dtype=torch.long)
        logits, rec_hidden = self.rlm(token_ids, depth)

        self.state.phonemes = f"Ph:{logits.mean().item():.2f}"
        self.state.morphemes = f"Morph{{{rec_hidden.mean().item():.2f}}}"
        self.state.lexemes = f'"{self.state.morphemes}"'
        self.state.sememes = raw

        structured = {
            "logits_shape": list(logits.shape),
            "rec_hidden_shape": list(rec_hidden.shape),
            "rlm_depth": depth,
        }
        rlm_hidden_list = rec_hidden.detach().cpu().numpy().tolist()

        self.state._init_metrics()
        outputs = {
            "phonemes": self.state.phonemes,
            "morphemes": self.state.morphemes,
            "lexemes": self.state.lexemes,
            "sememes": self.state.sememes,
            "structured_prop": structured,
            "rlm_hidden": rlm_hidden_list,
        }
        return outputs, self.state.get_metrics()


# ───────────────────────────────────────────
# SENSORIMOTOR SIM MODULE (C1 backend)
# ───────────────────────────────────────────

class SensorimotorSimModule(ELEModuleBase):
    def __init__(self, name: str = "SimModule") -> None:
        super().__init__(name, SensorimotorState)
        self.action_map = {
            "grasp": {"force": 1.0, "target": np.array([1.0, 0.5])},
            "manipulate_triangle": {"force": 0.5, "target": np.array([0.0, 1.0])},
            "default": {"force": 0.5, "target": np.array([0.5, 0.5])},
        }
        self.failure_jitter_sigma = 0.5
        self.success_threshold = 1.5

    def _define_inputs(self) -> Dict[str, str]:
        return {"sememe_action": "str", "rlm_hidden": "List"}

    def _define_outputs(self) -> Dict[str, str]:
        return {"kinematics": "Dict", "env_state": "Dict"}

    def _flatten_hidden(self, hidden_mod: List[Any]) -> List[float]:
        flat: List[float] = []
        for layer in hidden_mod:
            if isinstance(layer, list):
                for batch in layer:
                    if isinstance(batch, list):
                        flat.extend(batch)
                    else:
                        flat.append(batch)
            else:
                flat.append(layer)
        return flat

    def _simulate_action(self, action: str, hidden_mod: List[Any]) -> Dict[str, Any]:
        params = self.action_map.get(action, self.action_map["default"])
        target = params["target"]

        flat_hidden = self._flatten_hidden(hidden_mod) if hidden_mod else []
        hidden_mean = float(np.mean(flat_hidden)) if flat_hidden else 0.0

        arm_start = np.array([0.0, 0.0]) + hidden_mean * 0.1
        base_direction = target - arm_start
        jitter = np.random.normal(0.0, self.failure_jitter_sigma, size=2)
        direction = base_direction + jitter
        distance = float(np.linalg.norm(direction)) or 1e-6

        trajectory = np.linspace(arm_start, target, 10).tolist()
        joint_angles = float(np.arctan2(direction[1], direction[0]))
        force_vector = (direction / distance * params["force"]).tolist()
        success = distance < self.success_threshold

        self.state.grasp_success_rate = float(success) * 100.0

        return {
            "trajectory": trajectory,
            "joint_angles": joint_angles,
            "force_vector": force_vector,
            "success": success,
        }

    def process(self, inputs: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        action = inputs.get("sememe_action", "default")
        hidden = inputs.get("rlm_hidden", [])
        kinematics = self._simulate_action(action, hidden)
        env_state = {"object_pos": kinematics["trajectory"][-1], "grasped": kinematics["success"]}
        self.state._init_metrics()
        outputs = {"kinematics": kinematics, "env_state": env_state}
        return outputs, self.state.get_metrics()


# ───────────────────────────────────────────
# C1: COGNITION MODULE
# ───────────────────────────────────────────

class CognitionModule(ELEModuleBase):
    def __init__(self, name: str, state_cls: Callable[[], ELEState]) -> None:
        super().__init__(name, state_cls)
        self.sim_module = SensorimotorSimModule()

    def _define_inputs(self) -> Dict[str, str]:
        return {"sememes": "str", "rlm_hidden": "List"}

    def _define_outputs(self) -> Dict[str, str]:
        return {
            "grounded_concepts": "Dict",
            "chunks": "List[str]",
            "metarules": "List[Callable]",
            "sim_env_state": "Dict",
            "disruption_index": "float",
        }

    def process(self, inputs: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        sem = inputs.get("sememes", "") or "dummy_sememe"
        rlm_hidden = inputs.get("rlm_hidden", [])

        lower = sem.lower()
        if "grasp" in lower:
            action = "grasp"
        elif "triangle" in lower:
            action = "manipulate_triangle"
        else:
            action = "default"

        sim_inputs = {"sememe_action": action, "rlm_hidden": rlm_hidden}
        sim_out, _ = self.sim_module.process(sim_inputs, context)
        kinematics = sim_out["kinematics"]
        env_state = sim_out["env_state"]

        traj_len = max(2, len(kinematics["trajectory"]))
        step = max(1, traj_len // 2)
        self.state.chunks = [sem[i:i + step] for i in range(0, len(sem), step)]

        self.state.concept_graph[sem] = [
            f"traj_step_{i}: {pos}" for i, pos in enumerate(kinematics["trajectory"][:3])
        ]
        self.state.concept_graph[sem].append(f"angle: {kinematics['joint_angles']:.2f}")

        flat_hidden: List[float] = []
        for layer in rlm_hidden:
            if isinstance(layer, list):
                for batch in layer:
                    if isinstance(batch, list):
                        flat_hidden.extend(batch)
                    else:
                        flat_hidden.append(batch)
            else:
                flat_hidden.append(layer)
        hidden_mean = float(np.mean(flat_hidden)) if flat_hidden else 1.0
        force_x = kinematics["force_vector"][0]

        def metarule(x: float) -> str:
            return f"Gen:{x:.2f} * {force_x:.2f}"

        grounded = {
            sem: {
                "sim_data": kinematics,
                "action": f"{action} (success: {env_state['grasped']})",
                "metarule_applied": metarule(hidden_mean),
            }
        }

        disruption = 0.0
        if not env_state["grasped"]:
            disruption += 1.0
        if len(self.state.chunks) > self.state.chunk_capacity:
            disruption += 0.5
        if len(sem.strip()) == 0:
            disruption += 1.0

        self.state.disruption_index = disruption
        for m in self.state.metrics:
            if m.name == "Disruption Index":
                m.norm = disruption
                m.current = disruption

        self.state._init_metrics()
        outputs = {
            "grounded_concepts": grounded,
            "chunks": self.state.chunks,
            "metarules": [metarule],
            "sim_env_state": env_state,
            "disruption_index": disruption,
        }
        return outputs, self.state.get_metrics()


# ───────────────────────────────────────────
# C2: COMMUNICATION MODULE
# ───────────────────────────────────────────

class CommunicationModule(ELEModuleBase):
    def __init__(self, name: str, state_cls: Callable[[], ELEState]) -> None:
        super().__init__(name, state_cls)

    def _define_inputs(self) -> Dict[str, str]:
        return {"utterance_plan": "Dict", "social_ctx": "Dict",
                "sim_env_state": "Dict", "disruption_index": "float"}

    def _define_outputs(self) -> Dict[str, str]:
        return {"pragmemes": "str", "tom_inferences": "Dict",
                "repair_strategy": "str", "feedback_to_lower": "Dict"}

    def process(self, inputs: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        plan = inputs.get("utterance_plan", {})
        sim_state = inputs.get("sim_env_state", {})
        disruption_index = float(inputs.get("disruption_index", 0.0))
        social_ctx = inputs.get("social_ctx", {"belief": "neutral", "accent_profile": "neutral"})

        grasped = bool(sim_state.get("grasped", True))
        base_belief = social_ctx.get("belief", "neutral")
        accent_profile = social_ctx.get("accent_profile", "neutral")
        self.state.speaker_accent = accent_profile

        if not grasped:
            inferred_belief = "failure_frustration"
        else:
            inferred_belief = base_belief

        self.state.tom_beliefs["interlocutor"] = f"Believes:{inferred_belief}"

        if not grasped or disruption_index > 0.5:
            norm = 0.7
            repair_strategy = "simplify"
        else:
            norm = 0.9
            repair_strategy = "none"

        self.state.norm_level = norm
        self.state.repair_strategy = repair_strategy

        raw_sem = plan.get("sememes", "hello")
        self.state.pragmemes = f"Prag*{raw_sem}* (norm:{norm:.2f}, repair:{repair_strategy})"

        context_mod = "whisper" if norm < 0.8 else "normal"
        feedback = {"context_mod": context_mod, "accent_profile": accent_profile}

        self.state._init_metrics()
        outputs = {
            "pragmemes": self.state.pragmemes,
            "tom_inferences": dict(self.state.tom_beliefs),
            "repair_strategy": repair_strategy,
            "feedback_to_lower": feedback,
        }
        return outputs, self.state.get_metrics()


# ───────────────────────────────────────────
# ELEngine: FULL ORCHESTRATION + ROBUST LOOP
# ───────────────────────────────────────────

class ELEngine:
    def __init__(self) -> None:
        self.modules: Dict[ELEModule, ELEModuleBase] = {
            ELEModule.PHYSICS: PhysicsModule("P1", PhysicsState),
            ELEModule.PHYSIOLOGY: PhysiologyModule("P2", PhysiologyState),
            ELEModule.LINGUISTICS: LinguisticsModule("L", LinguisticsState),
            ELEModule.COGNITION: CognitionModule("C1", CognitionState),
            ELEModule.COMMUNICATION: CommunicationModule("C2", CommunicationState),
        }
        self.global_context: Dict[str, Any] = {"recursion_depth": 2, "env_model": {}}
        self.trace: List[Dict[str, Any]] = []

    # Core forward + feedback

    def _forward_chain(self, current_inputs: Dict[str, Any]) -> Dict[str, Any]:
        outputs = dict(current_inputs)
        chain_trace: Dict[str, Any] = {}

        # P1
        p1_mod = self.modules[ELEModule.PHYSICS]
        p1_out, p1_metrics = p1_mod.process(outputs, self.global_context)
        outputs.update(p1_out)
        chain_trace[ELEModule.PHYSICS.value] = {"outputs": p1_out, "metrics": p1_metrics}

        # P2
        p2_mod = self.modules[ELEModule.PHYSIOLOGY]
        social_ctx = outputs.get("social_ctx", {"accent_profile": "neutral"})
        p2_inputs = {
            "acoustic_envelope": outputs.get("acoustic_envelope"),
            "context_mod": outputs.get("context_mod", "normal"),
            "accent_profile": social_ctx.get("accent_profile", "neutral"),
        }
        p2_out, p2_metrics = p2_mod.process(p2_inputs, self.global_context)
        outputs.update(p2_out)
        chain_trace[ELEModule.PHYSIOLOGY.value] = {"outputs": p2_out, "metrics": p2_metrics}

        # L
        l_mod = self.modules[ELEModule.LINGUISTICS]
        l_inputs = {"raw_signal": outputs.get("raw_signal"),
                    "breath_groups": outputs.get("breath_groups")}
        l_out, l_metrics = l_mod.process(l_inputs, self.global_context)
        outputs.update(l_out)
        chain_trace[ELEModule.LINGUISTICS.value] = {"outputs": l_out, "metrics": l_metrics}

        # C1
        c1_mod = self.modules[ELEModule.COGNITION]
        c1_inputs = {"sememes": outputs.get("sememes"),
                     "rlm_hidden": outputs.get("rlm_hidden")}
        c1_out, c1_metrics = c1_mod.process(c1_inputs, self.global_context)
        outputs.update(c1_out)
        chain_trace[ELEModule.COGNITION.value] = {"outputs": c1_out, "metrics": c1_metrics}

        # C2
        c2_mod = self.modules[ELEModule.COMMUNICATION]
        c2_inputs = {"utterance_plan": outputs,
                     "social_ctx": social_ctx,
                     "sim_env_state": outputs.get("sim_env_state"),
                     "disruption_index": outputs.get("disruption_index")}
        c2_out, c2_metrics = c2_mod.process(c2_inputs, self.global_context)
        outputs.update(c2_out)
        chain_trace[ELEModule.COMMUNICATION.value] = {"outputs": c2_out, "metrics": c2_metrics}

        self.trace.append(chain_trace)
        return outputs

    def _apply_feedback(self, current_outputs: Dict[str, Any]) -> Dict[str, Any]:
        feedback = current_outputs.get("feedback_to_lower", {})
        context_mod = feedback.get("context_mod", "normal")
        accent_profile = feedback.get("accent_profile", "neutral")

        # P2 re-run
        p2_mod = self.modules[ELEModule.PHYSIOLOGY]
        p2_inputs = {
            "acoustic_envelope": current_outputs.get("acoustic_envelope"),
            "context_mod": context_mod,
            "accent_profile": accent_profile,
        }
        p2_out, p2_metrics = p2_mod.process(p2_inputs, self.global_context)
        current_outputs.update(p2_out)

        # L re-run
        l_mod = self.modules[ELEModule.LINGUISTICS]
        l_inputs = {"raw_signal": current_outputs.get("raw_signal"),
                    "breath_groups": current_outputs.get("breath_groups")}
        l_out, l_metrics = l_mod.process(l_inputs, self.global_context)
        current_outputs.update(l_out)

        # C1 re-run
        c1_mod = self.modules[ELEModule.COGNITION]
        c1_inputs = {"sememes": current_outputs.get("sememes"),
                     "rlm_hidden": current_outputs.get("rlm_hidden")}
        c1_out, c1_metrics = c1_mod.process(c1_inputs, self.global_context)
        current_outputs.update(c1_out)

        # C2 re-run
        c2_mod = self.modules[ELEModule.COMMUNICATION]
        social_ctx = current_outputs.get("social_ctx", {"accent_profile": "neutral"})
        c2_inputs = {"utterance_plan": current_outputs,
                     "social_ctx": social_ctx,
                     "sim_env_state": current_outputs.get("sim_env_state"),
                     "disruption_index": current_outputs.get("disruption_index")}
        c2_out, c2_metrics = c2_mod.process(c2_inputs, self.global_context)
        current_outputs.update(c2_out)

        self.trace.append({
            ELEModule.PHYSIOLOGY.value: {"outputs": p2_out, "metrics": p2_metrics},
            ELEModule.LINGUISTICS.value: {"outputs": l_out, "metrics": l_metrics},
            ELEModule.COGNITION.value: {"outputs": c1_out, "metrics": c1_metrics},
            ELEModule.COMMUNICATION.value: {"outputs": c2_out, "metrics": c2_metrics},
        })

        return current_outputs

    # Summaries/metrics

    def _generate_concerns_summary(self) -> Dict[str, str]:
        return {
            "Grounding": "P1/P2 model aerodynamics and laryngeal control; C1 simulates embodied action from sememes.",
            "Symbolic Structure": "L exposes phon/morph/lex/sem descriptors on top of a recursive GRU RLM.",
            "Timescale/Hierarchy": "Recursion depth and module layering implement multi-scale language dynamics.",
            "Interaction-Dominance": "C2 uses sim_env_state and disruption_index to modulate P2; reruns L/C1/C2.",
            "ToM/Pragmatics": "C2 infers beliefs from physical outcomes; selects repair_strategy and norm_level.",
            "Cognitive Disruptions": "C1 computes disruption_index; C2 reacts via repair policies (simplify, whisper).",
            "Accent Handling": "P2 and C2 encode accent_profile for style without breaking sememe invariance.",
        }

    def get_all_metrics(self) -> Dict[str, Dict[str, Dict[str, Any]]]:
        return {m.value: self.modules[m].state.get_metrics() for m in ELEModule}

    def get_api_contracts(self) -> Dict[str, Any]:
        return {m.value: self.modules[m].get_api_contract() for m in ELEModule}

    # Coherence & error handling

    def _check_coherence(self, result: Dict[str, Any]) -> Tuple[bool, List[str]]:
        categories: List[str] = []
        final_outputs = result.get("final_outputs", {})
        sim_env_state = final_outputs.get("sim_env_state", {})
        grasped = bool(sim_env_state.get("grasped", True))
        disruption = float(final_outputs.get("disruption_index", 0.0))
        error_flag = final_outputs.get("error", None)

        if error_flag is not None:
            categories.append("engine_error")
        if not grasped:
            categories.append("grounding_failure")
        if disruption > 0.5:
            categories.append("high_disruption")

        coherent = not categories
        return coherent, categories

    def _categorize_exception(self, e: Exception) -> str:
        name = type(e).__name__
        if "RuntimeError" in name:
            return "rlm_error"
        elif "ValueError" in name or "ZeroDivisionError" in name:
            return "sim_error"
        else:
            return "unknown_error"

    def _make_error_result(self, e: Exception, category: str) -> Dict[str, Any]:
        err_info = {
            "error_type": type(e).__name__,
            "error_msg": str(e),
            "error_category": category,
        }
        return {
            "final_outputs": {"error": err_info},
            "last_trace": self.trace[-1] if self.trace else {},
            "all_metrics": self.get_all_metrics(),
            "api_contracts": self.get_api_contracts(),
            "concerns_addressed": self._generate_concerns_summary(),
        }

    # Single cycle

    def process(
        self,
        initial_intent: str,
        recursion_depth: int = 2,
        social_ctx: Optional[Dict[str, Any]] = None,
        apply_feedback: bool = True,
    ) -> Dict[str, Any]:
        self.global_context["recursion_depth"] = recursion_depth
        self.trace = []

        init_inputs = {
            "utterance_intent": initial_intent,
            "social_ctx": social_ctx or {"belief": "neutral", "accent_profile": "neutral"},
        }

        forward_out = self._forward_chain(init_inputs)
        final_out = forward_out

        if apply_feedback and "feedback_to_lower" in forward_out:
            final_out = self._apply_feedback(forward_out)

        synthesis = {
            "final_outputs": final_out,
            "last_trace": self.trace[-1] if self.trace else {},
            "all_metrics": self.get_all_metrics(),
            "api_contracts": self.get_api_contracts(),
            "concerns_addressed": self._generate_concerns_summary(),
        }
        return synthesis

    # Robust recursion / auto-repair

    def robust_process(
        self,
        initial_intent: str,
        base_recursion_depth: int = 2,
        social_ctx: Optional[Dict[str, Any]] = None,
        max_attempts: int = 3,
    ) -> Dict[str, Any]:
        attempts = 0
        last_result: Optional[Dict[str, Any]] = None
        social_ctx = social_ctx or {"belief": "neutral", "accent_profile": "neutral"}
        recursion_depth = base_recursion_depth

        while attempts < max_attempts:
            try:
                result = self.process(
                    initial_intent=initial_intent,
                    recursion_depth=recursion_depth,
                    social_ctx=social_ctx,
                    apply_feedback=True,
                )
                last_result = result
                coherent, categories = self._check_coherence(result)

                if coherent:
                    result["coherence"] = {
                        "ok": True,
                        "categories": categories,
                        "attempts": attempts + 1,
                    }
                    return result

                # Not coherent: adjust and rerun
                recursion_depth = max(1, recursion_depth - 1)
                if "grounding_failure" in categories:
                    social_ctx = {**social_ctx, "belief": "cautious"}

                attempts += 1

            except Exception as e:
                category = self._categorize_exception(e)
                recursion_depth = 1
                social_ctx = {"belief": "safe_mode", "accent_profile": "neutral"}
                attempts += 1
                if attempts >= max_attempts:
                    return self._make_error_result(e, category)

        if last_result is None:
            return {
                "final_outputs": {"error": {"error_type": "NoResult",
                                            "error_msg": "No result after attempts."}},
                "last_trace": {},
                "all_metrics": {},
                "api_contracts": self.get_api_contracts(),
                "concerns_addressed": self._generate_concerns_summary(),
                "coherence": {"ok": False, "categories": ["no_result"], "attempts": attempts},
            }

        coherent, categories = self._check_coherence(last_result)
        last_result["coherence"] = {"ok": coherent, "categories": categories, "attempts": attempts}
        return last_result


# ───────────────────────────────────────────
# DEMO (OPTIONAL)
# ───────────────────────────────────────────

if __name__ == "__main__":
    engine = ELEngine()

    print("=== ELE robust_process demo: 'grasp the concept of recursion' (soft accent) ===")
    result1 = engine.robust_process(
        "grasp the concept of recursion",
        base_recursion_depth=2,
        social_ctx={"belief": "neutral", "accent_profile": "soft"},
        max_attempts=3,
    )
    print(result1["final_outputs"].get("pragmemes", "N/A"))
    print("Coherence:", result1.get("coherence", {}))

    print("\n=== ELE robust_process demo: 'manipulate the small triangle' (harsh accent) ===")
    result2 = engine.robust_process(
        "manipulate the small triangle",
        base_recursion_depth=2,
        social_ctx={"belief": "neutral", "accent_profile": "harsh"},
        max_attempts=3,
    )
    print(result2["final_outputs"].get("pragmemes", "N/A"))
    print("Coherence:", result2.get("coherence", {}))