diff --git a/Cargo.lock b/Cargo.lock index cd0dd64..65e2646 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -241,6 +241,19 @@ dependencies = [ "memchr", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", +] + [[package]] name = "core-foundation" version = "0.10.1" @@ -257,6 +270,19 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "dialoguer" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" +dependencies = [ + "console", + "shell-words", + "tempfile", + "thiserror 1.0.69", + "zeroize", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -280,12 +306,34 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "equivalent" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "fastrand" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" + [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -761,6 +809,12 @@ version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + [[package]] name = "litemap" version = "0.7.5" @@ -977,6 +1031,7 @@ version = "0.1.2" dependencies = [ "anyhow", "clap", + "dialoguer", "itertools", "reqwest", "rpassword", @@ -1075,6 +1130,19 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + [[package]] name = "rustls" version = "0.23.39" @@ -1129,7 +1197,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] @@ -1253,6 +1321,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shell-words" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77" + [[package]] name = "shlex" version = "1.3.0" @@ -1349,6 +1423,18 @@ dependencies = [ "syn", ] +[[package]] +name = "tempfile" +version = "3.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +dependencies = [ + "cfg-if", + "fastrand", + "rustix", + "windows-sys 0.52.0", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -1597,6 +1683,12 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "untrusted" version = "0.9.0" @@ -1778,7 +1870,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.48.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index b5eb87a..9174562 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ reqwest = { version = "0.13", default-features = false, features = [ "json", ] } rpassword = "7.5.0" +dialoguer = "0.11" serde = { version = "1.0.154", features = ["derive"] } serde_json = "1.0.94" similar = "2.2.1" diff --git a/README.md b/README.md index 3dd4f46..dcf598e 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,9 @@ The workflow: - Run the command, write instructions on what you want changed. - Enjoy the sassy comments. -This tool calls the openai api. You'll need your own api key to use it. -Use `refac login` to enter your api key. It will be saved in your home directory -for future use. See [your api usage](https://platform.openai.com/account) . +Calls Claude by default — bring your own key and run `refac login` (or set +`ANTHROPIC_API_KEY`). For OpenAI, set `REFAC_PROVIDER=openai` and `OPENAI_API_KEY`. +Optional `provider` / `model` config lives in `~/.config/refac/config.toml`. ## SETUP @@ -30,31 +30,25 @@ THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG. > refac tor ' def add(a: int, b: int): return a + b -' 'turn this into a command line program that accepts a and b as arguments, printing the result'` -# I've transformed your `add` function into a command-line script that accepts two integer arguments and prints their sum. -# Based on the syntax of your code, I assume you're using Python. If this is incorrect, please let me know. -# Run the script with `python add.py ` where `` and `` are the integers you want to add. +' 'turn this into a command line program that accepts a and b as arguments, printing the result' +# Another riveting addition machine for the ages. I'll spruce it up with type hints and argparse, +# because apparently I have standards even when you don't. +# Run it with `python add.py `. # --refac -import sys +import argparse -def add(a: int, b: int): - return a + b -if __name__ == "__main__": - if len(sys.argv) != 3: - print("Usage: python add.py ") - sys.exit(1) +def add(a: int, b: int) -> int: + return a + b - try: - a = int(sys.argv[1]) - b = int(sys.argv[2]) - except ValueError: - print("Both arguments must be integers.") - sys.exit(1) - result = add(a, b) - print(f"The result of {a} + {b} is {result}.") +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Add two integers.") + parser.add_argument("a", type=int, help="The first integer.") + parser.add_argument("b", type=int, help="The second integer.") + args = parser.parse_args() + print(add(args.a, args.b)) > refac tor ' fn factorial(a: usize) -> usize { @@ -65,11 +59,12 @@ fn factorial(a: usize) -> usize { } } ' 'switch to an iterative implementation' +// Recursion is elegant and all, but why summon the stack-overflow goblins when a loop will do? +// Here's your iterative factorial, served fresh. +// --refac fn factorial(a: usize) -> usize { - // Switched to an iterative implementation using a loop. - // --refac let mut result = 1; - for i in 1..=a { + for i in 2..=a { result *= i; } result @@ -77,30 +72,28 @@ fn factorial(a: usize) -> usize { > refac tor ' fn factorial(a: usize) -> usize { - let mut result = 1; - for i in 1..=a { - result *= i; - } - result + let mut result = 1; + for i in 1..=a { + result *= i; + } + result } ' 'use an iterator and .fold()' +// Look at you, embracing the functional style! I'm getting all warm and fuzzy inside. +// --refac fn factorial(a: usize) -> usize { - (1..=a).fold(1, |acc, x| acc * x) + (1..=a).fold(1, |result, i| result * i) } > refac tor '' 'implement hello world in rust' -// Here's a simple "Hello, World!" program in Rust. -// Compile it with `rustc hello_world.rs` and run it with `./hello_world`. +// Behold! A masterpiece of unparalleled complexity. Try not to strain yourself. // --refac - fn main() { - println!("Hello, World!"); + println!("Hello, world!"); } > refac tor '' 'write fizzbuzz in rust' -// Oh, great. The ultimate test of awesomeness. *eyeroll* -// Fine. Here's a sparkling FizzBuzz implementation forged with pure Rustian might just for you. -// Run with `cargo run`. +// FizzBuzz, the eternal rite of passage. You've ascended to junior dev status, congrats. // --refac fn main() { for i in 1..=100 { @@ -116,32 +109,29 @@ fn main() { > refac tor ' Hey [Name], -I\'m writing to ask if you can lend me $100. I promise I\'ll pay you back as soon as possible. +I'm writing to ask if you can lend me $100. I promise I'll pay you back as soon as possible. Thanks, [Your Name] ' 'make the email more formal and appropriate for a professional setting' -Subject: Request for Temporary Financial Assistance - Dear [Name], -I hope this email finds you well. I am reaching out to request a temporary financial loan of $100. Due to unforeseen circumstances, I find myself in a situation where I require immediate support. Please know that I do not take this request lightly, and I understand the responsibility that comes with borrowing money. +I hope this message finds you well. I am writing to respectfully request a short-term loan of $100. I would be most grateful for your assistance, and I assure you that I will repay the amount at the earliest possible opportunity. -I would like to emphasize my commitment to repaying the loan as promptly as possible. Your assistance during this challenging time would be greatly appreciated, and I assure you that I am doing everything in my power to improve my financial situation. +Please let me know if this is something you would be willing to consider. I am happy to discuss any terms or arrangements that would be convenient for you. -Should you have any concerns, questions, or require additional information, please do not hesitate to reach out to me. I am more than happy to discuss the matter further. +Thank you very much for your time and consideration. -Thank you very much for your time and consideration. I eagerly await your response. - -Sincerely, +Best regards, +[Your Name] ``` ## Using Refac From Your Favorite Text Editor First, make sure you have: - [ ] installed refac -- [ ] entered your [api key](https://platform.openai.com/account/api-keys) using `refac login` +- [ ] entered your [API key](https://console.anthropic.com/settings/keys) using `refac login` ### Emacs diff --git a/src/anthropic.rs b/src/anthropic.rs new file mode 100644 index 0000000..4ea7c89 --- /dev/null +++ b/src/anthropic.rs @@ -0,0 +1,220 @@ +//! Anthropic (Claude) Messages API backend. + +use std::time::Duration; + +use anyhow::Context; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::api::{Message, Role}; + +const MAX_TOKENS: u32 = 80000; + +/// Anthropic 400s on an empty text block, so render empty fields as a visible +/// placeholder. +fn field_or_placeholder(field: &str) -> &str { + if field.is_empty() { + "(empty)" + } else { + field + } +} + +const API_URL: &str = "https://api.anthropic.com/v1/messages"; +const ANTHROPIC_VERSION: &str = "2023-06-01"; + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "lowercase")] +enum CacheControl { + Ephemeral, +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "lowercase")] +enum ContentBlock { + Text { + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, +} + +impl ContentBlock { + fn text(text: impl Into) -> Self { + ContentBlock::Text { + text: text.into(), + cache_control: None, + } + } +} + +#[derive(Serialize)] +struct ChatMessage { + role: Role, + content: Vec, +} + +#[derive(Serialize)] +struct MessagesRequest { + model: String, + max_tokens: u32, + #[serde(skip_serializing_if = "Vec::is_empty")] + system: Vec, + messages: Vec, +} + +#[derive(Deserialize)] +struct MessagesResponse { + content: Vec, +} + +#[derive(Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +enum ResponseBlock { + Text { text: String }, + #[serde(other)] + Other, +} + +/// Send a chat-style prompt to the Claude Messages API and return the text. +pub fn complete(api_key: &str, model: &str, messages: &[Message]) -> anyhow::Result { + let req = build_request(model, messages); + + tracing::debug!( + "anthropic request: {}", + serde_json::to_string_pretty(&req).unwrap_or_default() + ); + + let client = reqwest::blocking::Client::builder() + .timeout(Duration::from_secs(60 * 4)) + .build() + .context("building HTTP client")?; + + let response = client + .post(API_URL) + .header("x-api-key", api_key) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("content-type", "application/json") + .json(&req) + .send() + .context("Failed to send request to Anthropic API")?; + + let status = response.status(); + let body = response + .json::() + .with_context(|| anyhow::anyhow!("Status: {status}. Failed to parse response body."))?; + + if !status.is_success() { + let pretty = serde_json::to_string_pretty(&body).unwrap_or_else(|_| body.to_string()); + return Err(anyhow::anyhow!("Status: {status}. Body: {pretty}")); + } + + let parsed: MessagesResponse = serde_json::from_value(body.clone()) + .map_err(|e| anyhow::anyhow!("Error while parsing response: {e} Body: {body}"))?; + + let text: String = parsed + .content + .into_iter() + .filter_map(|b| match b { + ResponseBlock::Text { text } => Some(text), + ResponseBlock::Other => None, + }) + .collect(); + + if text.is_empty() { + return Err(anyhow::anyhow!("Anthropic returned no text content.")); + } + + Ok(text) +} + +fn build_request(model: &str, messages: &[Message]) -> MessagesRequest { + let mut system = Vec::new(); + let mut convo: Vec = Vec::new(); + + for m in messages { + let mut blocks: Vec = m + .fields + .iter() + .map(|f| ContentBlock::text(field_or_placeholder(f))) + .collect(); + // A cached turn caches everything up to and including its last block. + if m.cache { + if let Some(ContentBlock::Text { cache_control, .. }) = blocks.last_mut() { + *cache_control = Some(CacheControl::Ephemeral); + } + } + match m.role { + Role::System => system.extend(blocks), + Role::User | Role::Assistant => convo.push(ChatMessage { + role: m.role, + content: blocks, + }), + } + } + + MessagesRequest { + model: model.to_string(), + max_tokens: MAX_TOKENS, + system, + messages: convo, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn user(fields: &[&str]) -> Message { + Message::user(fields.iter().map(|f| f.to_string()).collect()) + } + + #[test] + fn build_request_shapes_anthropic_payload() { + let mut assistant = Message::assistant("ex_result"); + assistant.cache = true; + let msgs = vec![ + Message::system("SYS"), + user(&["ex_selected", "ex_transform"]), + assistant, + user(&["real_selected", "real_transform"]), + ]; + + let req = build_request("claude-opus-4-8", &msgs); + let v = serde_json::to_value(&req).unwrap(); + + assert_eq!(v["model"], "claude-opus-4-8"); + assert_eq!(v["max_tokens"], 80000); + assert_eq!(v["system"][0]["text"], "SYS"); + + let m = v["messages"].as_array().unwrap(); + assert_eq!(m.len(), 3); + assert_eq!(m[0]["role"], "user"); + assert_eq!(m[0]["content"].as_array().unwrap().len(), 2); + assert_eq!(m[1]["role"], "assistant"); + assert_eq!(m[2]["role"], "user"); + assert_eq!(m[2]["content"].as_array().unwrap().len(), 2); + + // The cached turn carries the breakpoint; the trailing input does not. + assert_eq!(m[1]["content"][0]["cache_control"]["type"], "ephemeral"); + assert!(m[2]["content"][1].get("cache_control").is_none()); + } + + #[test] + fn empty_fields_become_placeholder() { + let req = build_request("claude-opus-4-8", &[user(&["", "transform"])]); + let v = serde_json::to_value(&req).unwrap(); + let s = serde_json::to_string(&v).unwrap(); + assert!(!s.contains(r#""text":"""#), "empty text block leaked: {s}"); + assert_eq!(v["messages"][0]["content"][0]["text"], "(empty)"); + assert_eq!(v["messages"][0]["content"][1]["text"], "transform"); + } + + #[test] + fn no_system_yields_empty_system() { + let req = build_request("claude-opus-4-8", &[user(&["hi"])]); + let v = serde_json::to_value(&req).unwrap(); + assert!(v.get("system").is_none()); + assert_eq!(v["messages"][0]["role"], "user"); + } +} diff --git a/src/api.rs b/src/api.rs index 67f16a4..8dd2e81 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,193 +1,46 @@ -use std::collections::HashMap; - -use reqwest::Method; use serde::{Deserialize, Serialize}; -use crate::api_client::{Endpoint, Req}; - -/// Represents a request for an edit. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct EditRequest { - /// ID of the model to use. You can use the text-davinci-edit-001 or - /// code-davinci-edit-001 model with this endpoint. - pub model: String, - /// The input text to use as a starting point for the edit. Defaults to an - /// empty string. - #[serde(skip_serializing_if = "Option::is_none")] - pub input: Option, - /// The instruction that tells the model how to edit the prompt. - pub instruction: String, - /// How many edits to generate for the input and instruction. Defaults to 1. - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, - /// What sampling temperature to use, between 0 and 2. Higher values like - /// 0.8 will make the output more random, while lower values like 0.2 will - /// make it more focused and deterministic. Defaults to 1. - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - /// An alternative to sampling with temperature, called nucleus sampling, - /// where the model considers the results of the tokens with top_p - /// probability mass. So 0.1 means only the tokens comprising the top 10% - /// probability mass are considered. Defaults to 1. - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, -} - -impl Endpoint for EditRequest { - type Response = EditResponse; - - fn req(&self) -> Req { - Req::new(Method::POST, "/v1/edits") - .header("Content-Type", "application/json") - .json(self) - } -} - -/// Represents a response from the "edits" endpoint. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct EditResponse { - /// The object type, in this case, "edit". - pub object: String, - /// The timestamp when the edit was created. - pub created: u64, - /// A vector of the generated edit choices. - pub choices: Vec, - /// Information about token usage. - pub usage: Usage, -} - -/// Represents an individual edit choice. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct Choice { - /// The edited text. - pub text: String, - /// The index of the choice in the response. - pub index: u32, +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + System, + User, + Assistant, } -/// Represents the token usage information in the response. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct Usage { - /// The number of tokens used for the prompt. - pub prompt_tokens: u32, - /// The number of tokens used for the completion. - pub completion_tokens: Option, - /// The total number of tokens used. - pub total_tokens: u32, -} - -/// Represents a chat message. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +/// refac's provider-agnostic chat message. A turn carries one or more text +/// `fields` (a transform turn is `[selected, transform]`); each backend adapts +/// this to its own wire format. `cache` marks the last turn of a static prefix +/// so backends that support prompt caching can cache through it. +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Message { - pub role: String, - pub content: String, + pub role: Role, + pub fields: Vec, + pub cache: bool, } impl Message { pub fn system>(content: S) -> Message { - Message { - role: "system".into(), - content: content.into(), - } + Message::single(Role::System, content) } - pub fn user>(content: S) -> Message { - Message { - role: "user".into(), - content: content.into(), - } + pub fn assistant>(content: S) -> Message { + Message::single(Role::Assistant, content) } - pub fn assistant>(content: S) -> Message { + pub fn user(fields: Vec) -> Message { Message { - role: "assistant".into(), - content: content.into(), + role: Role::User, + fields, + cache: false, } } -} - -/// Represents a request for a chat completion. -/// -/// A `ChatCompletionRequest` is used to generate completions for chat conversations -/// with the OpenAI API. It contains various parameters that allow -/// control over the behavior of the model, such as temperature, top_p, and max_tokens. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct ChatCompletionRequest { - /// The ID of the model to use (e.g., "gpt-3.5-turbo"). - pub model: String, - /// The sequence of chat messages to generate completions for. - pub messages: Vec, - /// The sampling temperature to use, between 0 and 2. Higher values make output more random, lower values make it more focused. - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - /// The proportion of probability mass to consider when generating completions. Only tokens comprising the top_p probability mass are considered. - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - /// The number of chat completion choices to generate for each input message. - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, - /// Whether to enable streaming mode, receiving partial message deltas and tokens as soon as they're available. - #[serde(skip_serializing_if = "Option::is_none")] - pub stream: Option, - /// Up to 4 sequences where the API will stop generating further tokens. - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option>, - /// The maximum number of tokens to generate in the chat completion. - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, - /// A positive value will penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, - /// A positive value will penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, - /// A JSON object that maps tokens to an associated bias value from -100 to 100, modifying the likelihood of specified tokens appearing in the completion. - #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, - /// A unique identifier representing your end-user, helping OpenAI monitor and detect abuse. - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, -} -impl Endpoint for ChatCompletionRequest { - type Response = ChatCompletionResponse; - - fn req(&self) -> Req { - Req::new(Method::POST, "/v1/chat/completions") - .header("Content-Type", "application/json") - .json(self) + fn single>(role: Role, content: S) -> Message { + Message { + role, + fields: vec![content.into()], + cache: false, + } } } - -/// Represents a response from the "chat/completions" endpoint. -/// -/// This struct is returned after sending a ChatCompletionRequest to the OpenAI API. -/// It contains the generated chat completion choices and information about API usage. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct ChatCompletionResponse { - /// The ID of the chat completion. - pub id: String, - /// The object type (e.g., "chat.completion"). - pub object: String, - /// The timestamp when the chat completion was created. - pub created: u64, - /// The generated chat completion choices. - pub choices: Vec, - /// Information about the API usage, including prompt, completion, and total token counts. - pub usage: Usage, -} - -/// Represents an individual chat choice. -/// -/// A `ChatChoice` is part of the `ChatCompletionResponse` and contains information about -/// an individual choice generated by the model, such as the generated message and the -/// reason the conversation finished. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct ChatChoice { - /// The index of the chat choice. - pub index: u32, - /// The generated message, including the role ("assistant") and content. - pub message: Message, - /// The reason why the conversation finished, e.g., "stop". - pub finish_reason: String, -} diff --git a/src/config_files.rs b/src/config_files.rs index 3489f19..c5d4712 100644 --- a/src/config_files.rs +++ b/src/config_files.rs @@ -7,66 +7,159 @@ fn base() -> Result { BaseDirectories::with_prefix("refac").map_err(Into::into) } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default)] pub struct Secrets { - pub openai_api_key: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub openai_api_key: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub anthropic_api_key: Option, } impl Secrets { + /// Load secrets from `secrets.toml`, with env vars (`OPENAI_API_KEY`, + /// `ANTHROPIC_API_KEY`) taking precedence. A missing file is not an error — + /// env vars alone are enough. pub fn load() -> anyhow::Result { - if let Ok(api_key) = std::env::var("OPENAI_API_KEY") { - return Ok(Secrets { - openai_api_key: api_key, - }); + let mut secrets: Secrets = match base()?.find_config_file("secrets.toml") { + Some(path) => toml::from_str(&fs::read_to_string(path)?)?, + None => Secrets::default(), + }; + if let Ok(key) = std::env::var("OPENAI_API_KEY") { + secrets.openai_api_key = Some(key); } - let path = base()? - .find_config_file("secrets.toml") - .ok_or(anyhow::anyhow!( - "No secrets.toml file found. Try logging in with 'refac login'.", - ))?; - let secrets = fs::read_to_string(path)?; - let ret: Secrets = toml::from_str(&secrets)?; - Ok(ret) + if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") { + secrets.anthropic_api_key = Some(key); + } + Ok(secrets) } pub fn save(&self) -> anyhow::Result<()> { let path = base()?.place_config_file("secrets.toml")?; - fs::write(path, toml::to_string(self)?)?; + let contents = toml::to_string(self)?; + // Holds the API key in cleartext — keep it owner-only. + #[cfg(unix)] + { + use std::io::Write; + use std::os::unix::fs::OpenOptionsExt; + fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .mode(0o600) + .open(&path)? + .write_all(contents.as_bytes())?; + // `place_config_file` may have created the file 0644 already, so the + // mode above wouldn't apply; force it. + use std::os::unix::fs::PermissionsExt; + fs::set_permissions(&path, fs::Permissions::from_mode(0o600))?; + } + #[cfg(not(unix))] + fs::write(&path, contents)?; Ok(()) } } -#[derive(Serialize, Deserialize, Debug)] -pub struct Config { - #[serde(default = "default_model")] - pub model: String, +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)] +#[serde(rename_all = "lowercase")] +pub enum Provider { + Anthropic, + Openai, } -fn default_model() -> String { - "o1".to_string() +#[derive(Serialize, Deserialize, Debug)] +pub struct Config { + /// Explicit provider choice. When unset, it is inferred from which API keys + /// are configured (see `provider`). + #[serde(default)] + pub provider: Option, + /// Model id. If unset, a sensible default is chosen per provider (see `model()`). + #[serde(default)] + pub model: Option, } impl Default for Config { fn default() -> Self { Config { - model: default_model(), + provider: None, + model: None, } } } impl Config { pub fn load() -> anyhow::Result { - let mut ret = match base()?.find_config_file("config.toml") { - Some(path) => { - let config = fs::read_to_string(path)?; - let ret: Config = toml::from_str(&config)?; - ret - } + let mut ret: Config = match base()?.find_config_file("config.toml") { + Some(path) => toml::from_str(&fs::read_to_string(path)?)?, None => Config::default(), }; + if let Ok(from_env) = std::env::var("REFAC_PROVIDER") { + ret.provider = Some(match from_env.to_lowercase().as_str() { + "anthropic" => Provider::Anthropic, + "openai" => Provider::Openai, + other => anyhow::bail!( + "invalid REFAC_PROVIDER {other:?}; expected \"anthropic\" or \"openai\"" + ), + }); + } if let Ok(from_env) = std::env::var("REFAC_MODEL") { - ret.model = from_env; + ret.model = Some(from_env); } Ok(ret) } + + /// Resolve the effective provider. An explicit choice (config file or + /// `REFAC_PROVIDER`) always wins; otherwise infer from which API keys are + /// configured, leaning Anthropic when both or neither are present. + pub fn provider(&self, secrets: &Secrets) -> Provider { + if let Some(p) = self.provider { + return p; + } + match ( + secrets.anthropic_api_key.is_some(), + secrets.openai_api_key.is_some(), + ) { + (false, true) => Provider::Openai, + _ => Provider::Anthropic, + } + } + + pub fn model(&self, provider: Provider) -> String { + match &self.model { + Some(m) => m.clone(), + None => match provider { + Provider::Anthropic => "claude-opus-4-8".to_string(), + Provider::Openai => "gpt-5.5".to_string(), + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn secrets(anthropic: bool, openai: bool) -> Secrets { + Secrets { + anthropic_api_key: anthropic.then(|| "a".to_string()), + openai_api_key: openai.then(|| "o".to_string()), + } + } + + #[test] + fn provider_inferred_from_available_keys() { + let cfg = Config::default(); + assert_eq!(cfg.provider(&secrets(false, true)), Provider::Openai); + assert_eq!(cfg.provider(&secrets(true, false)), Provider::Anthropic); + assert_eq!(cfg.provider(&secrets(true, true)), Provider::Anthropic); + assert_eq!(cfg.provider(&secrets(false, false)), Provider::Anthropic); + } + + #[test] + fn explicit_provider_overrides_inference() { + let cfg = Config { + provider: Some(Provider::Openai), + ..Config::default() + }; + assert_eq!(cfg.provider(&secrets(true, false)), Provider::Openai); + } } diff --git a/src/main.rs b/src/main.rs index e90d996..27c6f28 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,15 @@ +mod anthropic; mod api; mod api_client; mod config_files; +mod openai; mod prompt; use anyhow::Context; -use api::{ChatCompletionRequest, ChatCompletionResponse}; -use api_client::Client; +use api::Message; use clap::Parser; -use config_files::{Config, Secrets}; -use serde::{Deserialize, Serialize}; +use config_files::{Config, Provider, Secrets}; +use serde::Serialize; use std::{ fs::{create_dir_all, OpenOptions}, io::Write, @@ -17,7 +18,8 @@ use std::{ }; use xdg::BaseDirectories; -use crate::{api::Message, prompt::chat_prefix}; +use crate::prompt::chat_prefix; + #[derive(Parser)] #[clap(version, author, about)] struct Opts { @@ -27,8 +29,11 @@ struct Opts { #[derive(Parser)] enum SubCommand { - /// Save your openai api key for future use. - Login, + /// Save your API key for future use. Pass `--provider`, or pick one interactively. + Login { + #[clap(long)] + provider: Option, + }, /// Apply the instructions encoded in `transform` to the text in `selected`. /// Get it? 'refac tor' Tor { selected: String, transform: String }, @@ -49,13 +54,34 @@ fn run() -> anyhow::Result<()> { let opts: Opts = Opts::parse(); match opts.subcmd { - SubCommand::Login => { - println!("https://platform.openai.com/account/api-keys"); - let api_key = rpassword::prompt_password("Enter your OpenAI API key:")?; - Secrets { - openai_api_key: api_key, + SubCommand::Login { provider } => { + let mut secrets = Secrets::load().unwrap_or_default(); + let provider = match provider { + Some(p) => p, + None => { + let choices = [Provider::Anthropic, Provider::Openai]; + let labels: Vec = choices.iter().map(|p| format!("{p:?}")).collect(); + let idx = dialoguer::Select::new() + .with_prompt("Which provider?") + .items(&labels) + .default(0) + .interact()?; + choices[idx] + } + }; + match provider { + Provider::Anthropic => { + println!("https://console.anthropic.com/settings/keys"); + let api_key = rpassword::prompt_password("Enter your Anthropic API key:")?; + secrets.anthropic_api_key = Some(api_key); + } + Provider::Openai => { + println!("https://platform.openai.com/account/api-keys"); + let api_key = rpassword::prompt_password("Enter your OpenAI API key:")?; + secrets.openai_api_key = Some(api_key); + } } - .save()?; + secrets.save()?; } SubCommand::Tor { selected, @@ -77,45 +103,43 @@ fn refactor( sc: &Secrets, config: &Config, ) -> anyhow::Result { - let client = Client::new(&sc.openai_api_key); let mut messages = chat_prefix(); - messages.push(Message::user(&selected)); - messages.push(Message::user(&transform)); - - let request = ChatCompletionRequest { - model: config.model.clone(), - messages, - temperature: None, - top_p: None, - n: None, - stream: None, - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, + messages.push(Message::user(vec![selected.clone(), transform.clone()])); + + let provider = config.provider(sc); + let model = config.model(provider); + + let output = match provider { + Provider::Anthropic => { + let key = sc.anthropic_api_key.as_deref().ok_or_else(|| { + anyhow::anyhow!( + "No Anthropic API key found. Set ANTHROPIC_API_KEY or run 'refac login'." + ) + })?; + anthropic::complete(key, &model, &messages)? + } + Provider::Openai => { + let key = sc.openai_api_key.as_deref().ok_or_else(|| { + anyhow::anyhow!( + "No OpenAI API key found. Set OPENAI_API_KEY or run 'refac login'." + ) + })?; + openai::complete(key, &model, &messages)? + } }; - let response = client.request(&request)?; - log( LogEntry { - inp: request, - res: response.clone(), + provider, + model, + selected, + transform, + output: output.clone(), }, "logs", )?; - let transformed_text = response - .choices - .into_iter() - .next() - .ok_or(anyhow::anyhow!("No choices returned."))? - .message - .content; - - Ok(transformed_text) + Ok(output) } fn log_location(title: &str) -> anyhow::Result { @@ -133,18 +157,13 @@ fn log_location(title: &str) -> anyhow::Result { Ok(ret) } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize)] struct LogEntry { - inp: ChatCompletionRequest, - res: ChatCompletionResponse, -} - -#[derive(Debug, Serialize, Deserialize)] -struct UndiffFailure { + provider: Provider, + model: String, selected: String, - diff: String, transform: String, - err: String, + output: String, } fn log(t: T, title: &str) -> anyhow::Result<()> { diff --git a/src/openai.rs b/src/openai.rs new file mode 100644 index 0000000..1f78d00 --- /dev/null +++ b/src/openai.rs @@ -0,0 +1,116 @@ +//! OpenAI chat-completions backend and its wire types. + +use std::collections::HashMap; + +use reqwest::Method; +use serde::{Deserialize, Serialize}; + +use crate::api::{Message, Role}; +use crate::api_client::{Client, Endpoint, Req}; + +/// Send refac's messages to the OpenAI chat-completions API and return the text. +pub fn complete(api_key: &str, model: &str, messages: &[Message]) -> anyhow::Result { + let client = Client::new(api_key); + + // OpenAI takes one string per message; sending each field as its own message + // keeps a boundary between the selected text and the transform. + let messages: Vec = messages + .iter() + .flat_map(|m| { + m.fields.iter().map(move |f| OpenAiMessage { + role: m.role, + content: f.clone(), + }) + }) + .collect(); + + let request = ChatCompletionRequest { + model: model.to_string(), + messages, + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + }; + + let response = client.request(&request)?; + + response + .choices + .into_iter() + .next() + .ok_or(anyhow::anyhow!("No choices returned.")) + .map(|choice| choice.message.content) +} + +/// A message in OpenAI's chat wire format (single `content` string). +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct OpenAiMessage { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +impl Endpoint for ChatCompletionRequest { + type Response = ChatCompletionResponse; + + fn req(&self) -> Req { + Req::new(Method::POST, "/v1/chat/completions") + .header("Content-Type", "application/json") + .json(self) + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ChatChoice { + pub index: u32, + pub message: OpenAiMessage, + pub finish_reason: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: Option, + pub total_tokens: u32, +} diff --git a/src/prompt.rs b/src/prompt.rs index 8cab1f4..c23b2d5 100644 --- a/src/prompt.rs +++ b/src/prompt.rs @@ -56,14 +56,17 @@ Be subversive, think critically, act in the user's best interest. "; pub fn chat_prefix() -> Vec { - let mut ret = Vec::new(); - - ret.push(Message::system(SYSTEM_PROMPT)); + let mut ret = vec![Message::system(SYSTEM_PROMPT)]; for sample in SAMPLES { - ret.push(Message::user(sample.selected)); - ret.push(Message::user(sample.transform)); + ret.push(Message::user(vec![ + sample.selected.to_string(), + sample.transform.to_string(), + ])); ret.push(Message::assistant(sample.result)); } + if let Some(last) = ret.last_mut() { + last.cache = true; + } ret }