From e0f092fc2fd6217d766176b35bc4d2dbf22a17f1 Mon Sep 17 00:00:00 2001 From: Riccardo Strina Date: Sun, 19 Apr 2026 22:52:35 +0200 Subject: [PATCH] Add jdt:// URI rewriting for documentation responses --- proxy/src/decompile.rs | 77 +++++++++++++++++++++++++++++++++++++ proxy/src/lsp.rs | 11 ++++++ proxy/src/main.rs | 86 +++++++++++++++++++++++++++++++----------- 3 files changed, 152 insertions(+), 22 deletions(-) diff --git a/proxy/src/decompile.rs b/proxy/src/decompile.rs index 1d9e6b6..c65a355 100644 --- a/proxy/src/decompile.rs +++ b/proxy/src/decompile.rs @@ -141,3 +141,80 @@ pub fn rewrite_jdt_locations( } rewritten } + +/// A jdt:// URI in embedded markdown/text terminates at whitespace or any of these +/// delimiters commonly used in markdown links and JSON strings. The URI itself only +/// contains URL-encoded forms of these characters, so scanning until we hit one of +/// them is safe. +fn jdt_uri_end(s: &str) -> usize { + s.find(|c: char| c.is_whitespace() || matches!(c, ')' | ']' | '"' | '>' | '`' | '\'')) + .unwrap_or(s.len()) +} + +/// Extract all unique `jdt://` URIs appearing inside any string in `value`. +fn collect_jdt_uris(value: &Value, out: &mut Vec) { + match value { + Value::String(s) => { + let mut rest = s.as_str(); + while let Some(pos) = rest.find("jdt://") { + let tail = &rest[pos..]; + let end = jdt_uri_end(tail); + let uri = tail[..end].to_string(); + if !out.contains(&uri) { + out.push(uri); + } + rest = &tail[end..]; + } + } + Value::Array(arr) => arr.iter().for_each(|v| collect_jdt_uris(v, out)), + Value::Object(obj) => obj.values().for_each(|v| collect_jdt_uris(v, out)), + _ => {} + } +} + +/// Replace all occurrences of any key in `map` with its value, inside every string +/// contained in `value` (recursively). +fn replace_in_strings(value: &mut Value, map: &HashMap) { + match value { + Value::String(s) => { + for (from, to) in map { + if s.contains(from.as_str()) { + *s = s.replace(from.as_str(), to); + } + } + } + Value::Array(arr) => arr.iter_mut().for_each(|v| replace_in_strings(v, map)), + Value::Object(obj) => obj.values_mut().for_each(|v| replace_in_strings(v, map)), + _ => {} + } +} + +/// Scan a documentation response (hover, signatureHelp, completionItem/resolve, …) +/// for embedded `jdt://` URIs, resolve each one to a `file://` URI backed by a temp +/// file, and replace the URIs in-place in every string of `msg.result`. +pub fn rewrite_jdt_in_strings( + msg: &mut Value, + writer: &Arc>, + pending: &Arc>>>, + next_id: &mut impl FnMut() -> Value, +) { + let Some(result) = msg.get_mut("result") else { + return; + }; + + let mut uris = Vec::new(); + collect_jdt_uris(result, &mut uris); + if uris.is_empty() { + return; + } + + let mut map = HashMap::new(); + for uri in uris { + if let Some(file_uri) = resolve_jdt_uri(&uri, writer, pending, next_id()) { + map.insert(uri, file_uri); + } + } + if !map.is_empty() { + replace_in_strings(result, &map); + } +} diff --git a/proxy/src/lsp.rs b/proxy/src/lsp.rs index eeb343e..7493861 100644 --- a/proxy/src/lsp.rs +++ b/proxy/src/lsp.rs @@ -54,6 +54,17 @@ pub fn parse_lsp_content(raw: &[u8]) -> Option { serde_json::from_slice(&raw[sep_pos + 4..]).ok() } +/// Cheap check for the presence of an `"id"` key in the JSON body of a raw LSP +/// message. Used to skip full JSON parsing for notifications, which carry no +/// `id` and therefore cannot be responses or completion results. +pub fn raw_has_id(raw: &[u8]) -> bool { + let Some(sep_pos) = raw.windows(4).position(|w| w == HEADER_SEP) else { + return false; + }; + let body = &raw[sep_pos + 4..]; + body.windows(5).any(|w| w == b"\"id\":") +} + pub fn encode_lsp(value: &impl Serialize) -> String { let json = serde_json::to_string(value).unwrap(); format!("{CONTENT_LENGTH}: {}\r\n\r\n{json}", json.len()) diff --git a/proxy/src/main.rs b/proxy/src/main.rs index becfbdd..63581e9 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -6,13 +6,13 @@ mod lsp; mod platform; use completions::{should_sort_completions, sort_completions_by_param_count}; -use decompile::rewrite_jdt_locations; +use decompile::{rewrite_jdt_in_strings, rewrite_jdt_locations}; use http::handle_http; -use lsp::{parse_lsp_content, write_raw, write_to_stdout, LspReader}; +use lsp::{parse_lsp_content, raw_has_id, write_raw, write_to_stdout, LspReader}; use platform::spawn_parent_monitor; use serde_json::Value; use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, env, fs, io::{self, BufReader, Write}, net::TcpListener, @@ -25,6 +25,12 @@ use std::{ thread, }; +#[derive(Clone, Copy)] +enum TrackedKind { + Definition, + Doc, +} + fn main() { let args: Vec = env::args().skip(1).collect(); if args.len() < 2 { @@ -90,29 +96,41 @@ fn main() { let id_counter = Arc::new(AtomicU64::new(1)); - // Track definition/typeDefinition/implementation request IDs for jdt:// rewriting - let definition_ids: Arc>> = Arc::new(Mutex::new(HashSet::new())); + // Track definition/typeDefinition/implementation and documentation request IDs + // so their responses can be intercepted and rewritten. + let tracked_ids: Arc>> = Arc::new(Mutex::new(HashMap::new())); // --- Thread 1: Zed stdin -> JDTLS stdin (track definition requests) --- let stdin_writer = Arc::clone(&child_stdin); let alive_stdin = Arc::clone(&alive); - let def_ids_in = Arc::clone(&definition_ids); + let tracked_in = Arc::clone(&tracked_ids); thread::spawn(move || { let stdin = io::stdin().lock(); - let mut reader = LspReader::new(stdin); + let mut reader = LspReader::new(BufReader::new(stdin)); while alive_stdin.load(Ordering::Relaxed) { match reader.read_message() { Ok(Some(raw)) => { - if let Some(msg) = parse_lsp_content(&raw) { - if let Some(method) = msg.get("method").and_then(|m| m.as_str()) { - if matches!( - method, - "textDocument/definition" + // Only requests (not notifications) carry an `id`; skip the + // JSON parse entirely for high-volume notifications like + // textDocument/didChange. + if raw_has_id(&raw) { + if let Some(msg) = parse_lsp_content(&raw) { + if let Some(method) = msg.get("method").and_then(|m| m.as_str()) { + let kind = match method { + "textDocument/definition" | "textDocument/typeDefinition" - | "textDocument/implementation" - ) { - if let Some(id) = msg.get("id").cloned() { - def_ids_in.lock().unwrap().insert(id); + | "textDocument/implementation" => { + Some(TrackedKind::Definition) + } + "textDocument/hover" + | "textDocument/signatureHelp" + | "completionItem/resolve" => Some(TrackedKind::Doc), + _ => None, + }; + if let Some(kind) = kind { + if let Some(id) = msg.get("id").cloned() { + tracked_in.lock().unwrap().insert(id, kind); + } } } } @@ -131,7 +149,7 @@ fn main() { // --- Thread 2: JDTLS stdout -> rewrite jdt:// URIs, modify completions -> Zed stdout / resolve pending --- let pending_out = Arc::clone(&pending); let alive_out = Arc::clone(&alive); - let def_ids_out = Arc::clone(&definition_ids); + let tracked_out = Arc::clone(&tracked_ids); let decompile_writer = Arc::clone(&child_stdin); let decompile_pending = Arc::clone(&pending); let decompile_counter = Arc::clone(&id_counter); @@ -141,6 +159,13 @@ fn main() { while alive_out.load(Ordering::Relaxed) { match reader.read_message() { Ok(Some(raw)) => { + // Fast path: notifications (no `id`) can't be responses we + // need to intercept. Forward the raw bytes without parsing. + if !raw_has_id(&raw) { + write_raw(&mut io::stdout().lock(), &raw); + continue; + } + let Some(mut msg) = parse_lsp_content(&raw) else { write_raw(&mut io::stdout().lock(), &raw); continue; @@ -154,11 +179,11 @@ fn main() { } } - // Rewrite jdt:// URIs in definition responses - // Spawns a thread so this loop stays unblocked and can - // route the java/classFileContents response back via `pending`. + // Rewrite jdt:// URIs in definition or documentation responses. + // Spawns a thread so this loop stays unblocked and can route + // the java/classFileContents response back via `pending`. if let Some(id) = msg.get("id").cloned() { - if def_ids_out.lock().unwrap().remove(&id) { + if let Some(kind) = tracked_out.lock().unwrap().remove(&id) { let writer = Arc::clone(&decompile_writer); let pending = Arc::clone(&decompile_pending); let pid = decompile_proxy_id.clone(); @@ -168,7 +193,24 @@ fn main() { let seq = counter.fetch_add(1, Ordering::Relaxed); Value::String(format!("{pid}-decompile-{seq}")) }; - rewrite_jdt_locations(&mut msg, &writer, &pending, &mut next_id); + match kind { + TrackedKind::Definition => { + rewrite_jdt_locations( + &mut msg, + &writer, + &pending, + &mut next_id, + ); + } + TrackedKind::Doc => { + rewrite_jdt_in_strings( + &mut msg, + &writer, + &pending, + &mut next_id, + ); + } + } write_to_stdout(&msg); }); continue;