JsonRpcClient.java
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
*--------------------------------------------------------------------------------------------*/
package com.github.copilot.sdk;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiConsumer;
import java.util.logging.Level;
import java.util.logging.Logger;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.github.copilot.sdk.json.JsonRpcError;
import com.github.copilot.sdk.json.JsonRpcRequest;
import com.github.copilot.sdk.json.JsonRpcResponse;
/**
* JSON-RPC 2.0 client implementation for communicating with the Copilot CLI.
*
* @since 1.0.0
*/
class JsonRpcClient implements AutoCloseable {
private static final Logger LOG = Logger.getLogger(JsonRpcClient.class.getName());
private static final ObjectMapper MAPPER = createObjectMapper();
private final InputStream inputStream;
private final OutputStream outputStream;
private final Socket socket;
private final Process process;
private final AtomicLong requestIdCounter = new AtomicLong(0);
private final Map<Long, CompletableFuture<JsonNode>> pendingRequests = new ConcurrentHashMap<>();
private final Map<String, BiConsumer<String, JsonNode>> notificationHandlers = new ConcurrentHashMap<>();
private final ExecutorService readerExecutor;
private volatile boolean running = true;
private JsonRpcClient(InputStream inputStream, OutputStream outputStream, Socket socket, Process process) {
this.inputStream = inputStream;
this.outputStream = outputStream;
this.socket = socket;
this.process = process;
this.readerExecutor = Executors.newSingleThreadExecutor(r -> {
Thread t = new Thread(r, "jsonrpc-reader");
t.setDaemon(true);
return t;
});
startReader();
}
static ObjectMapper createObjectMapper() {
var mapper = new ObjectMapper();
mapper.registerModule(new JavaTimeModule());
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
mapper.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false);
mapper.setDefaultPropertyInclusion(JsonInclude.Include.NON_NULL);
return mapper;
}
public static ObjectMapper getObjectMapper() {
return MAPPER;
}
/**
* Creates a JSON-RPC client using stdio with a process.
*/
public static JsonRpcClient fromProcess(Process process) {
return new JsonRpcClient(process.getInputStream(), process.getOutputStream(), null, process);
}
/**
* Creates a JSON-RPC client using TCP socket.
*/
public static JsonRpcClient fromSocket(Socket socket) throws IOException {
return new JsonRpcClient(socket.getInputStream(), socket.getOutputStream(), socket, null);
}
/**
* Registers a handler for JSON-RPC method calls (requests/notifications from
* server).
*/
public void registerMethodHandler(String method, BiConsumer<String, JsonNode> handler) {
notificationHandlers.put(method, handler);
}
/**
* Sends a JSON-RPC request and waits for the response.
*/
public <T> CompletableFuture<T> invoke(String method, Object params, Class<T> responseType) {
long id = requestIdCounter.incrementAndGet();
var future = new CompletableFuture<JsonNode>();
pendingRequests.put(id, future);
var request = new JsonRpcRequest();
request.setJsonrpc("2.0");
request.setId(id);
request.setMethod(method);
request.setParams(params);
try {
sendMessage(request);
} catch (IOException e) {
pendingRequests.remove(id);
future.completeExceptionally(e);
}
return future.thenApply(result -> {
try {
if (responseType == Void.class || responseType == void.class) {
return null;
}
return MAPPER.treeToValue(result, responseType);
} catch (JsonProcessingException e) {
throw new CompletionException(e);
}
});
}
/**
* Sends a JSON-RPC notification (no response expected).
*/
public void notify(String method, Object params) throws IOException {
var notification = new JsonRpcRequest();
notification.setJsonrpc("2.0");
notification.setMethod(method);
notification.setParams(params);
sendMessage(notification);
}
/**
* Sends a JSON-RPC response to a server request.
*/
public void sendResponse(Object id, Object result) throws IOException {
var response = new JsonRpcResponse();
response.setJsonrpc("2.0");
response.setId(id);
response.setResult(result);
sendMessage(response);
}
/**
* Sends a JSON-RPC error response to a server request.
*/
public void sendErrorResponse(Object id, int code, String message) throws IOException {
var response = new JsonRpcResponse();
response.setJsonrpc("2.0");
response.setId(id);
var error = new JsonRpcError();
error.setCode(code);
error.setMessage(message);
response.setError(error);
sendMessage(response);
}
private synchronized void sendMessage(Object message) throws IOException {
String json = MAPPER.writeValueAsString(message);
byte[] content = json.getBytes(StandardCharsets.UTF_8);
String header = "Content-Length: " + content.length + "\r\n\r\n";
outputStream.write(header.getBytes(StandardCharsets.UTF_8));
outputStream.write(content);
outputStream.flush();
LOG.fine("Sent: " + json);
}
private void startReader() {
readerExecutor.submit(() -> {
try {
// We need to read bytes because Content-Length specifies bytes, not characters.
// Using BufferedReader would cause issues with multi-byte UTF-8 characters.
var bis = new BufferedInputStream(inputStream);
while (running) {
// Read headers line by line
int contentLength = -1;
var headerLine = new StringBuilder();
boolean lastWasCR = false;
boolean inHeaders = true;
while (inHeaders) {
int b = bis.read();
if (b == -1) {
return;
}
if (b == '\r') {
lastWasCR = true;
} else if (b == '\n') {
String line = headerLine.toString();
headerLine.setLength(0);
lastWasCR = false;
if (line.isEmpty()) {
// End of headers (blank line)
inHeaders = false;
} else if (line.toLowerCase().startsWith("content-length:")) {
contentLength = Integer.parseInt(line.substring(15).trim());
}
} else {
if (lastWasCR) {
headerLine.append('\r');
lastWasCR = false;
}
headerLine.append((char) b);
}
}
if (contentLength <= 0) {
continue;
}
// Read content as bytes (Content-Length specifies bytes, not characters)
byte[] buffer = new byte[contentLength];
int read = 0;
while (read < contentLength) {
int result = bis.read(buffer, read, contentLength - read);
if (result == -1) {
return;
}
read += result;
}
String content = new String(buffer, StandardCharsets.UTF_8);
LOG.fine("Received: " + content);
handleMessage(content);
}
} catch (Exception e) {
if (running) {
LOG.log(Level.SEVERE, "Error in JSON-RPC reader", e);
}
}
});
}
private void handleMessage(String content) {
try {
JsonNode node = MAPPER.readTree(content);
// Check if this is a response to our request
if (node.has("id") && !node.get("id").isNull() && (node.has("result") || node.has("error"))) {
long id = node.get("id").asLong();
CompletableFuture<JsonNode> future = pendingRequests.remove(id);
if (future != null) {
if (node.has("error")) {
JsonNode errorNode = node.get("error");
String errorMessage = errorNode.has("message")
? errorNode.get("message").asText()
: "Unknown error";
int errorCode = errorNode.has("code") ? errorNode.get("code").asInt() : -1;
future.completeExceptionally(new JsonRpcException(errorCode, errorMessage));
} else {
future.complete(node.get("result"));
}
}
}
// Check if this is a request from server (has method and id)
else if (node.has("method")) {
String method = node.get("method").asText();
JsonNode params = node.get("params");
Object id = node.has("id") && !node.get("id").isNull() ? node.get("id") : null;
LOG.fine("Received method: " + method);
BiConsumer<String, JsonNode> handler = notificationHandlers.get(method);
if (handler != null) {
try {
// Create a context that includes the request ID for responses
handler.accept(id != null ? id.toString() : null, params);
} catch (Exception e) {
LOG.log(Level.SEVERE, "Error handling method " + method, e);
if (id != null) {
try {
sendErrorResponse(id, -32603, e.getMessage());
} catch (IOException ioe) {
LOG.log(Level.SEVERE, "Failed to send error response", ioe);
}
}
}
} else {
LOG.fine("No handler for method: " + method);
if (id != null) {
try {
sendErrorResponse(id, -32601, "Method not found: " + method);
} catch (IOException ioe) {
LOG.log(Level.SEVERE, "Failed to send error response", ioe);
}
}
}
}
} catch (JsonProcessingException e) {
LOG.log(Level.SEVERE, "Error parsing JSON-RPC message", e);
}
}
@Override
public void close() {
running = false;
readerExecutor.shutdownNow();
// Cancel all pending requests
pendingRequests.forEach((id, future) -> future.completeExceptionally(new IOException("Client closed")));
pendingRequests.clear();
try {
if (socket != null) {
socket.close();
}
} catch (IOException e) {
LOG.log(Level.FINE, "Error closing socket", e);
}
if (process != null) {
process.destroy();
}
}
public boolean isConnected() {
if (socket != null) {
return socket.isConnected() && !socket.isClosed();
}
if (process != null) {
return process.isAlive();
}
return false;
}
public Process getProcess() {
return process;
}
}