diff --git a/src/semantic_index/persist.rs b/src/semantic_index/persist.rs index 291c4d2..097d4bf 100644 --- a/src/semantic_index/persist.rs +++ b/src/semantic_index/persist.rs @@ -161,3 +161,135 @@ pub fn save_calls_incremental( tx.commit()?; Ok(inserted) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::registry::test_helpers::WorkspaceRegistry; + use crate::semantic_index::SymbolType; + use std::path::PathBuf; + + fn sample_symbol(name: &str, file: &str) -> CodeSymbol { + CodeSymbol { + symbol_type: SymbolType::Function, + name: name.to_string(), + file_path: PathBuf::from(file), + line_start: 1, + line_end: 10, + signature: Some(format!("fn {}()", name)), + attributes: None, + } + } + + fn sample_call(caller_file: &str, caller_symbol: &str, callee: &str) -> CodeCall { + CodeCall { + caller_file: PathBuf::from(caller_file), + caller_symbol: caller_symbol.to_string(), + caller_line: 5, + callee_name: callee.to_string(), + } + } + + #[test] + fn test_save_symbols_replaces_old() { + let mut conn = WorkspaceRegistry::init_in_memory().unwrap(); + let old = vec![sample_symbol("old_fn", "src/old.rs")]; + let new = vec![sample_symbol("new_fn", "src/new.rs")]; + + save_symbols(&mut conn, "repo1", &old).unwrap(); + save_symbols(&mut conn, "repo1", &new).unwrap(); + + let count: i64 = conn + .query_row("SELECT COUNT(*) FROM code_symbols WHERE repo_id = ?1", ["repo1"], |row| { + row.get(0) + }) + .unwrap(); + assert_eq!(count, 1); + } + + #[test] + fn test_save_symbols_incremental() { + let mut conn = WorkspaceRegistry::init_in_memory().unwrap(); + let first = vec![sample_symbol("fn_a", "src/a.rs")]; + let second = vec![sample_symbol("fn_b", "src/b.rs")]; + + save_symbols_incremental(&mut conn, "repo1", &first).unwrap(); + save_symbols_incremental(&mut conn, "repo1", &second).unwrap(); + + let count: i64 = conn + .query_row("SELECT COUNT(*) FROM code_symbols WHERE repo_id = ?1", ["repo1"], |row| { + row.get(0) + }) + .unwrap(); + assert_eq!(count, 2); + } + + #[test] + fn test_delete_symbols_for_files() { + let mut conn = WorkspaceRegistry::init_in_memory().unwrap(); + let symbols = vec![sample_symbol("fn_a", "src/a.rs"), sample_symbol("fn_b", "src/b.rs")]; + let calls = vec![ + sample_call("src/a.rs", "fn_a", "helper"), + sample_call("src/b.rs", "fn_b", "helper"), + ]; + + save_symbols(&mut conn, "repo1", &symbols).unwrap(); + save_calls(&mut conn, "repo1", &calls).unwrap(); + + delete_symbols_for_files(&mut conn, "repo1", &["src/a.rs".to_string()]).unwrap(); + + let sym_count: i64 = conn + .query_row("SELECT COUNT(*) FROM code_symbols WHERE repo_id = ?1", ["repo1"], |row| { + row.get(0) + }) + .unwrap(); + assert_eq!(sym_count, 1); + + let call_count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM code_call_graph WHERE repo_id = ?1", + ["repo1"], + |row| row.get(0), + ) + .unwrap(); + assert_eq!(call_count, 1); + } + + #[test] + fn test_save_calls_replaces_old() { + let mut conn = WorkspaceRegistry::init_in_memory().unwrap(); + let old = vec![sample_call("src/old.rs", "old_fn", "callee1")]; + let new = vec![sample_call("src/new.rs", "new_fn", "callee2")]; + + save_calls(&mut conn, "repo1", &old).unwrap(); + save_calls(&mut conn, "repo1", &new).unwrap(); + + let count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM code_call_graph WHERE repo_id = ?1", + ["repo1"], + |row| row.get(0), + ) + .unwrap(); + assert_eq!(count, 1); + } + + #[test] + fn test_save_calls_incremental() { + let mut conn = WorkspaceRegistry::init_in_memory().unwrap(); + let first = vec![sample_call("src/a.rs", "fn_a", "callee1")]; + let second = vec![sample_call("src/b.rs", "fn_b", "callee2")]; + + save_calls_incremental(&mut conn, "repo1", &first).unwrap(); + save_calls_incremental(&mut conn, "repo1", &second).unwrap(); + + let count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM code_call_graph WHERE repo_id = ?1", + ["repo1"], + |row| row.get(0), + ) + .unwrap(); + assert_eq!(count, 2); + } +}