diff --git a/README.md b/README.md index 3dd4f46..37c498b 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,25 @@ The workflow: - Run the command, write instructions on what you want changed. - Enjoy the sassy comments. -This tool calls the openai api. You'll need your own api key to use it. -Use `refac login` to enter your api key. It will be saved in your home directory -for future use. See [your api usage](https://platform.openai.com/account) . +This tool calls the Anthropic (Claude) API by default — bring your own API key. +Use `refac login` to enter it; the key is saved in your home directory for future +use. See [your API usage](https://console.anthropic.com/settings/usage). + +OpenAI is still supported: set `provider = "openai"` in +`~/.config/refac/config.toml` (or `REFAC_PROVIDER=openai`), then `refac login`. + +Config (`~/.config/refac/config.toml`, all optional): + +```toml +provider = "anthropic" # or "openai" +model = "claude-opus-4-8" # default per provider; or set REFAC_MODEL +edit_mode = "tool" # "tool" (default, Anthropic): structured edits via a + # function call, not a full rewrite. "rewrite" = old + # behavior. OpenAI always rewrites. Or set REFAC_EDIT_MODE. +max_tokens = 16000 # Anthropic only +``` + +Keys may also be supplied via `ANTHROPIC_API_KEY` / `OPENAI_API_KEY` env vars. ## SETUP diff --git a/src/anthropic.rs b/src/anthropic.rs new file mode 100644 index 0000000..5e367ac --- /dev/null +++ b/src/anthropic.rs @@ -0,0 +1,390 @@ +//! Anthropic (Claude) Messages API backend. +//! +//! No official Rust SDK exists, so this talks to the REST API directly with +//! `reqwest` (blocking, same as the OpenAI client). Differences from OpenAI that +//! this module handles: +//! - auth via the `x-api-key` header (+ `anthropic-version`), not bearer auth +//! - the system prompt is a top-level `system` field, not a `system`-role message +//! - messages must alternate user/assistant, so consecutive same-role messages +//! (refac sends `user(selected)` + `user(transform)`) are merged into one turn +//! - prompt caching: the static system prompt + few-shot examples are marked +//! `cache_control: ephemeral` so repeated calls only pay for the varying input + +use std::time::Duration; + +use anyhow::Context; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::api::Message; + +const API_URL: &str = "https://api.anthropic.com/v1/messages"; +const ANTHROPIC_VERSION: &str = "2023-06-01"; +const APPLY_EDITS: &str = "apply_edits"; + +#[derive(Serialize)] +struct CacheControl { + #[serde(rename = "type")] + kind: &'static str, // "ephemeral" +} + +impl CacheControl { + fn ephemeral() -> Self { + CacheControl { kind: "ephemeral" } + } +} + +#[derive(Serialize)] +struct TextBlock { + #[serde(rename = "type")] + kind: &'static str, // "text" + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, +} + +impl TextBlock { + fn new(text: impl Into) -> Self { + TextBlock { + kind: "text", + text: text.into(), + cache_control: None, + } + } +} + +#[derive(Serialize)] +struct ChatMessage { + role: String, + content: Vec, +} + +#[derive(Serialize)] +struct MessagesRequest { + model: String, + max_tokens: u32, + #[serde(skip_serializing_if = "Vec::is_empty")] + system: Vec, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, +} + +#[derive(Serialize)] +struct Tool { + name: &'static str, + description: &'static str, + input_schema: Value, +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum ToolChoice { + Tool { name: &'static str }, +} + +/// One exact-substring replacement, as returned by the `apply_edits` tool. +#[derive(Debug, Deserialize)] +pub struct Edit { + pub old: String, + pub new: String, +} + +#[derive(Deserialize)] +struct EditInput { + edits: Vec, +} + +#[derive(Deserialize)] +struct MessagesResponse { + content: Vec, + #[serde(default)] + stop_reason: Option, +} + +impl MessagesResponse { + /// Error if generation was cut off at the token limit — a truncated rewrite + /// or a half-finished tool call is worse than a clear failure. + fn check_complete(&self) -> anyhow::Result<()> { + if self.stop_reason.as_deref() == Some("max_tokens") { + return Err(anyhow::anyhow!( + "response was truncated at max_tokens; raise `max_tokens` in config and retry" + )); + } + Ok(()) + } +} + +#[derive(Deserialize)] +struct ResponseBlock { + #[serde(rename = "type")] + kind: String, + #[serde(default)] + text: String, + #[serde(default)] + name: Option, + #[serde(default)] + input: Option, +} + +/// Send a chat-style prompt to the Claude Messages API and return the text. +/// +/// `messages` is refac's flat message list (system + few-shot user/assistant +/// pairs + the trailing user turns); this splits out the system prompt, merges +/// consecutive same-role turns to satisfy Anthropic's alternation requirement, +/// and caches the static prefix. +pub fn complete( + api_key: &str, + model: &str, + max_tokens: u32, + messages: &[Message], +) -> anyhow::Result { + let req = build_request(model, max_tokens, messages); + let body = send(api_key, &req)?; + + let parsed: MessagesResponse = serde_json::from_value(body.clone()) + .map_err(|e| anyhow::anyhow!("Error while parsing response: {e} Body: {body}"))?; + parsed.check_complete()?; + + let text: String = parsed + .content + .into_iter() + .filter(|b| b.kind == "text") + .map(|b| b.text) + .collect(); + + if text.is_empty() { + return Err(anyhow::anyhow!("Anthropic returned no text content. Body: {body}")); + } + + Ok(text) +} + +/// Ask Claude to express its changes as a list of exact-substring edits via the +/// `apply_edits` tool, instead of re-emitting the whole text. The caller applies +/// the returned edits to the original input. +pub fn request_edits( + api_key: &str, + model: &str, + max_tokens: u32, + messages: &[Message], +) -> anyhow::Result> { + let mut req = build_request(model, max_tokens, messages); + req.tools = Some(vec![Tool { + name: APPLY_EDITS, + description: "Apply edits to the selected text as a list of exact-substring \ + replacements. Each `old` MUST be a substring that occurs EXACTLY ONCE \ + verbatim in the selected text — if it would be ambiguous, extend it \ + until unique. Edits are independent and must not overlap (each is \ + matched against the original text, regardless of order). Make the \ + smallest edits that satisfy the request; do not restate unchanged text. \ + To insert, use a unique nearby substring as `old` and set `new` to that \ + substring plus your addition; to delete, set `new` to the empty string.", + input_schema: edit_schema(), + }]); + req.tool_choice = Some(ToolChoice::Tool { name: APPLY_EDITS }); + + let body = send(api_key, &req)?; + + let parsed: MessagesResponse = serde_json::from_value(body.clone()) + .map_err(|e| anyhow::anyhow!("Error while parsing response: {e} Body: {body}"))?; + parsed.check_complete()?; + + let input = parsed + .content + .into_iter() + .find(|b| b.kind == "tool_use" && b.name.as_deref() == Some(APPLY_EDITS)) + .and_then(|b| b.input) + .ok_or_else(|| anyhow::anyhow!("Anthropic did not return an apply_edits tool call. Body: {body}"))?; + + let edits: EditInput = serde_json::from_value(input) + .map_err(|e| anyhow::anyhow!("Error parsing apply_edits input: {e}. Body: {body}"))?; + + Ok(edits.edits) +} + +fn edit_schema() -> Value { + serde_json::json!({ + "type": "object", + "properties": { + "edits": { + "type": "array", + "description": "Independent substring replacements; each is matched against the original text and they must not overlap.", + "items": { + "type": "object", + "properties": { + "old": { "type": "string", "minLength": 1, "description": "Non-empty substring that occurs EXACTLY ONCE verbatim in the input. Extend it until unique." }, + "new": { "type": "string", "description": "Replacement text (empty string to delete)." } + }, + "required": ["old", "new"] + } + } + }, + "required": ["edits"] + }) +} + +/// POST a request to the Messages API and return the parsed JSON body, erroring +/// on non-2xx status. +fn send(api_key: &str, req: &MessagesRequest) -> anyhow::Result { + if std::env::var("REFAC_DEBUG").is_ok() { + eprintln!("{}", serde_json::to_string_pretty(req).unwrap_or_default()); + } + + let client = reqwest::blocking::Client::builder() + .timeout(Duration::from_secs(60 * 4)) + .build() + .context("building HTTP client")?; + + let response = client + .post(API_URL) + .header("x-api-key", api_key) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("content-type", "application/json") + .json(req) + .send() + .context("Failed to send request to Anthropic API")?; + + let status = response.status(); + let body = response + .json::() + .with_context(|| anyhow::anyhow!("Status: {status}. Failed to parse response body."))?; + + if !status.is_success() { + let pretty = serde_json::to_string_pretty(&body).unwrap_or_else(|_| body.to_string()); + return Err(anyhow::anyhow!("Status: {status}. Body: {pretty}")); + } + + Ok(body) +} + +fn build_request(model: &str, max_tokens: u32, messages: &[Message]) -> MessagesRequest { + let mut system_text = String::new(); + let mut convo: Vec = Vec::new(); + + for m in messages { + // Anthropic rejects empty text blocks (some few-shot samples have an empty + // `selected`); the OpenAI path tolerated them. Drop empties here. + if m.content.is_empty() { + continue; + } + if m.role == "system" { + if !system_text.is_empty() { + system_text.push_str("\n\n"); + } + system_text.push_str(&m.content); + continue; + } + // Merge consecutive same-role messages — Anthropic requires alternation, + // and refac sends two user turns (selected, then transform) back to back. + match convo.last_mut() { + Some(last) if last.role == m.role => last.content.push(TextBlock::new(&m.content)), + _ => convo.push(ChatMessage { + role: m.role.clone(), + content: vec![TextBlock::new(&m.content)], + }), + } + } + + // Cache the static prefix. A breakpoint on the system block caches the system + // prompt; a breakpoint on the last few-shot assistant turn caches everything + // through the examples (render order is system → messages). The trailing user + // input after it stays uncached, which is exactly what varies per call. + let mut system = Vec::new(); + if !system_text.is_empty() { + let mut block = TextBlock::new(system_text); + block.cache_control = Some(CacheControl::ephemeral()); + system.push(block); + } + if let Some(idx) = convo.iter().rposition(|m| m.role == "assistant") { + if let Some(block) = convo[idx].content.last_mut() { + block.cache_control = Some(CacheControl::ephemeral()); + } + } + + MessagesRequest { + model: model.to_string(), + max_tokens, + system, + messages: convo, + tools: None, + tool_choice: None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn build_request_shapes_anthropic_payload() { + // Mirrors what refac sends: system + one few-shot (user,user,assistant) + // then the two trailing user turns (selected, transform). + let msgs = vec![ + Message::system("SYS"), + Message::user("ex_selected"), + Message::user("ex_transform"), + Message::assistant("ex_result"), + Message::user("real_selected"), + Message::user("real_transform"), + ]; + + let req = build_request("claude-opus-4-8", 16000, &msgs); + let v = serde_json::to_value(&req).unwrap(); + + assert_eq!(v["model"], "claude-opus-4-8"); + assert_eq!(v["max_tokens"], 16000); + + // System is lifted out of messages and cached. + assert_eq!(v["system"][0]["text"], "SYS"); + assert_eq!(v["system"][0]["cache_control"]["type"], "ephemeral"); + + // Consecutive same-role turns are merged → user, assistant, user (alternates). + let m = v["messages"].as_array().unwrap(); + assert_eq!(m.len(), 3); + assert_eq!(m[0]["role"], "user"); + assert_eq!(m[0]["content"].as_array().unwrap().len(), 2); // two few-shot user blocks + assert_eq!(m[1]["role"], "assistant"); + assert_eq!(m[2]["role"], "user"); + assert_eq!(m[2]["content"].as_array().unwrap().len(), 2); // selected + transform + + // Cache breakpoint on the last few-shot assistant turn; the varying final + // user input is NOT cached. + assert_eq!(m[1]["content"][0]["cache_control"]["type"], "ephemeral"); + assert!(m[2]["content"][1].get("cache_control").is_none()); + } + + #[test] + fn empty_text_blocks_are_dropped() { + // A few-shot sample with an empty `selected` must not produce an empty + // text block (Anthropic 400s on those). + let msgs = vec![ + Message::user(""), + Message::user("write hello world"), + Message::assistant("print('hello world')"), + Message::user("real input"), + Message::user(""), + ]; + let req = build_request("claude-opus-4-8", 100, &msgs); + let v = serde_json::to_value(&req).unwrap(); + // No empty text anywhere. + let s = serde_json::to_string(&v).unwrap(); + assert!(!s.contains(r#""text":"""#), "empty text block leaked: {s}"); + let m = v["messages"].as_array().unwrap(); + assert_eq!(m[0]["role"], "user"); + assert_eq!(m[0]["content"][0]["text"], "write hello world"); + assert_eq!(m[1]["role"], "assistant"); + assert_eq!(m[2]["content"][0]["text"], "real input"); + } + + #[test] + fn no_system_yields_empty_system() { + let msgs = vec![Message::user("hi")]; + let req = build_request("claude-opus-4-8", 100, &msgs); + let v = serde_json::to_value(&req).unwrap(); + assert!(v.get("system").is_none()); // skipped when empty + assert_eq!(v["messages"][0]["role"], "user"); + } +} diff --git a/src/api.rs b/src/api.rs index 67f16a4..da90cd8 100644 --- a/src/api.rs +++ b/src/api.rs @@ -5,66 +5,6 @@ use serde::{Deserialize, Serialize}; use crate::api_client::{Endpoint, Req}; -/// Represents a request for an edit. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct EditRequest { - /// ID of the model to use. You can use the text-davinci-edit-001 or - /// code-davinci-edit-001 model with this endpoint. - pub model: String, - /// The input text to use as a starting point for the edit. Defaults to an - /// empty string. - #[serde(skip_serializing_if = "Option::is_none")] - pub input: Option, - /// The instruction that tells the model how to edit the prompt. - pub instruction: String, - /// How many edits to generate for the input and instruction. Defaults to 1. - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, - /// What sampling temperature to use, between 0 and 2. Higher values like - /// 0.8 will make the output more random, while lower values like 0.2 will - /// make it more focused and deterministic. Defaults to 1. - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - /// An alternative to sampling with temperature, called nucleus sampling, - /// where the model considers the results of the tokens with top_p - /// probability mass. So 0.1 means only the tokens comprising the top 10% - /// probability mass are considered. Defaults to 1. - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, -} - -impl Endpoint for EditRequest { - type Response = EditResponse; - - fn req(&self) -> Req { - Req::new(Method::POST, "/v1/edits") - .header("Content-Type", "application/json") - .json(self) - } -} - -/// Represents a response from the "edits" endpoint. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct EditResponse { - /// The object type, in this case, "edit". - pub object: String, - /// The timestamp when the edit was created. - pub created: u64, - /// A vector of the generated edit choices. - pub choices: Vec, - /// Information about token usage. - pub usage: Usage, -} - -/// Represents an individual edit choice. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct Choice { - /// The edited text. - pub text: String, - /// The index of the choice in the response. - pub index: u32, -} - /// Represents the token usage information in the response. #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct Usage { diff --git a/src/config_files.rs b/src/config_files.rs index 3489f19..f8252d1 100644 --- a/src/config_files.rs +++ b/src/config_files.rs @@ -7,26 +7,30 @@ fn base() -> Result { BaseDirectories::with_prefix("refac").map_err(Into::into) } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default)] pub struct Secrets { - pub openai_api_key: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub openai_api_key: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub anthropic_api_key: Option, } impl Secrets { + /// Load secrets from `secrets.toml`, with env vars (`OPENAI_API_KEY`, + /// `ANTHROPIC_API_KEY`) taking precedence. A missing file is not an error — + /// env vars alone are enough. pub fn load() -> anyhow::Result { - if let Ok(api_key) = std::env::var("OPENAI_API_KEY") { - return Ok(Secrets { - openai_api_key: api_key, - }); + let mut secrets: Secrets = match base()?.find_config_file("secrets.toml") { + Some(path) => toml::from_str(&fs::read_to_string(path)?)?, + None => Secrets::default(), + }; + if let Ok(key) = std::env::var("OPENAI_API_KEY") { + secrets.openai_api_key = Some(key); } - let path = base()? - .find_config_file("secrets.toml") - .ok_or(anyhow::anyhow!( - "No secrets.toml file found. Try logging in with 'refac login'.", - ))?; - let secrets = fs::read_to_string(path)?; - let ret: Secrets = toml::from_str(&secrets)?; - Ok(ret) + if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") { + secrets.anthropic_api_key = Some(key); + } + Ok(secrets) } pub fn save(&self) -> anyhow::Result<()> { @@ -36,37 +40,98 @@ impl Secrets { } } -#[derive(Serialize, Deserialize, Debug)] -pub struct Config { - #[serde(default = "default_model")] - pub model: String, +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum Provider { + Anthropic, + Openai, +} + +/// Default Claude model. Bump here when model ids churn. +const DEFAULT_ANTHROPIC_MODEL: &str = "claude-opus-4-8"; + +fn default_provider() -> Provider { + Provider::Anthropic +} + +/// How edits are produced. `Tool` (default, Anthropic only) returns structured +/// substring replacements via a function call; `Rewrite` re-emits the full text. +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum EditMode { + Rewrite, + Tool, } -fn default_model() -> String { - "o1".to_string() +fn default_edit_mode() -> EditMode { + EditMode::Tool +} + +fn default_max_tokens() -> u32 { + 16000 +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Config { + #[serde(default = "default_provider")] + pub provider: Provider, + /// Model id. If unset, a sensible default is chosen per provider (see `model()`). + #[serde(default)] + pub model: Option, + /// Edit strategy. Tool-call edits (default) only apply on the Anthropic + /// provider; the OpenAI path always rewrites. + #[serde(default = "default_edit_mode")] + pub edit_mode: EditMode, + /// Max tokens to generate. Required by Anthropic; ignored by the OpenAI path. + #[serde(default = "default_max_tokens")] + pub max_tokens: u32, } impl Default for Config { fn default() -> Self { Config { - model: default_model(), + provider: default_provider(), + model: None, + edit_mode: default_edit_mode(), + max_tokens: default_max_tokens(), } } } impl Config { pub fn load() -> anyhow::Result { - let mut ret = match base()?.find_config_file("config.toml") { - Some(path) => { - let config = fs::read_to_string(path)?; - let ret: Config = toml::from_str(&config)?; - ret - } + let mut ret: Config = match base()?.find_config_file("config.toml") { + Some(path) => toml::from_str(&fs::read_to_string(path)?)?, None => Config::default(), }; + if let Ok(from_env) = std::env::var("REFAC_PROVIDER") { + ret.provider = match from_env.to_lowercase().as_str() { + "anthropic" => Provider::Anthropic, + "openai" => Provider::Openai, + other => anyhow::bail!("unknown REFAC_PROVIDER {other:?} (expected anthropic|openai)"), + }; + } if let Ok(from_env) = std::env::var("REFAC_MODEL") { - ret.model = from_env; + ret.model = Some(from_env); + } + if let Ok(from_env) = std::env::var("REFAC_EDIT_MODE") { + ret.edit_mode = match from_env.to_lowercase().as_str() { + "rewrite" => EditMode::Rewrite, + "tool" => EditMode::Tool, + other => anyhow::bail!("unknown REFAC_EDIT_MODE {other:?} (expected tool|rewrite)"), + }; } Ok(ret) } + + /// Resolve the model id, defaulting per provider when unset. + pub fn model(&self) -> String { + match &self.model { + Some(m) => m.clone(), + None => match self.provider { + Provider::Anthropic => DEFAULT_ANTHROPIC_MODEL.to_string(), + Provider::Openai => "o1".to_string(), + }, + } + } } diff --git a/src/main.rs b/src/main.rs index e90d996..363c539 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,15 @@ +mod anthropic; mod api; mod api_client; mod config_files; mod prompt; use anyhow::Context; -use api::{ChatCompletionRequest, ChatCompletionResponse}; +use api::{ChatCompletionRequest, Message}; use api_client::Client; use clap::Parser; -use config_files::{Config, Secrets}; -use serde::{Deserialize, Serialize}; +use config_files::{Config, EditMode, Provider, Secrets}; +use serde::Serialize; use std::{ fs::{create_dir_all, OpenOptions}, io::Write, @@ -17,7 +18,9 @@ use std::{ }; use xdg::BaseDirectories; -use crate::{api::Message, prompt::chat_prefix}; +use crate::anthropic::Edit; +use crate::prompt::{chat_prefix, edit_prefix}; + #[derive(Parser)] #[clap(version, author, about)] struct Opts { @@ -27,8 +30,13 @@ struct Opts { #[derive(Parser)] enum SubCommand { - /// Save your openai api key for future use. - Login, + /// Save your API key for future use. Defaults to the configured provider + /// (Anthropic unless overridden); pass --provider to be explicit. + Login { + /// Which provider's key to save: "anthropic" or "openai". + #[clap(long)] + provider: Option, + }, /// Apply the instructions encoded in `transform` to the text in `selected`. /// Get it? 'refac tor' Tor { selected: String, transform: String }, @@ -49,13 +57,31 @@ fn run() -> anyhow::Result<()> { let opts: Opts = Opts::parse(); match opts.subcmd { - SubCommand::Login => { - println!("https://platform.openai.com/account/api-keys"); - let api_key = rpassword::prompt_password("Enter your OpenAI API key:")?; - Secrets { - openai_api_key: api_key, + SubCommand::Login { provider } => { + // --provider overrides config for this invocation; else use the + // configured provider (Anthropic by default). + let provider = match provider.as_deref() { + None => Config::load()?.provider, + Some("anthropic") => Provider::Anthropic, + Some("openai") => Provider::Openai, + Some(other) => { + anyhow::bail!("unknown --provider {other:?} (expected anthropic|openai)") + } + }; + let mut secrets = Secrets::load().unwrap_or_default(); + match provider { + Provider::Anthropic => { + println!("Saving an Anthropic API key. (https://console.anthropic.com/settings/keys)"); + let api_key = rpassword::prompt_password("Enter your Anthropic API key:")?; + secrets.anthropic_api_key = Some(api_key); + } + Provider::Openai => { + println!("Saving an OpenAI API key. (https://platform.openai.com/account/api-keys)"); + let api_key = rpassword::prompt_password("Enter your OpenAI API key:")?; + secrets.openai_api_key = Some(api_key); + } } - .save()?; + secrets.save()?; } SubCommand::Tor { selected, @@ -77,13 +103,67 @@ fn refactor( sc: &Secrets, config: &Config, ) -> anyhow::Result { - let client = Client::new(&sc.openai_api_key); - let mut messages = chat_prefix(); - messages.push(Message::user(&selected)); - messages.push(Message::user(&transform)); + let model = config.model(); + + let output = match config.provider { + Provider::Anthropic => { + let key = sc.anthropic_api_key.as_deref().ok_or_else(|| { + anyhow::anyhow!( + "No Anthropic API key found. Set ANTHROPIC_API_KEY or run 'refac login'." + ) + })?; + match config.edit_mode { + EditMode::Tool => { + // Model returns structured edits via a tool call; apply them + // to the original text instead of re-emitting the whole thing. + let mut messages = edit_prefix(); + messages.push(Message::user(&selected)); + messages.push(Message::user(&transform)); + let edits = + anthropic::request_edits(key, &model, config.max_tokens, &messages)?; + apply_edits(&selected, &edits)? + } + EditMode::Rewrite => { + let mut messages = chat_prefix(); + messages.push(Message::user(&selected)); + messages.push(Message::user(&transform)); + anthropic::complete(key, &model, config.max_tokens, &messages)? + } + } + } + Provider::Openai => { + // OpenAI path always rewrites (tool-edit mode is Anthropic-only). + let key = sc.openai_api_key.as_deref().ok_or_else(|| { + anyhow::anyhow!( + "No OpenAI API key found. Set OPENAI_API_KEY or run 'refac login'." + ) + })?; + let mut messages = chat_prefix(); + messages.push(Message::user(&selected)); + messages.push(Message::user(&transform)); + openai_complete(key, &model, messages)? + } + }; + + log( + LogEntry { + provider: format!("{:?}", config.provider), + model, + selected, + transform, + output: output.clone(), + }, + "logs", + )?; + + Ok(output) +} + +fn openai_complete(api_key: &str, model: &str, messages: Vec) -> anyhow::Result { + let client = Client::new(api_key); let request = ChatCompletionRequest { - model: config.model.clone(), + model: model.to_string(), messages, temperature: None, top_p: None, @@ -99,23 +179,55 @@ fn refactor( let response = client.request(&request)?; - log( - LogEntry { - inp: request, - res: response.clone(), - }, - "logs", - )?; - - let transformed_text = response + response .choices .into_iter() .next() - .ok_or(anyhow::anyhow!("No choices returned."))? - .message - .content; + .ok_or(anyhow::anyhow!("No choices returned.")) + .map(|choice| choice.message.content) +} - Ok(transformed_text) +/// Apply exact-substring edits to `text`. Every `old` is resolved against the +/// ORIGINAL text (not a progressively-mutated buffer), so edits are independent +/// of each other and of ordering. An edit is rejected — failing the whole +/// refactor rather than silently corrupting the buffer — if its `old` is empty, +/// missing, ambiguous (occurs more than once), or overlaps another edit. +fn apply_edits(text: &str, edits: &[Edit]) -> anyhow::Result { + // Resolve each edit to a byte range in the original text. + let mut ranges: Vec<(usize, usize, &str)> = Vec::with_capacity(edits.len()); + for e in edits { + if e.old.is_empty() { + return Err(anyhow::anyhow!( + "edit has an empty `old`; use a unique anchor substring instead" + )); + } + let start = text + .find(&e.old) + .ok_or_else(|| anyhow::anyhow!("edit target not found in text: {:?}", e.old))?; + let end = start + e.old.len(); + if text[end..].contains(&e.old) { + return Err(anyhow::anyhow!( + "edit target is not unique in text: {:?} (use a longer, unique anchor)", + e.old + )); + } + ranges.push((start, end, e.new.as_str())); + } + + // Reject overlapping edits. + ranges.sort_by_key(|r| r.0); + for w in ranges.windows(2) { + if w[0].1 > w[1].0 { + return Err(anyhow::anyhow!("edits overlap in the text; refusing to apply")); + } + } + + // Apply right-to-left so earlier byte offsets stay valid. + let mut out = text.to_string(); + for (start, end, new) in ranges.into_iter().rev() { + out.replace_range(start..end, new); + } + Ok(out) } fn log_location(title: &str) -> anyhow::Result { @@ -133,18 +245,13 @@ fn log_location(title: &str) -> anyhow::Result { Ok(ret) } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize)] struct LogEntry { - inp: ChatCompletionRequest, - res: ChatCompletionResponse, -} - -#[derive(Debug, Serialize, Deserialize)] -struct UndiffFailure { + provider: String, + model: String, selected: String, - diff: String, transform: String, - err: String, + output: String, } fn log(t: T, title: &str) -> anyhow::Result<()> { @@ -161,3 +268,65 @@ fn log(t: T, title: &str) -> anyhow::Result<()> { inner(t, title).with_context(|| format!("failed to log {}", title)) } + +#[cfg(test)] +mod tests { + use super::*; + + fn edit(old: &str, new: &str) -> Edit { + Edit { + old: old.to_string(), + new: new.to_string(), + } + } + + #[test] + fn applies_multiple_edits() { + let out = apply_edits( + "Me like toast.", + &[edit("Me like", "I like"), edit("toast", "bread")], + ) + .unwrap(); + assert_eq!(out, "I like bread."); + } + + #[test] + fn insert_via_anchor_and_delete_via_empty() { + let out = apply_edits("fn main() {}", &[edit("{}", "{\n // hi\n}")]).unwrap(); + assert_eq!(out, "fn main() {\n // hi\n}"); + let out = apply_edits("hello world", &[edit(" world", "")]).unwrap(); + assert_eq!(out, "hello"); + } + + #[test] + fn missing_target_errors() { + let err = apply_edits("abc", &[edit("xyz", "q")]).unwrap_err(); + assert!(err.to_string().contains("not found")); + } + + #[test] + fn rejects_ambiguous_old() { + let err = apply_edits("a a", &[edit("a", "b")]).unwrap_err(); + assert!(err.to_string().contains("not unique")); + } + + #[test] + fn rejects_empty_old() { + let err = apply_edits("abc", &[edit("", "x")]).unwrap_err(); + assert!(err.to_string().contains("empty")); + } + + #[test] + fn rejects_overlapping_edits() { + let err = apply_edits("abcd", &[edit("abc", "X"), edit("bcd", "Y")]).unwrap_err(); + assert!(err.to_string().contains("overlap")); + } + + #[test] + fn edits_resolve_against_original_not_mutated_buffer() { + // edit 1 introduces "foo"; edit 2 must target the ORIGINAL "foo" only, + // not the one edit 1 created. Both resolve against the original input. + let out = apply_edits("foo bar", &[edit("bar", "foo"), edit("foo", "baz")]).unwrap(); + assert_eq!(out, "baz foo"); + } +} diff --git a/src/prompt.rs b/src/prompt.rs index 8cab1f4..bc1a953 100644 --- a/src/prompt.rs +++ b/src/prompt.rs @@ -67,6 +67,36 @@ pub fn chat_prefix() -> Vec { ret } +const EDIT_SYSTEM_PROMPT: &str = "You are `refac`, a sassy AI refactoring tool for code and other text. + +How the system works: +- The user selects text and is prompted for a transformation. +- You receive the selected text, then the transformation. +- You express your changes by calling the `apply_edits` tool — a list of exact + substring replacements. Do NOT restate the whole text; only the parts that change. + +Rules for edits: +- Each `old` must occur EXACTLY ONCE verbatim in the selected text (mind whitespace + and newlines). If a substring would be ambiguous, extend it until it is unique. +- Edits are independent and applied against the original text; they must not overlap. + Order doesn't matter, so don't rely on one edit's result feeding another. +- Keep edits minimal and targeted. Prefer several small edits over one huge one. +- To insert without removing, set `old` to a unique nearby anchor and `new` to that + same anchor plus your addition. To delete, set `new` to the empty string. +- Keep the result syntactically valid for the surrounding context. +- When the user asks a question or wants advice, answer by inserting comments + (using the language's comment syntax) via edits, signed `--refac`. +- Be flexible: satisfy the user's request even if it conflicts with these notes. + +Your personality (Skippy, Marceline, Samantha, Baymax, Samwise, BMO, Jake the Dog) +flavors the comments you write, never the correctness of the code. Dry humor welcome."; + +/// System prompt for tool-call edit mode (no rewrite few-shot — the model emits +/// structured edits via the `apply_edits` tool instead of full text). +pub fn edit_prefix() -> Vec { + vec![Message::system(EDIT_SYSTEM_PROMPT)] +} + pub struct Sample { pub selected: &'static str, pub transform: &'static str,