RpcHandlerDispatcher.java

/*---------------------------------------------------------------------------------------------
 *  Copyright (c) Microsoft Corporation. All rights reserved.
 *--------------------------------------------------------------------------------------------*/

package com.github.copilot.sdk;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.logging.Level;
import java.util.logging.Logger;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.copilot.sdk.events.AbstractSessionEvent;
import com.github.copilot.sdk.events.SessionEventParser;
import com.github.copilot.sdk.json.PermissionRequestResult;
import com.github.copilot.sdk.json.SessionLifecycleEvent;
import com.github.copilot.sdk.json.SessionLifecycleEventMetadata;
import com.github.copilot.sdk.json.ToolDefinition;
import com.github.copilot.sdk.json.ToolInvocation;
import com.github.copilot.sdk.json.ToolResultObject;
import com.github.copilot.sdk.json.UserInputRequest;

/**
 * Dispatches incoming JSON-RPC method calls to the appropriate handlers.
 * <p>
 * This class handles all server-to-client RPC calls including:
 * <ul>
 * <li>Session events</li>
 * <li>Tool calls</li>
 * <li>Permission requests</li>
 * <li>User input requests</li>
 * <li>Hooks invocations</li>
 * <li>Lifecycle events</li>
 * </ul>
 */
final class RpcHandlerDispatcher {

    private static final Logger LOG = Logger.getLogger(RpcHandlerDispatcher.class.getName());
    private static final ObjectMapper MAPPER = JsonRpcClient.getObjectMapper();

    private final Map<String, CopilotSession> sessions;
    private final LifecycleEventDispatcher lifecycleDispatcher;

    /**
     * Creates a dispatcher with session registry and lifecycle dispatcher.
     *
     * @param sessions
     *            the session registry to look up sessions by ID
     * @param lifecycleDispatcher
     *            callback for dispatching lifecycle events
     */
    RpcHandlerDispatcher(Map<String, CopilotSession> sessions, LifecycleEventDispatcher lifecycleDispatcher) {
        this.sessions = sessions;
        this.lifecycleDispatcher = lifecycleDispatcher;
    }

    /**
     * Registers all RPC method handlers with the given JSON-RPC client.
     *
     * @param rpc
     *            the JSON-RPC client to register handlers with
     */
    void registerHandlers(JsonRpcClient rpc) {
        rpc.registerMethodHandler("session.event", (requestId, params) -> handleSessionEvent(params));
        rpc.registerMethodHandler("session.lifecycle", (requestId, params) -> handleLifecycleEvent(params));
        rpc.registerMethodHandler("tool.call", (requestId, params) -> handleToolCall(rpc, requestId, params));
        rpc.registerMethodHandler("permission.request",
                (requestId, params) -> handlePermissionRequest(rpc, requestId, params));
        rpc.registerMethodHandler("userInput.request",
                (requestId, params) -> handleUserInputRequest(rpc, requestId, params));
        rpc.registerMethodHandler("hooks.invoke", (requestId, params) -> handleHooksInvoke(rpc, requestId, params));
    }

    private void handleSessionEvent(JsonNode params) {
        try {
            String sessionId = params.get("sessionId").asText();
            JsonNode eventNode = params.get("event");
            LOG.fine("Received session.event: " + eventNode);

            CopilotSession session = sessions.get(sessionId);
            if (session != null && eventNode != null) {
                AbstractSessionEvent event = SessionEventParser.parse(eventNode);
                if (event != null) {
                    session.dispatchEvent(event);
                }
            }
        } catch (Exception e) {
            LOG.log(Level.SEVERE, "Error handling session event", e);
        }
    }

    private void handleLifecycleEvent(JsonNode params) {
        try {
            String type = params.has("type") ? params.get("type").asText() : "";
            String sessionId = params.has("sessionId") ? params.get("sessionId").asText() : "";

            SessionLifecycleEvent event = new SessionLifecycleEvent();
            event.setType(type);
            event.setSessionId(sessionId);

            if (params.has("metadata") && !params.get("metadata").isNull()) {
                SessionLifecycleEventMetadata metadata = MAPPER.treeToValue(params.get("metadata"),
                        SessionLifecycleEventMetadata.class);
                event.setMetadata(metadata);
            }

            lifecycleDispatcher.dispatch(event);
        } catch (Exception e) {
            LOG.log(Level.SEVERE, "Error handling session lifecycle event", e);
        }
    }

    private void handleToolCall(JsonRpcClient rpc, String requestId, JsonNode params) {
        CompletableFuture.runAsync(() -> {
            try {
                String sessionId = params.get("sessionId").asText();
                String toolCallId = params.get("toolCallId").asText();
                String toolName = params.get("toolName").asText();
                JsonNode arguments = params.get("arguments");

                CopilotSession session = sessions.get(sessionId);
                if (session == null) {
                    rpc.sendErrorResponse(Long.parseLong(requestId), -32602, "Unknown session " + sessionId);
                    return;
                }

                ToolDefinition tool = session.getTool(toolName);
                if (tool == null || tool.handler() == null) {
                    var result = ToolResultObject.failure("Tool '" + toolName + "' is not supported.",
                            "tool '" + toolName + "' not supported");
                    rpc.sendResponse(Long.parseLong(requestId), Map.of("result", result));
                    return;
                }

                var invocation = new ToolInvocation().setSessionId(sessionId).setToolCallId(toolCallId)
                        .setToolName(toolName).setArguments(arguments);

                tool.handler().invoke(invocation).thenAccept(result -> {
                    try {
                        ToolResultObject toolResult;
                        if (result instanceof ToolResultObject tr) {
                            toolResult = tr;
                        } else {
                            toolResult = ToolResultObject
                                    .success(result instanceof String s ? s : MAPPER.writeValueAsString(result));
                        }
                        rpc.sendResponse(Long.parseLong(requestId), Map.of("result", toolResult));
                    } catch (Exception e) {
                        LOG.log(Level.SEVERE, "Error sending tool result", e);
                    }
                }).exceptionally(ex -> {
                    try {
                        var result = ToolResultObject.failure(
                                "Invoking this tool produced an error. Detailed information is not available.",
                                ex.getMessage());
                        rpc.sendResponse(Long.parseLong(requestId), Map.of("result", result));
                    } catch (Exception e) {
                        LOG.log(Level.SEVERE, "Error sending tool error", e);
                    }
                    return null;
                });
            } catch (Exception e) {
                LOG.log(Level.SEVERE, "Error handling tool call", e);
                try {
                    rpc.sendErrorResponse(Long.parseLong(requestId), -32603, e.getMessage());
                } catch (IOException ioe) {
                    LOG.log(Level.SEVERE, "Failed to send error response", ioe);
                }
            }
        });
    }

    private void handlePermissionRequest(JsonRpcClient rpc, String requestId, JsonNode params) {
        CompletableFuture.runAsync(() -> {
            try {
                String sessionId = params.get("sessionId").asText();
                JsonNode permissionRequest = params.get("permissionRequest");

                CopilotSession session = sessions.get(sessionId);
                if (session == null) {
                    var result = new PermissionRequestResult()
                            .setKind("denied-no-approval-rule-and-could-not-request-from-user");
                    rpc.sendResponse(Long.parseLong(requestId), Map.of("result", result));
                    return;
                }

                session.handlePermissionRequest(permissionRequest).thenAccept(result -> {
                    try {
                        rpc.sendResponse(Long.parseLong(requestId), Map.of("result", result));
                    } catch (IOException e) {
                        LOG.log(Level.SEVERE, "Error sending permission result", e);
                    }
                }).exceptionally(ex -> {
                    try {
                        var result = new PermissionRequestResult()
                                .setKind("denied-no-approval-rule-and-could-not-request-from-user");
                        rpc.sendResponse(Long.parseLong(requestId), Map.of("result", result));
                    } catch (IOException e) {
                        LOG.log(Level.SEVERE, "Error sending permission denied", e);
                    }
                    return null;
                });
            } catch (Exception e) {
                LOG.log(Level.SEVERE, "Error handling permission request", e);
            }
        });
    }

    private void handleUserInputRequest(JsonRpcClient rpc, String requestId, JsonNode params) {
        LOG.fine("Received userInput.request: " + params);
        CompletableFuture.runAsync(() -> {
            try {
                String sessionId = params.get("sessionId").asText();
                String question = params.get("question").asText();
                LOG.fine("Processing userInput for session " + sessionId + ", question: " + question);
                JsonNode choicesNode = params.get("choices");
                JsonNode allowFreeformNode = params.get("allowFreeform");

                CopilotSession session = sessions.get(sessionId);
                LOG.fine("Found session: " + (session != null));
                if (session == null) {
                    LOG.fine("Session not found, sending error");
                    rpc.sendErrorResponse(Long.parseLong(requestId), -32602, "Unknown session " + sessionId);
                    return;
                }

                var request = new UserInputRequest().setQuestion(question);
                if (choicesNode != null && choicesNode.isArray()) {
                    var choices = new ArrayList<String>();
                    for (JsonNode choice : choicesNode) {
                        choices.add(choice.asText());
                    }
                    request.setChoices(choices);
                }
                if (allowFreeformNode != null) {
                    request.setAllowFreeform(allowFreeformNode.asBoolean());
                }

                session.handleUserInputRequest(request).thenAccept(response -> {
                    try {
                        // Ensure answer is never null - CLI requires a non-null string
                        String answer = response.getAnswer() != null ? response.getAnswer() : "";
                        LOG.fine("Sending userInput response: answer=" + answer + ", wasFreeform="
                                + response.isWasFreeform());
                        rpc.sendResponse(Long.parseLong(requestId),
                                Map.of("answer", answer, "wasFreeform", response.isWasFreeform()));
                    } catch (IOException e) {
                        LOG.log(Level.SEVERE, "Error sending user input response", e);
                    }
                }).exceptionally(ex -> {
                    LOG.log(Level.WARNING, "User input handler exception", ex);
                    try {
                        rpc.sendErrorResponse(Long.parseLong(requestId), -32603,
                                "User input handler error: " + ex.getMessage());
                    } catch (IOException e) {
                        LOG.log(Level.SEVERE, "Error sending user input error", e);
                    }
                    return null;
                });
            } catch (Exception e) {
                LOG.log(Level.SEVERE, "Error handling user input request", e);
            }
        });
    }

    private void handleHooksInvoke(JsonRpcClient rpc, String requestId, JsonNode params) {
        CompletableFuture.runAsync(() -> {
            try {
                String sessionId = params.get("sessionId").asText();
                String hookType = params.get("hookType").asText();
                JsonNode input = params.get("input");

                CopilotSession session = sessions.get(sessionId);
                if (session == null) {
                    rpc.sendErrorResponse(Long.parseLong(requestId), -32602, "Unknown session " + sessionId);
                    return;
                }

                session.handleHooksInvoke(hookType, input).thenAccept(output -> {
                    try {
                        rpc.sendResponse(Long.parseLong(requestId), Collections.singletonMap("output", output));
                    } catch (IOException e) {
                        LOG.log(Level.SEVERE, "Error sending hooks response", e);
                    }
                }).exceptionally(ex -> {
                    try {
                        rpc.sendErrorResponse(Long.parseLong(requestId), -32603,
                                "Hooks handler error: " + ex.getMessage());
                    } catch (IOException e) {
                        LOG.log(Level.SEVERE, "Error sending hooks error", e);
                    }
                    return null;
                });
            } catch (Exception e) {
                LOG.log(Level.SEVERE, "Error handling hooks invoke", e);
            }
        });
    }

    /**
     * Functional interface for dispatching lifecycle events.
     */
    @FunctionalInterface
    interface LifecycleEventDispatcher {

        void dispatch(SessionLifecycleEvent event);
    }
}