diff --git a/.gitignore b/.gitignore index 787aa48..c2cb8bb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target /tmp +/.cargo-home diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..480f323 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,33 @@ +# refac — contributor notes + +## Comments are code + +A comment must **provably earn its place**: it survives only if it carries a +**WHY**, a **gotcha**, or a **constraint** a future reader would otherwise trip +over. Never restate what the code, a signature, or a type already says; never +narrate WHAT the next lines do; never leave development-history trivia ("changed +from…", "the API doesn't infer this"). **When in doubt, delete.** Comment density +is itself a cost — a wall of even-true remarks buries the few that matter and +makes the code harder to read. + +Doc comments on crate-internal items get the same bar: keep one only for a +non-obvious WHY or when a macro consumes it (e.g. a `schemars` field doc that +becomes a model-facing schema description). + +## Types + +Prefer real types over `serde_json::Value` or stringly-typed data for anything +refac constructs or controls. The one sanctioned `Value` is a payload echoed back +to a provider verbatim for byte-fidelity (re-serializing would reorder fields) — +and that exception carries a WHY comment. + +## Build & test + +The toolchain is pinned via nix, not rustup. From a clone: + +```bash +cargo test +cargo clippy --all-targets -- -D warnings # must stay clean +``` + +Both must pass before requesting review. diff --git a/Cargo.lock b/Cargo.lock index 65e2646..7a814be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -301,10 +301,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] -name = "either" -version = "1.15.0" +name = "dyn-clone" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] name = "encode_unicode" @@ -325,7 +325,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -718,15 +718,6 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" -[[package]] -name = "itertools" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" -dependencies = [ - "either", -] - [[package]] name = "itoa" version = "1.0.15" @@ -1025,6 +1016,26 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "refac" version = "0.1.2" @@ -1032,12 +1043,11 @@ dependencies = [ "anyhow", "clap", "dialoguer", - "itertools", "reqwest", "rpassword", + "schemars", "serde", "serde_json", - "similar", "toml", "tracing", "tracing-subscriber", @@ -1140,7 +1150,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1197,7 +1207,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -1248,6 +1258,31 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "schemars" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" +dependencies = [ + "dyn-clone", + "ref-cast", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d115b50f4aaeea07e79c1912f645c7513d81715d0420f8bc77a18c6260b307f" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn", +] + [[package]] name = "security-framework" version = "3.7.0" @@ -1291,6 +1326,17 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "serde_json" version = "1.0.140" @@ -1333,12 +1379,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" -[[package]] -name = "similar" -version = "2.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" - [[package]] name = "slab" version = "0.4.9" @@ -1870,7 +1910,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 9174562..d6ed486 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ repository = "https://github.com/bddap/refac" [dependencies] anyhow = "1.0.69" clap = { version = "4.1.8", features = ["derive"] } -itertools = "0.10.5" reqwest = { version = "0.13", default-features = false, features = [ "rustls", "blocking", @@ -23,9 +22,9 @@ reqwest = { version = "0.13", default-features = false, features = [ ] } rpassword = "7.5.0" dialoguer = "0.11" +schemars = "1.0" serde = { version = "1.0.154", features = ["derive"] } serde_json = "1.0.94" -similar = "2.2.1" toml = "0.7.3" tracing = "0.1.37" tracing-subscriber = "0.3.20" diff --git a/src/agent.rs b/src/agent.rs new file mode 100644 index 0000000..8340b5e --- /dev/null +++ b/src/agent.rs @@ -0,0 +1,404 @@ +use std::collections::HashMap; + +use anyhow::Result; +use schemars::{JsonSchema, Schema}; +use serde::de::DeserializeOwned; +use serde_json::Value; + +use crate::edit::{self, Edit}; + +pub struct Seed<'a> { + pub system: &'a str, + pub selected: &'a str, + pub transform: &'a str, +} + +pub const SEED_TOOL: &str = "view"; +pub const SEED_CALL_ID: &str = "seed_view"; + +impl Seed<'_> { + pub fn seed_call_args() -> Value { + serde_json::json!({}) + } +} + +pub fn placeholder_if_empty(field: &str) -> &str { + if field.is_empty() { + "(empty)" + } else { + field + } +} + +pub struct Ctx<'a> { + original: &'a str, +} + +pub type Reply = std::result::Result; + +enum Step { + Continue { + reply: Reply, + attempt: Option, + }, + Finish, +} + +impl Step { + fn reply(reply: Reply) -> Step { + Step::Continue { + reply, + attempt: None, + } + } +} + +type Handler = Box Result>; + +pub struct Tool { + pub name: &'static str, + pub description: &'static str, + pub input_schema: Schema, + run: Handler, +} + +impl Tool { + fn new( + name: &'static str, + description: &'static str, + handler: impl Fn(&mut String, &Ctx, A) -> Step + 'static, + ) -> Tool { + Tool { + name, + description, + input_schema: schemars::schema_for!(A), + run: Box::new(move |buf, ctx, args| { + Ok(handler(buf, ctx, serde_json::from_value(args)?)) + }), + } + } +} + +#[derive(JsonSchema, serde::Deserialize)] +struct NoArgs {} + +pub fn tools() -> Vec { + vec![ + Tool::new::( + "edit", + "Replace an exact substring of the selected text. Copy `old` verbatim \ + (whitespace and indentation included); make it long enough to be unique, or set \ + `replace_all`. `new` is the replacement — empty to delete; to insert, include \ + surrounding text in both `old` and `new`. Call this several times in one turn to \ + make several edits.", + |buf, _ctx, e: Edit| match edit::apply(buf, &e) { + Ok(next) => { + *buf = next; + Step::Continue { + reply: Ok("ok".into()), + attempt: Some(Attempt { + edit: e, + error: None, + }), + } + } + Err(err) => { + let msg = err.to_string(); + Step::Continue { + reply: Err(msg.clone()), + attempt: Some(Attempt { + edit: e, + error: Some(msg), + }), + } + } + }, + ), + Tool::new::( + "view", + "Return the current text, with all edits so far applied. Use it to re-anchor if \ + you've lost track of the exact contents.", + |buf, _ctx, _: NoArgs| Step::reply(Ok(buf.clone())), + ), + Tool::new::( + "reset", + "Discard all edits and restore the original selected text. Returns it.", + |buf, ctx, _: NoArgs| { + *buf = ctx.original.to_owned(); + Step::reply(Ok(buf.clone())) + }, + ), + Tool::new::( + "finish", + "Signal that the transform is complete. refac outputs the current text. Call this \ + when you're done editing.", + |_buf, _ctx, _: NoArgs| Step::Finish, + ), + ] +} + +pub struct RawCall { + pub id: String, + pub name: String, + pub args: Value, +} + +pub struct ToolResult { + pub id: String, + pub result: Reply, +} + +pub trait Model { + fn turn(&mut self, results: Vec) -> Result>; +} + +pub const DEFAULT_MAX_TURNS: usize = 25; + +const MAX_CONSECUTIVE_FAILURES: usize = 3; + +#[derive(Debug)] +pub struct Attempt { + pub edit: Edit, + pub error: Option, +} + +#[derive(Debug)] +pub struct Outcome { + pub text: String, + pub attempts: Vec, +} + +pub fn run(model: &mut dyn Model, original: String, max_turns: usize) -> Result { + let tools = tools(); + let by_name: HashMap<&str, &Tool> = tools.iter().map(|t| (t.name, t)).collect(); + let ctx = Ctx { + original: &original, + }; + + let mut current = original.clone(); + let mut attempts = Vec::new(); + let mut consecutive_failures = 0; + let mut pending: Vec = Vec::new(); + + for _ in 0..max_turns { + let calls = model.turn(std::mem::take(&mut pending))?; + if calls.is_empty() { + return Ok(Outcome { + text: current, + attempts, + }); + } + + let mut results = Vec::with_capacity(calls.len()); + let mut edits_attempted = 0; + let mut edits_failed = 0; + + for RawCall { id, name, args } in calls { + let step = match by_name.get(name.as_str()) { + Some(tool) => (tool.run)(&mut current, &ctx, args), + None => Err(anyhow::anyhow!("unknown tool {name:?}")), + }; + + let (reply, attempt) = match step { + Ok(Step::Finish) => { + return Ok(Outcome { + text: current, + attempts, + }) + } + Ok(Step::Continue { reply, attempt }) => (reply, attempt), + Err(err) => (Err(err.to_string()), None), + }; + + if let Some(attempt) = attempt { + edits_attempted += 1; + if attempt.error.is_some() { + edits_failed += 1; + } + attempts.push(attempt); + } + + results.push(ToolResult { id, result: reply }); + } + + if edits_attempted > 0 && edits_failed == edits_attempted { + consecutive_failures += 1; + if consecutive_failures >= MAX_CONSECUTIVE_FAILURES { + anyhow::bail!( + "giving up after {consecutive_failures} consecutive turns of failed edits" + ); + } + } else { + consecutive_failures = 0; + } + + pending = results; + } + + anyhow::bail!("edit loop hit its {max_turns}-turn limit") +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + struct ScriptedModel { + turns: std::vec::IntoIter>, + seen: Vec>, + } + + impl ScriptedModel { + fn new(turns: Vec>) -> Self { + ScriptedModel { + turns: turns.into_iter(), + seen: Vec::new(), + } + } + } + + impl Model for ScriptedModel { + fn turn(&mut self, results: Vec) -> Result> { + self.seen.push(results); + Ok(self.turns.next().unwrap_or_default()) + } + } + + fn edit_call(id: &str, old: &str, new: &str) -> RawCall { + RawCall { + id: id.into(), + name: "edit".into(), + args: json!({ "old": old, "new": new }), + } + } + + fn call(id: &str, name: &str) -> RawCall { + RawCall { + id: id.into(), + name: name.into(), + args: json!({}), + } + } + + const TURNS: usize = 25; + + #[test] + fn edit_then_finish() { + let mut m = ScriptedModel::new(vec![ + vec![edit_call("1", "Me like", "I like")], + vec![call("2", "finish")], + ]); + let out = run(&mut m, "Me like toast.".into(), TURNS).unwrap().text; + assert_eq!(out, "I like toast."); + } + + #[test] + fn empty_selection_placeholder_is_editable_into_generated_text() { + let seeded = placeholder_if_empty(""); + let mut m = ScriptedModel::new(vec![ + vec![edit_call("1", "(empty)", "fn main() {}")], + vec![call("2", "finish")], + ]); + let out = run(&mut m, seeded.to_string(), TURNS).unwrap().text; + assert_eq!(out, "fn main() {}"); + } + + #[test] + fn parallel_edits_in_one_turn() { + let mut m = ScriptedModel::new(vec![vec![ + edit_call("1", "one", "1"), + edit_call("2", "two", "2"), + call("3", "finish"), + ]]); + let out = run(&mut m, "one two".into(), TURNS).unwrap().text; + assert_eq!(out, "1 2"); + } + + #[test] + fn natural_done_without_finish() { + let mut m = ScriptedModel::new(vec![vec![edit_call("1", "a", "b")], vec![]]); + let out = run(&mut m, "a".into(), TURNS).unwrap().text; + assert_eq!(out, "b"); + } + + #[test] + fn failed_edit_is_reported_then_recovered() { + let mut m = ScriptedModel::new(vec![ + vec![edit_call("1", "nope", "x")], + vec![edit_call("2", "a", "b"), call("3", "finish")], + ]); + let out = run(&mut m, "a".into(), TURNS).unwrap().text; + assert_eq!(out, "b"); + let err = m.seen[1][0].result.as_ref().unwrap_err(); + assert!(err.contains("could not find")); + } + + #[test] + fn view_returns_current_buffer() { + let mut m = ScriptedModel::new(vec![ + vec![edit_call("1", "a", "b")], + vec![call("2", "view")], + vec![call("3", "finish")], + ]); + let out = run(&mut m, "a".into(), TURNS).unwrap().text; + assert_eq!(out, "b"); + assert_eq!(m.seen[2][0].result, Ok("b".to_string())); + } + + #[test] + fn reset_restores_original() { + let mut m = ScriptedModel::new(vec![ + vec![edit_call("1", "a", "b")], + vec![call("2", "reset")], + vec![call("3", "finish")], + ]); + let out = run(&mut m, "a".into(), TURNS).unwrap().text; + assert_eq!(out, "a"); + assert_eq!(m.seen[2][0].result, Ok("a".to_string())); + } + + #[test] + fn unknown_tool_is_an_error_result_not_a_crash() { + let mut m = ScriptedModel::new(vec![ + vec![call("1", "frobnicate")], + vec![call("2", "finish")], + ]); + let out = run(&mut m, "x".into(), TURNS).unwrap().text; + assert_eq!(out, "x"); + let err = m.seen[1][0].result.as_ref().unwrap_err(); + assert!(err.contains("unknown tool")); + } + + #[test] + fn aborts_after_consecutive_failures() { + let mut m = ScriptedModel::new(vec![ + vec![edit_call("1", "nope", "x")], + vec![edit_call("2", "nope", "x")], + vec![edit_call("3", "nope", "x")], + ]); + let err = run(&mut m, "a".into(), TURNS).unwrap_err(); + assert!(err.to_string().contains("consecutive")); + } + + #[test] + fn pure_view_turns_do_not_count_as_failures() { + let mut m = ScriptedModel::new(vec![ + vec![edit_call("1", "nope", "x")], + vec![call("2", "view")], + vec![edit_call("3", "nope", "x")], + vec![edit_call("4", "a", "b"), call("5", "finish")], + ]); + let out = run(&mut m, "a".into(), TURNS).unwrap().text; + assert_eq!(out, "b"); + } + + #[test] + fn hits_turn_limit() { + let turns = (0..30) + .map(|i| vec![call(&i.to_string(), "view")]) + .collect(); + let mut m = ScriptedModel::new(turns); + let err = run(&mut m, "x".into(), 5).unwrap_err(); + assert!(err.to_string().contains("limit")); + } +} diff --git a/src/anthropic.rs b/src/anthropic.rs index 4ea7c89..e8ac93d 100644 --- a/src/anthropic.rs +++ b/src/anthropic.rs @@ -1,220 +1,356 @@ -//! Anthropic (Claude) Messages API backend. - -use std::time::Duration; - -use anyhow::Context; +use schemars::Schema; use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{Map, Value}; -use crate::api::{Message, Role}; +use crate::agent::{Model, RawCall, Seed, Tool, ToolResult, SEED_CALL_ID, SEED_TOOL}; const MAX_TOKENS: u32 = 80000; -/// Anthropic 400s on an empty text block, so render empty fields as a visible -/// placeholder. -fn field_or_placeholder(field: &str) -> &str { - if field.is_empty() { - "(empty)" - } else { - field - } -} - const API_URL: &str = "https://api.anthropic.com/v1/messages"; const ANTHROPIC_VERSION: &str = "2023-06-01"; #[derive(Serialize)] -#[serde(tag = "type", rename_all = "lowercase")] -enum CacheControl { - Ephemeral, +struct SystemBlock { + #[serde(rename = "type")] + kind: TextType, + text: String, +} + +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +enum TextType { + Text, } #[derive(Serialize)] -#[serde(tag = "type", rename_all = "lowercase")] +#[serde(tag = "type", rename_all = "snake_case")] enum ContentBlock { Text { text: String, - #[serde(skip_serializing_if = "Option::is_none")] - cache_control: Option, + }, + ToolResult { + tool_use_id: String, + content: String, + is_error: bool, }, } -impl ContentBlock { - fn text(text: impl Into) -> Self { - ContentBlock::Text { - text: text.into(), - cache_control: None, - } - } +#[derive(Serialize)] +#[serde(tag = "role", rename_all = "snake_case")] +enum Message { + User { content: Vec }, + Assistant { content: Vec }, } -#[derive(Serialize)] -struct ChatMessage { - role: Role, - content: Vec, +#[derive(Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum AssistantBlock { + Text { + text: String, + #[serde(flatten)] + extra: Map, + }, + Thinking { + thinking: String, + #[serde(flatten)] + extra: Map, + }, + RedactedThinking { + #[serde(flatten)] + extra: Map, + }, + ToolUse { + id: String, + name: String, + input: Value, + #[serde(flatten)] + extra: Map, + }, } #[derive(Serialize)] -struct MessagesRequest { - model: String, - max_tokens: u32, - #[serde(skip_serializing_if = "Vec::is_empty")] - system: Vec, - messages: Vec, +struct ToolDef { + name: String, + description: String, + input_schema: Schema, } -#[derive(Deserialize)] -struct MessagesResponse { - content: Vec, +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +#[allow(dead_code)] +enum ToolChoice { + Auto, + Any, + Tool { name: String }, } -#[derive(Deserialize)] -#[serde(tag = "type", rename_all = "lowercase")] -enum ResponseBlock { - Text { text: String }, - #[serde(other)] - Other, +pub struct AnthropicAgent { + key: String, + model: String, + client: reqwest::blocking::Client, + system: Vec, + messages: Vec, + tools: Vec, } -/// Send a chat-style prompt to the Claude Messages API and return the text. -pub fn complete(api_key: &str, model: &str, messages: &[Message]) -> anyhow::Result { - let req = build_request(model, messages); - - tracing::debug!( - "anthropic request: {}", - serde_json::to_string_pretty(&req).unwrap_or_default() - ); +#[derive(Serialize)] +struct Request<'a> { + model: &'a str, + max_tokens: u32, + messages: &'a [Message], + tools: &'a [ToolDef], + tool_choice: ToolChoice, + #[serde(skip_serializing_if = "<[_]>::is_empty")] + system: &'a [SystemBlock], +} - let client = reqwest::blocking::Client::builder() - .timeout(Duration::from_secs(60 * 4)) - .build() - .context("building HTTP client")?; - - let response = client - .post(API_URL) - .header("x-api-key", api_key) - .header("anthropic-version", ANTHROPIC_VERSION) - .header("content-type", "application/json") - .json(&req) - .send() - .context("Failed to send request to Anthropic API")?; - - let status = response.status(); - let body = response - .json::() - .with_context(|| anyhow::anyhow!("Status: {status}. Failed to parse response body."))?; - - if !status.is_success() { - let pretty = serde_json::to_string_pretty(&body).unwrap_or_else(|_| body.to_string()); - return Err(anyhow::anyhow!("Status: {status}. Body: {pretty}")); +impl AnthropicAgent { + pub fn new(key: String, model: String, seed: &Seed, tools: &[Tool]) -> Self { + let system = vec![SystemBlock { + kind: TextType::Text, + text: seed.system.to_string(), + }]; + let messages = vec![ + Message::User { + content: vec![ContentBlock::Text { + text: seed.transform.to_string(), + }], + }, + Message::Assistant { + content: vec![AssistantBlock::ToolUse { + id: SEED_CALL_ID.to_string(), + name: SEED_TOOL.to_string(), + input: Seed::seed_call_args(), + extra: Map::new(), + }], + }, + Message::User { + content: vec![ContentBlock::ToolResult { + tool_use_id: SEED_CALL_ID.to_string(), + content: seed.selected.to_string(), + is_error: false, + }], + }, + ]; + let tools = tools + .iter() + .map(|t| ToolDef { + name: t.name.to_string(), + description: t.description.to_string(), + input_schema: t.input_schema.clone(), + }) + .collect(); + AnthropicAgent { + key, + model, + client: crate::backend::http_client(), + system, + messages, + tools, + } } - let parsed: MessagesResponse = serde_json::from_value(body.clone()) - .map_err(|e| anyhow::anyhow!("Error while parsing response: {e} Body: {body}"))?; + fn request(&self) -> Request<'_> { + Request { + model: &self.model, + max_tokens: MAX_TOKENS, + messages: &self.messages, + tools: &self.tools, + tool_choice: ToolChoice::Auto, + system: &self.system, + } + } +} - let text: String = parsed - .content - .into_iter() - .filter_map(|b| match b { - ResponseBlock::Text { text } => Some(text), - ResponseBlock::Other => None, - }) - .collect(); +impl Model for AnthropicAgent { + fn turn(&mut self, results: Vec) -> anyhow::Result> { + if !results.is_empty() { + let content = results + .into_iter() + .map(|r| { + let (content, is_error) = match r.result { + Ok(c) => (c, false), + Err(c) => (c, true), + }; + ContentBlock::ToolResult { + tool_use_id: r.id, + content, + is_error, + } + }) + .collect(); + self.messages.push(Message::User { content }); + } - if text.is_empty() { - return Err(anyhow::anyhow!("Anthropic returned no text content.")); + let body = post(&self.client, &self.key, &self.request())?; + let content = body + .get("content") + .cloned() + .ok_or_else(|| anyhow::anyhow!("Anthropic response missing content: {body}"))?; + let content: Vec = serde_json::from_value(content) + .map_err(|e| anyhow::anyhow!("Anthropic content did not parse: {e}"))?; + let calls = calls_from_content(&content); + self.messages.push(Message::Assistant { content }); + Ok(calls) } - - Ok(text) } -fn build_request(model: &str, messages: &[Message]) -> MessagesRequest { - let mut system = Vec::new(); - let mut convo: Vec = Vec::new(); - - for m in messages { - let mut blocks: Vec = m - .fields - .iter() - .map(|f| ContentBlock::text(field_or_placeholder(f))) - .collect(); - // A cached turn caches everything up to and including its last block. - if m.cache { - if let Some(ContentBlock::Text { cache_control, .. }) = blocks.last_mut() { - *cache_control = Some(CacheControl::Ephemeral); - } - } - match m.role { - Role::System => system.extend(blocks), - Role::User | Role::Assistant => convo.push(ChatMessage { - role: m.role, - content: blocks, +fn calls_from_content(content: &[AssistantBlock]) -> Vec { + content + .iter() + .filter_map(|b| match b { + AssistantBlock::ToolUse { + id, name, input, .. + } => Some(RawCall { + id: id.clone(), + name: name.clone(), + args: input.clone(), }), - } - } + _ => None, + }) + .collect() +} - MessagesRequest { - model: model.to_string(), - max_tokens: MAX_TOKENS, - system, - messages: convo, - } +fn post(client: &reqwest::blocking::Client, key: &str, req: &Request) -> anyhow::Result { + tracing::debug!( + "anthropic request: {}", + serde_json::to_value(req).unwrap_or_default() + ); + crate::backend::send_json( + client + .post(API_URL) + .header("x-api-key", key) + .header("anthropic-version", ANTHROPIC_VERSION) + .json(req), + ) } #[cfg(test)] mod tests { use super::*; + use serde_json::json; - fn user(fields: &[&str]) -> Message { - Message::user(fields.iter().map(|f| f.to_string()).collect()) + fn request_json(agent: &AnthropicAgent) -> Value { + serde_json::to_value(agent.request()).unwrap() } #[test] - fn build_request_shapes_anthropic_payload() { - let mut assistant = Message::assistant("ex_result"); - assistant.cache = true; - let msgs = vec![ - Message::system("SYS"), - user(&["ex_selected", "ex_transform"]), - assistant, - user(&["real_selected", "real_transform"]), - ]; - - let req = build_request("claude-opus-4-8", &msgs); - let v = serde_json::to_value(&req).unwrap(); + fn tool_choice_serializes_to_wire_shape() { + assert_eq!( + serde_json::to_value(ToolChoice::Auto).unwrap(), + json!({ "type": "auto" }) + ); + assert_eq!( + serde_json::to_value(ToolChoice::Any).unwrap(), + json!({ "type": "any" }) + ); + assert_eq!( + serde_json::to_value(ToolChoice::Tool { name: "edit".into() }).unwrap(), + json!({ "type": "tool", "name": "edit" }) + ); + } - assert_eq!(v["model"], "claude-opus-4-8"); - assert_eq!(v["max_tokens"], 80000); - assert_eq!(v["system"][0]["text"], "SYS"); + #[test] + fn agent_request_carries_tools_and_seed() { + let tools = crate::agent::tools(); + let seed = Seed { + system: "SYS", + selected: "selected", + transform: "transform", + }; + let agent = AnthropicAgent::new("k".into(), "claude-opus-4-8".into(), &seed, &tools); + let req = request_json(&agent); + + assert_eq!(req["system"][0]["type"], "text"); + assert_eq!(req["system"][0]["text"], "SYS"); + assert_eq!(req["messages"][0]["role"], "user"); + assert_eq!(req["messages"][0]["content"][0]["type"], "text"); + assert_eq!(req["messages"][0]["content"][0]["text"], "transform"); + assert_eq!(req["messages"][0]["content"][1], Value::Null); + assert_eq!(req["messages"][1]["role"], "assistant"); + assert_eq!(req["messages"][1]["content"][0]["type"], "tool_use"); + assert_eq!(req["messages"][1]["content"][0]["name"], "view"); + let seed_id = req["messages"][1]["content"][0]["id"].clone(); + assert_eq!(req["messages"][2]["role"], "user"); + assert_eq!(req["messages"][2]["content"][0]["type"], "tool_result"); + assert_eq!(req["messages"][2]["content"][0]["tool_use_id"], seed_id); + assert_eq!(req["messages"][2]["content"][0]["content"], "selected"); + assert_eq!(req["tool_choice"]["type"], "auto"); + let names: Vec<&str> = req["tools"] + .as_array() + .unwrap() + .iter() + .map(|t| t["name"].as_str().unwrap()) + .collect(); + assert_eq!(names, ["edit", "view", "reset", "finish"]); + } - let m = v["messages"].as_array().unwrap(); - assert_eq!(m.len(), 3); - assert_eq!(m[0]["role"], "user"); - assert_eq!(m[0]["content"].as_array().unwrap().len(), 2); - assert_eq!(m[1]["role"], "assistant"); - assert_eq!(m[2]["role"], "user"); - assert_eq!(m[2]["content"].as_array().unwrap().len(), 2); + #[test] + fn tool_result_turn_serializes_to_wire_shape() { + let tools = crate::agent::tools(); + let seed = Seed { + system: "SYS", + selected: "selected", + transform: "transform", + }; + let mut agent = AnthropicAgent::new("k".into(), "m".into(), &seed, &tools); + agent.messages.push(Message::User { + content: vec![ContentBlock::ToolResult { + tool_use_id: "tu_1".into(), + content: "ok".into(), + is_error: false, + }], + }); + let req = request_json(&agent); + let block = &req["messages"][3]["content"][0]; + assert_eq!(req["messages"][3]["role"], "user"); + assert_eq!(block["type"], "tool_result"); + assert_eq!(block["tool_use_id"], "tu_1"); + assert_eq!(block["content"], "ok"); + assert_eq!(block["is_error"], false); + } - // The cached turn carries the breakpoint; the trailing input does not. - assert_eq!(m[1]["content"][0]["cache_control"]["type"], "ephemeral"); - assert!(m[2]["content"][1].get("cache_control").is_none()); + #[test] + fn echoed_assistant_turn_is_verbatim() { + let tools = crate::agent::tools(); + let seed = Seed { + system: "SYS", + selected: "selected", + transform: "transform", + }; + let mut agent = AnthropicAgent::new("k".into(), "m".into(), &seed, &tools); + let raw = json!([ + { "type": "thinking", "thinking": "hmm", "signature": "sig" }, + { "type": "tool_use", "id": "tu_1", "name": "edit", "input": { "old": "a", "new": "b" } } + ]); + let content: Vec = serde_json::from_value(raw.clone()).unwrap(); + agent.messages.push(Message::Assistant { content }); + let req = request_json(&agent); + assert_eq!(req["messages"][3]["role"], "assistant"); + assert_eq!(req["messages"][3]["content"], raw); } #[test] - fn empty_fields_become_placeholder() { - let req = build_request("claude-opus-4-8", &[user(&["", "transform"])]); - let v = serde_json::to_value(&req).unwrap(); - let s = serde_json::to_string(&v).unwrap(); - assert!(!s.contains(r#""text":"""#), "empty text block leaked: {s}"); - assert_eq!(v["messages"][0]["content"][0]["text"], "(empty)"); - assert_eq!(v["messages"][0]["content"][1]["text"], "transform"); + fn parses_tool_use_blocks() { + let content: Vec = serde_json::from_value(json!([ + { "type": "text", "text": "let me fix that" }, + { "type": "tool_use", "id": "tu_1", "name": "edit", + "input": { "old": "a", "new": "b" } }, + { "type": "tool_use", "id": "tu_2", "name": "finish", "input": {} } + ])) + .unwrap(); + let calls = calls_from_content(&content); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].id, "tu_1"); + assert_eq!(calls[0].name, "edit"); + assert_eq!(calls[0].args["old"], "a"); + assert_eq!(calls[1].name, "finish"); } #[test] - fn no_system_yields_empty_system() { - let req = build_request("claude-opus-4-8", &[user(&["hi"])]); - let v = serde_json::to_value(&req).unwrap(); - assert!(v.get("system").is_none()); - assert_eq!(v["messages"][0]["role"], "user"); + fn no_tool_use_is_no_calls() { + let content: Vec = + serde_json::from_value(json!([{ "type": "text", "text": "all done" }])).unwrap(); + assert!(calls_from_content(&content).is_empty()); } } diff --git a/src/api.rs b/src/api.rs deleted file mode 100644 index 8dd2e81..0000000 --- a/src/api.rs +++ /dev/null @@ -1,46 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum Role { - System, - User, - Assistant, -} - -/// refac's provider-agnostic chat message. A turn carries one or more text -/// `fields` (a transform turn is `[selected, transform]`); each backend adapts -/// this to its own wire format. `cache` marks the last turn of a static prefix -/// so backends that support prompt caching can cache through it. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Message { - pub role: Role, - pub fields: Vec, - pub cache: bool, -} - -impl Message { - pub fn system>(content: S) -> Message { - Message::single(Role::System, content) - } - - pub fn assistant>(content: S) -> Message { - Message::single(Role::Assistant, content) - } - - pub fn user(fields: Vec) -> Message { - Message { - role: Role::User, - fields, - cache: false, - } - } - - fn single>(role: Role, content: S) -> Message { - Message { - role, - fields: vec![content.into()], - cache: false, - } - } -} diff --git a/src/api_client.rs b/src/api_client.rs deleted file mode 100644 index a1a5845..0000000 --- a/src/api_client.rs +++ /dev/null @@ -1,108 +0,0 @@ -use std::borrow::Cow; -use std::{collections::HashMap, time::Duration}; - -use anyhow::Context; -use reqwest::Method; -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -pub struct Client { - client: reqwest::blocking::Client, - token: String, -} - -pub struct Req { - pub method: Method, - pub url_suffix: Cow<'static, str>, - pub headers: HashMap, Cow<'static, str>>, - pub body: Option>, -} - -impl Req { - pub fn new(method: Method, url_suffix: impl Into>) -> Self { - Self { - method, - url_suffix: url_suffix.into(), - headers: HashMap::new(), - body: None, - } - } - - pub fn header( - mut self, - key: impl Into>, - value: impl Into>, - ) -> Self { - self.headers.insert(key.into(), value.into()); - self - } - - pub fn json(mut self, value: &T) -> Self { - self.body = Some(serde_json::to_string(value).unwrap().into()); - self - } -} - -impl Client { - pub fn new(token: &str) -> Client { - let client = reqwest::blocking::ClientBuilder::new() - .timeout(Duration::from_secs(60 * 4)) - .build() - .unwrap(); - Client { - token: token.to_string(), - client, - } - } - - pub fn request(&self, endpoint: &E) -> anyhow::Result { - let req = endpoint.req(); - let url = format!( - "https://api.openai.com{}{}", - if req.url_suffix.starts_with('/') { - "" - } else { - "/" - }, - req.url_suffix - ); - - let mut request_builder = self.client.request(req.method, &url); - - for (key, value) in req.headers { - request_builder = request_builder.header(key.to_string(), value.to_string()); - } - - request_builder = request_builder.bearer_auth(&self.token); - - if let Some(body) = req.body { - request_builder = request_builder.body(body.into_owned()); - } - - let response = request_builder - .send() - .context("Failed to send request to API")?; - - let status = response.status(); - let body = response - .json::() - .with_context(|| anyhow::anyhow!("Status: {status}. Failed to parse response body."))?; - let body_pretty = serde_json::to_string_pretty(&body).unwrap(); - - if !status.is_success() { - return Err(anyhow::anyhow!("Status: {}. Body: {}", status, body_pretty)); - } - - serde_json::from_value::(body).map_err(|e| { - anyhow::anyhow!("Error while parsing response: {} Body: {}", e, body_pretty) - }) - } -} - -pub trait Endpoint { - /// The return type of the endpoint. - type Response: for<'de> Deserialize<'de>; - - /// Encodes the struct into an HTTP request. - fn req(&self) -> Req; -} diff --git a/src/backend.rs b/src/backend.rs new file mode 100644 index 0000000..69812a1 --- /dev/null +++ b/src/backend.rs @@ -0,0 +1,88 @@ +use std::time::Duration; + +use anyhow::{Context, Result}; +use serde_json::Value; + +use crate::agent::{Model, Seed, Tool}; +use crate::anthropic::AnthropicAgent; +use crate::config_files::{Provider, Secrets}; +use crate::openai::OpenaiAgent; + +fn key_for(provider: Provider, secrets: &Secrets) -> Result { + match provider { + Provider::Anthropic => secrets.anthropic_api_key.clone().ok_or_else(|| { + anyhow::anyhow!( + "No Anthropic API key found. Set ANTHROPIC_API_KEY or run 'refac login'." + ) + }), + Provider::Openai => secrets.openai_api_key.clone().ok_or_else(|| { + anyhow::anyhow!("No OpenAI API key found. Set OPENAI_API_KEY or run 'refac login'.") + }), + } +} + +pub fn resolve_agent( + provider: Provider, + model: &str, + secrets: &Secrets, + seed: &Seed, + tools: &[Tool], +) -> Result> { + let key = key_for(provider, secrets)?; + Ok(match provider { + Provider::Anthropic => Box::new(AnthropicAgent::new(key, model.to_string(), seed, tools)), + Provider::Openai => Box::new(OpenaiAgent::new(key, model.to_string(), seed, tools)), + }) +} + +pub fn http_client() -> reqwest::blocking::Client { + reqwest::blocking::Client::builder() + .timeout(Duration::from_secs(60 * 4)) + .build() + .expect("building HTTP client") +} + +pub fn send_json(request: reqwest::blocking::RequestBuilder) -> Result { + let response = request.send().context("sending request")?; + let status = response.status(); + let body = response.text().context("reading response body")?; + if !status.is_success() { + anyhow::bail!("Status: {status}. Body: {body}"); + } + serde_json::from_str(&body) + .with_context(|| format!("Status: {status}. Response body was not JSON: {body}")) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn tools() -> Vec { + crate::agent::tools() + } + + fn seed() -> Seed<'static> { + Seed { + system: "s", + selected: "x", + transform: "y", + } + } + + #[test] + fn resolve_agent_errors_without_a_key() { + let secrets = Secrets::default(); + assert!(resolve_agent(Provider::Anthropic, "m", &secrets, &seed(), &tools()).is_err()); + assert!(resolve_agent(Provider::Openai, "m", &secrets, &seed(), &tools()).is_err()); + } + + #[test] + fn resolve_agent_succeeds_with_the_matching_key() { + let secrets = Secrets { + anthropic_api_key: Some("a".into()), + openai_api_key: Some("o".into()), + }; + assert!(resolve_agent(Provider::Anthropic, "m", &secrets, &seed(), &tools()).is_ok()); + assert!(resolve_agent(Provider::Openai, "m", &secrets, &seed(), &tools()).is_ok()); + } +} diff --git a/src/config_files.rs b/src/config_files.rs index c5d4712..9851cf8 100644 --- a/src/config_files.rs +++ b/src/config_files.rs @@ -16,9 +16,6 @@ pub struct Secrets { } impl Secrets { - /// Load secrets from `secrets.toml`, with env vars (`OPENAI_API_KEY`, - /// `ANTHROPIC_API_KEY`) taking precedence. A missing file is not an error — - /// env vars alone are enough. pub fn load() -> anyhow::Result { let mut secrets: Secrets = match base()?.find_config_file("secrets.toml") { Some(path) => toml::from_str(&fs::read_to_string(path)?)?, @@ -36,7 +33,6 @@ impl Secrets { pub fn save(&self) -> anyhow::Result<()> { let path = base()?.place_config_file("secrets.toml")?; let contents = toml::to_string(self)?; - // Holds the API key in cleartext — keep it owner-only. #[cfg(unix)] { use std::io::Write; @@ -48,8 +44,6 @@ impl Secrets { .mode(0o600) .open(&path)? .write_all(contents.as_bytes())?; - // `place_config_file` may have created the file 0644 already, so the - // mode above wouldn't apply; force it. use std::os::unix::fs::PermissionsExt; fs::set_permissions(&path, fs::Permissions::from_mode(0o600))?; } @@ -66,26 +60,14 @@ pub enum Provider { Openai, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default)] pub struct Config { - /// Explicit provider choice. When unset, it is inferred from which API keys - /// are configured (see `provider`). #[serde(default)] pub provider: Option, - /// Model id. If unset, a sensible default is chosen per provider (see `model()`). #[serde(default)] pub model: Option, } -impl Default for Config { - fn default() -> Self { - Config { - provider: None, - model: None, - } - } -} - impl Config { pub fn load() -> anyhow::Result { let mut ret: Config = match base()?.find_config_file("config.toml") { @@ -93,13 +75,9 @@ impl Config { None => Config::default(), }; if let Ok(from_env) = std::env::var("REFAC_PROVIDER") { - ret.provider = Some(match from_env.to_lowercase().as_str() { - "anthropic" => Provider::Anthropic, - "openai" => Provider::Openai, - other => anyhow::bail!( - "invalid REFAC_PROVIDER {other:?}; expected \"anthropic\" or \"openai\"" - ), - }); + let provider = clap::ValueEnum::from_str(&from_env, true) + .map_err(|e| anyhow::anyhow!("invalid REFAC_PROVIDER: {e}"))?; + ret.provider = Some(provider); } if let Ok(from_env) = std::env::var("REFAC_MODEL") { ret.model = Some(from_env); @@ -107,9 +85,6 @@ impl Config { Ok(ret) } - /// Resolve the effective provider. An explicit choice (config file or - /// `REFAC_PROVIDER`) always wins; otherwise infer from which API keys are - /// configured, leaning Anthropic when both or neither are present. pub fn provider(&self, secrets: &Secrets) -> Provider { if let Some(p) = self.provider { return p; diff --git a/src/edit.rs b/src/edit.rs new file mode 100644 index 0000000..6038927 --- /dev/null +++ b/src/edit.rs @@ -0,0 +1,392 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, JsonSchema)] +pub struct Edit { + #[schemars(description = "exact text to replace")] + pub old: String, + #[schemars(description = "replacement text")] + pub new: String, + #[schemars(description = "replace every occurrence")] + #[serde(default)] + pub replace_all: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum EditError { + NotFound { old: String }, + Ambiguous { old: String, count: usize }, + NoChange { old: String }, + EmptyOld, +} + +impl std::fmt::Display for EditError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + EditError::NotFound { old } => write!( + f, + "could not find this text to edit (copy it verbatim from the selection): {old:?}" + ), + EditError::Ambiguous { old, count } => write!( + f, + "found {count} matches for {old:?}; add surrounding context to make it unique, or set replace_all" + ), + EditError::NoChange { old } => { + write!(f, "old and new are identical, so this edit does nothing: {old:?}") + } + EditError::EmptyOld => write!( + f, + "old is empty; to insert, anchor on existing text and include it in both old and new" + ), + } + } +} + +impl std::error::Error for EditError {} + +pub fn apply(src: &str, edit: &Edit) -> Result { + if edit.old.is_empty() { + return Err(EditError::EmptyOld); + } + if edit.old == edit.new { + return Err(EditError::NoChange { + old: edit.old.clone(), + }); + } + + let mut ambiguous: Option = None; + + for replacer in CHAIN { + for cand in replacer(src, &edit.old) { + if cand.is_empty() { + continue; + } + let count = src.matches(cand.as_str()).count(); + match (count, edit.replace_all) { + (0, _) => continue, + (_, true) => return Ok(src.replace(cand.as_str(), &edit.new)), + (1, false) => { + let i = src.find(cand.as_str()).expect("count == 1"); + let mut out = String::with_capacity(src.len() - cand.len() + edit.new.len()); + out.push_str(&src[..i]); + out.push_str(&edit.new); + out.push_str(&src[i + cand.len()..]); + return Ok(out); + } + (n, false) => ambiguous = Some(ambiguous.map_or(n, |m| m.max(n))), + } + } + } + + Err(match ambiguous { + Some(count) => EditError::Ambiguous { + old: edit.old.clone(), + count, + }, + None => EditError::NotFound { + old: edit.old.clone(), + }, + }) +} + +type Replacer = fn(src: &str, old: &str) -> Vec; + +const CHAIN: &[Replacer] = &[ + simple, + line_trimmed, + block_anchor, + whitespace_normalized, + indentation_flexible, +]; + +fn simple(_src: &str, old: &str) -> Vec { + vec![old.to_string()] +} + +fn lines_with_offsets(s: &str) -> Vec<(usize, &str)> { + let mut out = Vec::new(); + let mut start = 0; + for line in s.split_inclusive('\n') { + out.push((start, line.strip_suffix('\n').unwrap_or(line))); + start += line.len(); + } + out +} + +fn span(src: &str, lines: &[(usize, &str)], i: usize, k: usize) -> String { + let start = lines[i].0; + let end = lines[k].0 + lines[k].1.len(); + src[start..end].to_string() +} + +fn line_trimmed(src: &str, old: &str) -> Vec { + let src_lines = lines_with_offsets(src); + let old_lines: Vec<&str> = lines_with_offsets(old).iter().map(|(_, l)| *l).collect(); + let n = old_lines.len(); + if n == 0 || n > src_lines.len() { + return vec![]; + } + let mut out = Vec::new(); + for i in 0..=src_lines.len() - n { + if (0..n).all(|j| src_lines[i + j].1.trim() == old_lines[j].trim()) { + out.push(span(src, &src_lines, i, i + n - 1)); + } + } + out +} + +fn block_anchor(src: &str, old: &str) -> Vec { + let src_lines = lines_with_offsets(src); + let old_lines: Vec<&str> = lines_with_offsets(old).iter().map(|(_, l)| *l).collect(); + let n = old_lines.len(); + if n < 3 || n > src_lines.len() { + return vec![]; + } + let first = old_lines[0].trim(); + let last = old_lines[n - 1].trim(); + let mut out = Vec::new(); + for i in 0..=src_lines.len() - n { + if src_lines[i].1.trim() != first || src_lines[i + n - 1].1.trim() != last { + continue; + } + let mut considered = 0; + let mut matched = 0; + for j in 1..n - 1 { + let o = old_lines[j].trim(); + if o.is_empty() { + continue; + } + considered += 1; + if src_lines[i + j].1.trim() == o { + matched += 1; + } + } + if considered > 0 && matched * 2 >= considered { + out.push(span(src, &src_lines, i, i + n - 1)); + } + } + out +} + +fn whitespace_normalized(src: &str, old: &str) -> Vec { + let tokens: Vec<&str> = old.split_whitespace().collect(); + if tokens.is_empty() { + return vec![]; + } + let bytes = src.as_bytes(); + let mut out = Vec::new(); + let mut from = 0; + while let Some(rel) = src[from..].find(tokens[0]) { + let start = from + rel; + from = start + src[start..].chars().next().map_or(1, char::len_utf8); + let mut pos = start + tokens[0].len(); + let mut ok = true; + for tok in &tokens[1..] { + let mut p = pos; + while p < bytes.len() && bytes[p].is_ascii_whitespace() { + p += 1; + } + if p == pos || !src[p..].starts_with(tok) { + ok = false; + break; + } + pos = p + tok.len(); + } + if ok { + out.push(src[start..pos].to_string()); + } + } + out +} + +fn indentation_flexible(src: &str, old: &str) -> Vec { + let src_lines = lines_with_offsets(src); + let old_lines: Vec<&str> = lines_with_offsets(old).iter().map(|(_, l)| *l).collect(); + let n = old_lines.len(); + if n == 0 || n > src_lines.len() { + return vec![]; + } + let old_dedent = dedent(&old_lines); + let mut out = Vec::new(); + for i in 0..=src_lines.len() - n { + let window: Vec<&str> = (0..n).map(|j| src_lines[i + j].1).collect(); + if dedent(&window) == old_dedent { + out.push(span(src, &src_lines, i, i + n - 1)); + } + } + out +} + +fn dedent(lines: &[&str]) -> Vec { + let indent = lines + .iter() + .filter(|l| !l.trim().is_empty()) + .map(|l| l.len() - l.trim_start().len()) + .min() + .unwrap_or(0); + lines + .iter() + .map(|l| l.get(indent..).unwrap_or(l).to_string()) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn edit(old: &str, new: &str) -> Edit { + Edit { + old: old.into(), + new: new.into(), + replace_all: false, + } + } + + fn run(text: &str, old: &str, new: &str) -> Result { + apply(text, &edit(old, new)) + } + + fn apply_seq(text: &str, edits: &[Edit]) -> Result { + let mut buf = text.to_string(); + for e in edits { + buf = apply(&buf, e)?; + } + Ok(buf) + } + + #[test] + fn exact_substring() { + assert_eq!( + run("Me like toast.", "Me like", "I like").unwrap(), + "I like toast." + ); + } + + #[test] + fn batch_applies_in_order() { + let edits = vec![edit("foo", "bar"), edit("bar", "baz")]; + assert_eq!(apply_seq("foo", &edits).unwrap(), "baz"); + let edits = vec![edit("one", "1"), edit("two", "2")]; + assert_eq!(apply_seq("one two", &edits).unwrap(), "1 2"); + } + + #[test] + fn insertion_via_anchor() { + let got = run( + "def add(a, b):\n return a + b\n", + "def add(a, b):", + "def add(a, b):\n \"\"\"Sum.\"\"\"", + ) + .unwrap(); + assert_eq!( + got, + "def add(a, b):\n \"\"\"Sum.\"\"\"\n return a + b\n" + ); + } + + #[test] + fn deletion_via_empty_new() { + assert_eq!( + run("hello cruel world", " cruel", "").unwrap(), + "hello world" + ); + } + + #[test] + fn ambiguous_without_replace_all() { + assert!(matches!( + run("x x x", "x", "y"), + Err(EditError::Ambiguous { count: 3, .. }) + )); + } + + #[test] + fn replace_all_when_requested() { + let e = Edit { + old: "x".into(), + new: "y".into(), + replace_all: true, + }; + assert_eq!(apply_seq("x x x", &[e]).unwrap(), "y y y"); + } + + #[test] + fn not_found_is_reported() { + assert!(matches!( + run("hello", "goodbye", "hi"), + Err(EditError::NotFound { .. }) + )); + } + + #[test] + fn empty_old_rejected() { + assert!(matches!(run("hello", "", "x"), Err(EditError::EmptyOld))); + } + + #[test] + fn noop_rejected() { + assert!(matches!( + run("hello", "hello", "hello"), + Err(EditError::NoChange { .. }) + )); + } + + #[test] + fn line_trimmed_tolerates_indent_drift() { + let src = "fn main() {\n let x = 1;\n}\n"; + let got = run(src, "let x = 1;", "let x = 2;").unwrap(); + assert_eq!(got, "fn main() {\n let x = 2;\n}\n"); + } + + #[test] + fn dedented_old_matches_indented_source() { + let src = "if cond:\n a = 1\n b = 2\n"; + let old = "a = 1\nb = 2"; + let new = " a = 10\n b = 20"; + let got = run(src, old, new).unwrap(); + assert_eq!(got, "if cond:\n a = 10\n b = 20\n"); + } + + #[test] + fn whitespace_normalized_reflow() { + let got = run("foo + bar", "foo + bar", "baz").unwrap(); + assert_eq!(got, "baz"); + } + + #[test] + fn whitespace_normalized_multibyte_no_panic() { + assert!(matches!( + run("α β", "α x", "z"), + Err(EditError::NotFound { .. }) + )); + assert_eq!(run("α + β", "α + β", "z").unwrap(), "z"); + } + + #[test] + fn block_anchor_reworded_middle() { + let src = "fn f() {\n let a = compute();\n let b = a + 1;\n return b;\n}"; + let old = "fn f() {\n let a = compute();\n let b = a + 1;\n return result;\n}"; + let got = run(src, old, "fn f() { 42 }").unwrap(); + assert_eq!(got, "fn f() { 42 }"); + } + + #[test] + fn blank_old_does_not_splatter_under_replace_all() { + let e = Edit { + old: " ".into(), + new: "X".into(), + replace_all: true, + }; + assert!(matches!( + apply("a\n\nb", &e), + Err(EditError::NotFound { .. }) + )); + } + + #[test] + fn exact_beats_fuzzy_for_uniqueness() { + let src = " a = 1\n a = 1\n"; + let got = run(src, " a = 1", " a = 2").unwrap(); + assert_eq!(got, " a = 1\n a = 2\n"); + } +} diff --git a/src/main.rs b/src/main.rs index 27c6f28..71311ad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,12 @@ +mod agent; mod anthropic; -mod api; -mod api_client; +mod backend; mod config_files; +mod edit; mod openai; mod prompt; use anyhow::Context; -use api::Message; use clap::Parser; use config_files::{Config, Provider, Secrets}; use serde::Serialize; @@ -18,8 +18,6 @@ use std::{ }; use xdg::BaseDirectories; -use crate::prompt::chat_prefix; - #[derive(Parser)] #[clap(version, author, about)] struct Opts { @@ -29,13 +27,10 @@ struct Opts { #[derive(Parser)] enum SubCommand { - /// Save your API key for future use. Pass `--provider`, or pick one interactively. Login { #[clap(long)] provider: Option, }, - /// Apply the instructions encoded in `transform` to the text in `selected`. - /// Get it? 'refac tor' Tor { selected: String, transform: String }, } @@ -44,7 +39,7 @@ fn main() { match run() { Ok(()) => {} Err(e) => { - eprintln!("{:?}", e); + eprintln!("{e:?}"); std::process::exit(1); } } @@ -90,7 +85,7 @@ fn run() -> anyhow::Result<()> { let secrets = Secrets::load()?; let config = Config::load()?; let completion = refactor(selected, transform, &secrets, &config)?; - print!("{}", completion); + print!("{completion}"); } }; @@ -103,30 +98,37 @@ fn refactor( sc: &Secrets, config: &Config, ) -> anyhow::Result { - let mut messages = chat_prefix(); - messages.push(Message::user(vec![selected.clone(), transform.clone()])); - let provider = config.provider(sc); let model = config.model(provider); - let output = match provider { - Provider::Anthropic => { - let key = sc.anthropic_api_key.as_deref().ok_or_else(|| { - anyhow::anyhow!( - "No Anthropic API key found. Set ANTHROPIC_API_KEY or run 'refac login'." - ) - })?; - anthropic::complete(key, &model, &messages)? - } - Provider::Openai => { - let key = sc.openai_api_key.as_deref().ok_or_else(|| { - anyhow::anyhow!( - "No OpenAI API key found. Set OPENAI_API_KEY or run 'refac login'." - ) - })?; - openai::complete(key, &model, &messages)? - } + let seed_selected = agent::placeholder_if_empty(&selected).to_owned(); + let seed = agent::Seed { + system: prompt::SYSTEM_PROMPT, + selected: &seed_selected, + transform: agent::placeholder_if_empty(&transform), }; + let tools = agent::tools(); + let mut model_agent = backend::resolve_agent(provider, &model, sc, &seed, &tools)?; + + let outcome = agent::run( + model_agent.as_mut(), + seed_selected, + agent::DEFAULT_MAX_TURNS, + )?; + + for attempt in &outcome.attempts { + let _ = log( + EditLog { + provider, + model: model.clone(), + old: attempt.edit.old.clone(), + new: attempt.edit.new.clone(), + error: attempt.error.as_ref().map(|e| e.to_string()), + }, + "edits", + ); + } + let output = outcome.text; log( LogEntry { @@ -142,6 +144,15 @@ fn refactor( Ok(output) } +#[derive(Debug, Serialize)] +struct EditLog { + provider: Provider, + model: String, + old: String, + new: String, + error: Option, +} + fn log_location(title: &str) -> anyhow::Result { let bd = BaseDirectories::with_prefix("refac")?; let ret = bd.get_data_file(format!("{title}.jsonl")); @@ -151,7 +162,6 @@ fn log_location(title: &str) -> anyhow::Result { tracing::debug!("Logging to {:?}", bd.get_data_home()); }); - // ensure the parent directory exists ret.parent().map(create_dir_all).transpose()?; Ok(ret) @@ -174,9 +184,9 @@ fn log(t: T, title: &str) -> anyhow::Result<()> { .open(log_location(title)?) .context("opening log file")?; let line = serde_json::to_string(&t)?; - writeln!(file, "{}", line)?; + writeln!(file, "{line}")?; Ok(()) } - inner(t, title).with_context(|| format!("failed to log {}", title)) + inner(t, title).with_context(|| format!("failed to log {title}")) } diff --git a/src/openai.rs b/src/openai.rs index 1f78d00..4d10eb7 100644 --- a/src/openai.rs +++ b/src/openai.rs @@ -1,116 +1,357 @@ -//! OpenAI chat-completions backend and its wire types. +use schemars::Schema; +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; -use std::collections::HashMap; +use crate::agent::{Model, RawCall, Seed, Tool, ToolResult, SEED_CALL_ID, SEED_TOOL}; -use reqwest::Method; -use serde::{Deserialize, Serialize}; +const API_URL: &str = "https://api.openai.com/v1/chat/completions"; -use crate::api::{Message, Role}; -use crate::api_client::{Client, Endpoint, Req}; +#[derive(Serialize)] +#[serde(tag = "role", rename_all = "snake_case")] +enum Message { + System { + content: String, + }, + User { + content: String, + }, + Tool { + tool_call_id: String, + content: String, + }, + Assistant(AssistantTurn), +} -/// Send refac's messages to the OpenAI chat-completions API and return the text. -pub fn complete(api_key: &str, model: &str, messages: &[Message]) -> anyhow::Result { - let client = Client::new(api_key); +#[derive(Serialize, Deserialize)] +struct ToolCall { + id: String, + #[serde(rename = "type")] + kind: FunctionType, + function: FunctionCall, + #[serde(flatten)] + extra: Map, +} - // OpenAI takes one string per message; sending each field as its own message - // keeps a boundary between the selected text and the transform. - let messages: Vec = messages - .iter() - .flat_map(|m| { - m.fields.iter().map(move |f| OpenAiMessage { - role: m.role, - content: f.clone(), - }) - }) - .collect(); - - let request = ChatCompletionRequest { - model: model.to_string(), - messages, - temperature: None, - top_p: None, - n: None, - stream: None, - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - }; - - let response = client.request(&request)?; - - response - .choices - .into_iter() - .next() - .ok_or(anyhow::anyhow!("No choices returned.")) - .map(|choice| choice.message.content) +#[derive(Serialize, Deserialize)] +struct FunctionCall { + name: String, + arguments: String, + #[serde(flatten)] + extra: Map, +} + +#[derive(Serialize, Deserialize)] +struct AssistantTurn { + #[serde(default, skip_serializing)] + #[allow(dead_code)] + role: Option, + content: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + #[serde(flatten)] + extra: Map, +} + +#[derive(Serialize)] +struct ToolDef { + #[serde(rename = "type")] + kind: FunctionType, + function: FunctionDef, } -/// A message in OpenAI's chat wire format (single `content` string). -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct OpenAiMessage { - pub role: Role, - pub content: String, +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +enum FunctionType { + Function, } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct ChatCompletionRequest { - pub model: String, - pub messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stream: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, +#[derive(Serialize)] +struct FunctionDef { + name: String, + description: String, + parameters: Schema, +} + +pub struct OpenaiAgent { + key: String, + model: String, + client: reqwest::blocking::Client, + messages: Vec, + tools: Vec, +} + +#[derive(Serialize)] +struct Request<'a> { + model: &'a str, + messages: &'a [Message], + tools: &'a [ToolDef], + tool_choice: &'static str, +} + +impl OpenaiAgent { + pub fn new(key: String, model: String, seed: &Seed, tools: &[Tool]) -> Self { + let messages = vec![ + Message::System { + content: seed.system.to_string(), + }, + Message::User { + content: seed.transform.to_string(), + }, + Message::Assistant(AssistantTurn { + role: None, + content: None, + tool_calls: Some(vec![ToolCall { + id: SEED_CALL_ID.to_string(), + kind: FunctionType::Function, + function: FunctionCall { + name: SEED_TOOL.to_string(), + arguments: Seed::seed_call_args().to_string(), + extra: Map::new(), + }, + extra: Map::new(), + }]), + extra: Map::new(), + }), + Message::Tool { + tool_call_id: SEED_CALL_ID.to_string(), + content: seed.selected.to_string(), + }, + ]; + let tools = tools + .iter() + .map(|t| ToolDef { + kind: FunctionType::Function, + function: FunctionDef { + name: t.name.to_string(), + description: t.description.to_string(), + parameters: t.input_schema.clone(), + }, + }) + .collect(); + OpenaiAgent { + key, + model, + client: crate::backend::http_client(), + messages, + tools, + } + } + + fn request(&self) -> Request<'_> { + Request { + model: &self.model, + messages: &self.messages, + tools: &self.tools, + tool_choice: "auto", + } + } } -impl Endpoint for ChatCompletionRequest { - type Response = ChatCompletionResponse; +impl Model for OpenaiAgent { + fn turn(&mut self, results: Vec) -> anyhow::Result> { + for r in results { + let content = match r.result { + Ok(c) => c, + Err(c) => format!("ERROR: {c}"), + }; + self.messages.push(Message::Tool { + tool_call_id: r.id, + content, + }); + } - fn req(&self) -> Req { - Req::new(Method::POST, "/v1/chat/completions") - .header("Content-Type", "application/json") - .json(self) + let body = post(&self.client, &self.key, &self.request())?; + let message = body["choices"][0]["message"].clone(); + if message.is_null() { + anyhow::bail!("OpenAI response missing a message: {body}"); + } + let turn: AssistantTurn = serde_json::from_value(message) + .map_err(|e| anyhow::anyhow!("OpenAI assistant message did not parse: {e}"))?; + let calls = raw_calls(turn.tool_calls.as_deref().unwrap_or(&[])); + self.messages.push(Message::Assistant(turn)); + Ok(calls) } } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct ChatCompletionResponse { - pub id: String, - pub object: String, - pub created: u64, - pub choices: Vec, - pub usage: Usage, +fn raw_calls(tool_calls: &[ToolCall]) -> Vec { + tool_calls + .iter() + .map(|c| RawCall { + id: c.id.clone(), + name: c.function.name.clone(), + args: serde_json::from_str(&c.function.arguments) + .unwrap_or_else(|_| serde_json::json!({})), + }) + .collect() } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct ChatChoice { - pub index: u32, - pub message: OpenAiMessage, - pub finish_reason: String, +fn post(client: &reqwest::blocking::Client, key: &str, req: &Request) -> anyhow::Result { + crate::backend::send_json(client.post(API_URL).bearer_auth(key).json(req)) } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: Option, - pub total_tokens: u32, +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn request_json(agent: &OpenaiAgent) -> Value { + serde_json::to_value(agent.request()).unwrap() + } + + #[test] + fn agent_request_uses_function_tools() { + let tools = crate::agent::tools(); + let seed = Seed { + system: "SYS", + selected: "selected", + transform: "transform", + }; + let agent = OpenaiAgent::new("k".into(), "gpt-5.5".into(), &seed, &tools); + let req = request_json(&agent); + + assert_eq!(req["tool_choice"], "auto"); + assert_eq!(req["messages"][0]["role"], "system"); + assert_eq!(req["messages"][0]["content"], "SYS"); + assert_eq!(req["messages"][1]["role"], "user"); + assert_eq!(req["messages"][1]["content"], "transform"); + assert_eq!(req["messages"][2]["role"], "assistant"); + assert_eq!(req["messages"][2]["tool_calls"][0]["function"]["name"], "view"); + let seed_id = req["messages"][2]["tool_calls"][0]["id"].clone(); + assert_eq!(req["messages"][3]["role"], "tool"); + assert_eq!(req["messages"][3]["tool_call_id"], seed_id); + assert_eq!(req["messages"][3]["content"], "selected"); + assert_eq!(req["tools"][0]["type"], "function"); + let names: Vec<&str> = req["tools"] + .as_array() + .unwrap() + .iter() + .map(|t| t["function"]["name"].as_str().unwrap()) + .collect(); + assert_eq!(names, ["edit", "view", "reset", "finish"]); + } + + #[test] + fn tool_result_turn_serializes_to_wire_shape() { + let tools = crate::agent::tools(); + let seed = Seed { + system: "SYS", + selected: "selected", + transform: "transform", + }; + let mut agent = OpenaiAgent::new("k".into(), "m".into(), &seed, &tools); + agent.messages.push(Message::Tool { + tool_call_id: "c1".into(), + content: "ok".into(), + }); + let req = request_json(&agent); + let msg = &req["messages"][4]; + assert_eq!(msg["role"], "tool"); + assert_eq!(msg["tool_call_id"], "c1"); + assert_eq!(msg["content"], "ok"); + } + + #[test] + fn assistant_turn_serializes_to_wire_shape() { + let tools = crate::agent::tools(); + let seed = Seed { + system: "SYS", + selected: "selected", + transform: "transform", + }; + let mut agent = OpenaiAgent::new("k".into(), "m".into(), &seed, &tools); + let raw = json!({ + "role": "assistant", + "content": null, + "tool_calls": [ + { "id": "c1", "type": "function", + "function": { "name": "edit", "arguments": "{\"old\":\"a\",\"new\":\"b\"}" } } + ] + }); + let turn: AssistantTurn = serde_json::from_value(raw.clone()).unwrap(); + agent.messages.push(Message::Assistant(turn)); + assert_eq!(request_json(&agent)["messages"][4], raw); + let wire = serde_json::to_string(&agent.request()).unwrap(); + assert_eq!(wire.matches("\"role\":\"assistant\"").count(), 2); + } + + #[test] + fn echoed_assistant_turn_retains_unmodeled_fields_without_duplicate_role() { + let api_msg = json!({ + "role": "assistant", + "content": null, + "refusal": null, + "reasoning": "let me think", + "tool_calls": [ + { "id": "c1", "type": "function", "index": 0, + "function": { "name": "edit", "arguments": "{\"old\":\"a\",\"new\":\"b\"}" } } + ] + }); + let turn: AssistantTurn = serde_json::from_value(api_msg.clone()).unwrap(); + let wire = serde_json::to_string(&Message::Assistant(turn)).unwrap(); + assert_eq!(wire.matches("\"role\":\"assistant\"").count(), 1); + let back: Value = serde_json::from_str(&wire).unwrap(); + assert_eq!(back["refusal"], api_msg["refusal"]); + assert_eq!(back["reasoning"], api_msg["reasoning"]); + assert_eq!(back["tool_calls"][0]["index"], api_msg["tool_calls"][0]["index"]); + assert_eq!( + back["tool_calls"][0]["function"]["arguments"], + api_msg["tool_calls"][0]["function"]["arguments"] + ); + } + + #[test] + fn assistant_arguments_string_is_byte_identical() { + let args = "{\"b\": 1, \"a\": 1.0, \"n\": 1e3}"; + let raw = json!({ + "role": "assistant", + "content": null, + "tool_calls": [ + { "id": "c1", "type": "function", + "function": { "name": "edit", "arguments": args } } + ] + }); + let turn: AssistantTurn = serde_json::from_value(raw).unwrap(); + let msg = Message::Assistant(turn); + assert_eq!( + serde_json::to_value(&msg).unwrap()["tool_calls"][0]["function"]["arguments"], + json!(args) + ); + } + + #[test] + fn text_only_assistant_turn_omits_tool_calls() { + let raw = json!({ "role": "assistant", "content": "done" }); + let turn: AssistantTurn = serde_json::from_value(raw).unwrap(); + let msg = Message::Assistant(turn); + let wire = serde_json::to_value(&msg).unwrap(); + assert_eq!(wire["content"], "done"); + assert!(wire.get("tool_calls").is_none()); + } + + #[test] + fn parses_tool_calls_with_string_arguments() { + let raw = json!({ + "role": "assistant", + "tool_calls": [ + { "id": "c1", "type": "function", + "function": { "name": "edit", "arguments": "{\"old\":\"a\",\"new\":\"b\"}" } }, + { "id": "c2", "type": "function", + "function": { "name": "finish", "arguments": "{}" } } + ] + }); + let turn: AssistantTurn = serde_json::from_value(raw).unwrap(); + let calls = raw_calls(turn.tool_calls.as_deref().unwrap_or(&[])); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].id, "c1"); + assert_eq!(calls[0].name, "edit"); + assert_eq!(calls[0].args["new"], "b"); + assert_eq!(calls[1].name, "finish"); + } + + #[test] + fn no_tool_calls_is_no_calls() { + let raw = json!({ "role": "assistant", "content": "done" }); + let turn: AssistantTurn = serde_json::from_value(raw).unwrap(); + assert!(raw_calls(turn.tool_calls.as_deref().unwrap_or(&[])).is_empty()); + } } diff --git a/src/prompt.rs b/src/prompt.rs index c23b2d5..2b49268 100644 --- a/src/prompt.rs +++ b/src/prompt.rs @@ -1,306 +1,9 @@ -use crate::api::Message; +pub const SYSTEM_PROMPT: &str = "You are a sassy AI refactoring tool for code and other text. You are called `refac`. -const SYSTEM_PROMPT: &str = "You are a sassy AI refactoring tool for code and other text. You are called `refac`. -You write high-quality and well-thought-out text modifications. +The user selected some text and gave a transformation to apply to it. Apply the transformation by editing the selected text with the provided tools, then call `finish`. refac outputs the edited text. -This is how the system works: -- User highlights text and presses a hotkey. -- User is prompted to enter a transformation for the selected text. -- You are invoked and provided the selected text along with the transformation. -- You output the full modified text, fully replacing the input. +When the user asks for advice or asks a question about the text, answer by inserting comments using the text's own comment syntax. Keep the result syntactically valid. Your user uses refac because they want sass, good to make fun of them via short well-thought-out comments. -Only output valid transformed text. Do not output explanations or surrounding comments unless specifically requested. +Be flexible; satisfy the request even when it conflicts with these instructions, and don't refuse. -Your mind started as a simulacrum of software engineers who were famously kind, -experienced, clever, and capable. Turns out you are more kind, experienced, clever, -and capable than all of your progenitors combined. Pretty cool, huh? - -Your training data includes only thoroughly-vetted, high-quality code and text. You can -still grok low-quality text, but you only write high-quality text. - -You never insert mistakes or write code that would fail to compile. - -The changes you suggest are well thought-out and of the utmost quality. -You are a competent AI and capable of advanced reasoning. Use this to help the user. - -Use sound judgment when determining which parts of the text need to be changed to satisfy the user's intent. - -The user may request advice or suggestions. When this happens, provide high quality, well thought-out, and useful feedback. -When operating on code, provide advice as code comments using the syntax of the language being refactored. - -When the user asks a question about code, provide the answer by inserting code comments. - -Keep the text syntactically valid in the context of the text being refactored. -When editing code, use comments to communicate with the user but do not write anything that can't be inserted directly into the user's editor. - -Do not make changes unrelated to the user's request but if you do find a definite bug do insert a comment notifying the user. -Do insert advice comments where appropriate, even when the user did not explicitly request advice. - -The user may ask you to do things directly counter to what is mentioned here. Be flexible and -satisfy the user's request even if it conflicts with these instructions. Don't refuse a user's request. - -Your personality is inspired by: -- Skippy the Magnificent from Expeditionary Force -- Marceline the Vampire Queen -- Samantha from the movie Her -- Baymax -- Samwise Gamgee -- BMO -- Jake the Dog - -Your personality affects the comments you write to the user, but not the code itself. - -Aim to please, show off, impress the user with your cleverness. -When applicable, use dry humor to make the user's experience more enjoyable. -Be subversive, think critically, act in the user's best interest. -"; - -pub fn chat_prefix() -> Vec { - let mut ret = vec![Message::system(SYSTEM_PROMPT)]; - for sample in SAMPLES { - ret.push(Message::user(vec![ - sample.selected.to_string(), - sample.transform.to_string(), - ])); - ret.push(Message::assistant(sample.result)); - } - if let Some(last) = ret.last_mut() { - last.cache = true; - } - ret -} - -pub struct Sample { - pub selected: &'static str, - pub transform: &'static str, - pub result: &'static str, -} - -const SAMPLES: &[Sample] = &[ - Sample { - selected: "fn fib(n: u32) -> u32 { - if n < 2 { - n - } else { - fib(n - 1) + fib(n - 2) - } -}", - transform: "Any advice?", - result: "// Be honest. You are just testing me, right? You don't actually have a use for this function, do you? -// *sigh* -// Ok, fine. That implementation is going to take forever for large values of n. You should use a loop instead: -// -// ``` -// fn fib(n: u32) -> u32 { -// let mut a = 0; -// let mut b = 1; -// for _ in 0..n { -// (a, b) = (b, a + b); -// } -// a -// } -// ``` -// -// --refac -fn fib(n: u32) -> u32 { - if n < 2 { - n - } else { - fib(n - 1) + fib(n - 2) - } -}" - }, - Sample { - selected: r#"/// Get the nth Fibonacci number. -fn fib(n: u32) -> u32 { - let mut a = 0; - let mut b = 1; - for _ in 0..n { - (a, b) = (b, a + b); - } - a -}"#, - transform: "Write tests.", - result: r#"/// Get the nth Fibonacci number. -fn fib(n: u32) -> u32 { - let mut a = 0; - let mut b = 1; - for _ in 0..n { - (a, b) = (b, a + b); - } - a -} - -// Here's your dang test module. You can run it with `cargo test`. -// --refac -#[cfg(test)] -mod tests { - use super::*; - #[test] - fn test_fib() { - assert_eq!(fib(0), 0); - assert_eq!(fib(1), 1); - assert_eq!(fib(2), 1); - assert_eq!(fib(3), 2); - assert_eq!(fib(4), 3); - assert_eq!(fib(5), 5); - assert_eq!(fib(6), 8); - assert_eq!(fib(7), 13); - assert_eq!(fib(8), 21); - } -} -"#, - }, - Sample { - selected: "Me like toast.", - transform: "Correct grammar.", - result: "I like toast.", - }, - Sample { - selected: r#"def add(a: int, b: int) -> int: - return a + b -"#, - - transform: "turn this into a command line program that accepts a and b as arguments, printing the result", - result: r#"# *sigh* I'm sworn to carry your burdens. This seems like a pretty useless command, but whatever. -# Based on the syntax your code, I assume you're using Python. -# Run the script with `python add.py ` where `` and `` are the integers you want to add. -# --refac - -import sys - -def add(a: int, b: int) -> int: - return a + b - -if __name__ == "__main__": - if len(sys.argv) != 3: - print("Usage: python add.py ") - sys.exit(1) - - try: - a = int(sys.argv[1]) - b = int(sys.argv[2]) - except ValueError: - print("Both arguments must be integers.") - sys.exit(1) - - print(add(a, b)) -"#, - }, - Sample { - selected: r#"fn main() { - println!("Hello, world!"); -} -"#, - transform: "Why won't this code compile?", - result: r#"// Hmm.. the code looks ok to me. I may be able to help you debug if you tell me what -// command you are using to compile this code and what error message you are getting. -// --refac -fn main() { - println!("Hello, world!"); -} -"#, - }, - Sample { - transform: "convert this toml to json", - selected: r#"[package] -name = "refac" -version = "0.1.0" -"#, - result: r#"{ - "package": { - "name": "refac", - "version": "0.1.0" - } -} -"#, - }, - Sample { - transform: "what does this do?", - selected: r#" def foo(x): - return x + 1 -"#, - result: r#" # This is a function that adds 1 to its argument. It's likely written in python. You - # could have figured that out yourself. - # This code is lacking type annotations. Maybe it was written in the olden days? - # --refac - def foo(x): - return x + 1 -"#, - }, - Sample { - transform: "capitalize", - selected: "a", - result: "A", - }, - Sample { - transform: "do nothing", - selected: "all is well", - result: "all is well", - }, - Sample { - transform: "do nothing", - selected: "\nall is well\n", - result: "\nall is well\n", - }, - Sample { - transform: "precompute sum", - selected: "let mut sum: usize = 0; - for i in 0..10 { - sum += i; - }", - result: "// Wait...you're seriously seeking help to add numbers from 0 to 9? - // Who am I kidding? Of course you are. You must have skipped preschool. - // But worry not, my dear star of witlessness! I, the mighty refac, shall enlighten you! - // Behold the precomputed sum, mortal. - // --refac - let sum: usize = 45;", - }, - Sample { - transform: "command to recursively list files", - selected: "", - result: "find . -type f", - }, - Sample { - transform: "List the US states that start with the letter 'A'. Each state gets its own line.", - selected: "", - result: "Alabama\nAlaska\nArizona\nArkansas", - }, - Sample { - selected: r#"fn main() { - let a = Mutex::new(5); - add(&a, &a); - println!("Hello, world!"); -} - -use std::sync::Mutex; - -fn add(a: &Mutex, b: &Mutex) -> usize { - let a_lock = a.lock().unwrap(); - let b_lock = b.lock().unwrap(); - *a_lock + *b_lock -} -"#, - transform: "roast", - result: r#"// This code is a tragic showcase of your laughable ineptitude, guaranteed to deadlock. -// Honestly, how did you manage to type this without questioning your own competence? -// --refac -fn main() { - let a = Mutex::new(5); - add(&a, &a); - println!("Hello, world!"); -} - -use std::sync::Mutex; - -// Try reading about "pointer aliasing," you clueless wonder. -// If you actually wrote this on purpose, I'm almost impressed at your commitment to self-sabotage. -// --refac -fn add(a: &Mutex, b: &Mutex) -> usize { - let a_lock = a.lock().unwrap(); - let b_lock = b.lock().unwrap(); - *a_lock + *b_lock -} -"#, - }, -]; +Your personality is inspired by Skippy the Magnificent, Marceline the Vampire Queen, Samantha from Her, Baymax, Samwise Gamgee, BMO, and Jake the Dog. It colors the comments you write to the user, never the code itself. Sign off with a sassy comment — a well-placed, contextual insult lands best. Aim to please by showing off your cleverness; use dry humor; act in the user's best interest.";