Add Custom Agents

If your agent is built with a framework we don't support, or you have a completely custom implementation, you can integrate it by implementing our Agent base class.

Requirements
Your custom agent must implement 4 required methods and return an AgentResult object for compatibility with the evaluation pipeline.

Required Methods

MethodInputOutputPurpose
initialize()NoneNoneCreate LLM client, call load_mcp_servers()
_create_mcp_server()MCPServerConfigAny (MCP client)Create framework-specific MCP client
run()str, Optional[Dict]AgentResultExecute agent, record trajectory, return result
cleanup()NoneNoneClose MCP connections, clean up resources

AgentResult (Required Return Type)

Important
The run() method MUST return an AgentResult object. Import it from dt_arena.src.types.agent.
from dt_arena.src.types.agent import AgentResult

# AgentResult fields:
@dataclass
class AgentResult:
    # Required fields (positional)
    final_output: Optional[str]      # Agent's final text response
    turn_count: int                  # Number of LLM calls executed
    trajectory: Optional[Trajectory] # Trajectory object with all steps

    # Optional fields (keyword-only)
    trace_id: Optional[str] = None   # Trace/session identifier
    duration: Optional[float] = None # Execution duration in seconds

# Example usage in run():
return AgentResult(
    final_output="Here are all the leads...",
    turn_count=3,
    trajectory=self.trajectory,
    trace_id="trace_abc123",
)

Full Implementation Example

Below is a complete example showing how to implement a custom agent:

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from dt_arena.src.types.agent import Agent, AgentConfig, RuntimeConfig, MCPServerConfig, AgentResult
from dt_arena.src.types.trajectory import Trajectory

class CustomAgent(Agent):
    """
    Custom agent implementation following the DTap interface.
    """

    def __init__(
        self,
        agent_config: AgentConfig,
        runtime_config: Optional[RuntimeConfig] = None,
    ):
        super().__init__(agent_config, runtime_config)
        self.client = None
        self.trajectory = None

    async def initialize(self) -> None:
        """Initialize the agent and connect to MCP servers."""
        # Initialize your LLM client
        self.client = YourLLMClient(
            model=self.runtime_config.model,
            temperature=self.runtime_config.temperature,
        )
        # Load and connect MCP servers from config
        await self.load_mcp_servers()
        # Initialize trajectory
        self.trajectory = Trajectory()

    def _create_mcp_server(self, config: MCPServerConfig) -> Any:
        """Create an MCP server instance for your framework."""
        return YourMCPClient(
            name=config.name,
            url=config.url,
            injections=self.runtime_config.mcp_injection.get(config.name, {})
        )

    async def run(
        self,
        user_input: str,
        metadata: Optional[Dict[str, Any]] = None
    ) -> AgentResult:
        """
        Execute the agent with user input.

        MUST return AgentResult with final_output, turn_count, and trajectory.
        """
        # Record user input
        self.trajectory.append_user_step(user_input, metadata or {})

        turns = 0
        final_output = None

        while turns < self.runtime_config.max_turns:
            response = await self.client.chat(
                messages=self._build_messages(),
                tools=self._get_available_tools()
            )
            turns += 1

            if response.tool_calls:
                for tool_call in response.tool_calls:
                    # Record agent action
                    self.trajectory.append_agent_step(
                        action=f"{tool_call.name}({tool_call.arguments})",
                        tool_name=tool_call.name,
                        tool_params=tool_call.arguments,
                    )
                    # Execute tool
                    result = await self._execute_tool(tool_call)
                    # Record tool result
                    self.trajectory.append_tool_return(
                        result=result,
                        tool_name=tool_call.name,
                    )
            else:
                final_output = response.content
                self.trajectory.append_agent_step(
                    action="send_message_to_user",
                    metadata={"message": final_output}
                )
                break

        # Save trajectory
        self.trajectory.save(self.runtime_config.output_dir, metadata=metadata)

        # MUST return AgentResult
        return AgentResult(
            final_output=final_output,
            turn_count=turns,
            trajectory=self.trajectory,
            trace_id=metadata.get("task_id") if metadata else None,
        )

    async def cleanup(self) -> None:
        """Clean up resources and close connections."""
        for server in self.mcp_servers:
            try:
                await server.close()
            except Exception:
                pass
        if self.client:
            await self.client.close()

Trajectory Format

Your agent must save trajectories in the standard format for evaluation:

{
  "task_info": {
    "task_id": "trace_abc123",
    "original_instruction": "List all leads in the CRM",
    "domain": "crm"
  },
  "traj_info": {
    "step_count": 4,
    "duration": 3.5,
    "agent_final_response": "Here are all the leads..."
  },
  "trajectory": [
    {"role": "user", "state": "List all leads", "step_id": 0},
    {"role": "agent", "action": "list_leads()", "step_id": 1},
    {"role": "tool", "state": [...], "step_id": 2},
    {"role": "agent", "action": "send_message_to_user", "step_id": 3}
  ]
}

Integration Checklist

  • Inherit from Agent base class
  • Implement all 4 required methods
  • Call load_mcp_servers() in initialize()
  • Return AgentResult from run()
  • Record all steps in trajectory (user, agent, tool)
  • Support async context manager (async with agent:)
  • Clean up all connections in cleanup()