diff --git a/compat.js b/compat.js index 5fc1fc1..a04aba9 100644 --- a/compat.js +++ b/compat.js @@ -196,6 +196,77 @@ class Database { } } + /** + * Executes a batch of SQL statements sequentially, returning one + * result object per input statement. + * + * When `mode` is provided and the connection is not already inside a + * transaction, the batch is wrapped in a `BEGIN ` / `COMMIT` + * transaction that is rolled back if any statement fails. + * + * @param {Array }>} statements - The statements to execute. + * @param {string | { mode?: string, raw?: boolean }} [options] - Optional + * transaction mode or batch options. When `mode` is provided and the + * connection is not already inside a transaction, the statements run inside + * a transaction. When `raw` is true, reader rows are returned as arrays. + * + * Batch mutation result sets intentionally expose `rowsAffected` only. Unlike + * `Statement.run()`, they do not include `lastInsertRowid`. + * @returns {Array<{ columns: string[], columnTypes: string[], rows: Array | any[]>, rowsAffected: number }>} + */ + batch(statements, options) { + if (!Array.isArray(statements)) { + throw new TypeError("Expected first argument to be an array of statements"); + } + + const { mode, raw } = normalizeBatchOptions(options); + const wrap = mode != null && !this.inTransaction; + if (wrap) { + this.exec(`BEGIN ${normalizeBatchMode(mode)}`); + } + + const results = []; + try { + for (const statement of statements) { + const sql = typeof statement === "string" ? statement : statement.sql; + const args = typeof statement === "string" ? undefined : statement.args; + + const stmt = this.prepare(sql); + const cols = stmt.columns(); + const columnNames = cols.map((c) => c.name); + const columnTypes = cols.map((c) => c.type ?? ""); + + if (columnNames.length > 0) { + // Reader statement: collect the returned rows. + if (raw) { + stmt.raw(true); + } + const rows = args !== undefined ? stmt.all(args) : stmt.all(); + results.push(makeResultSet(columnNames, columnTypes, rows, 0)); + } else { + // Mutating statement: report affected rows only; batch results do not + // expose Statement.run()'s lastInsertRowid by design. + const info = args !== undefined ? stmt.run(args) : stmt.run(); + results.push(makeResultSet(columnNames, columnTypes, [], info.changes)); + } + } + + if (wrap) { + this.exec("COMMIT"); + } + } catch (err) { + if (wrap) { + try { + this.exec("ROLLBACK"); + } catch (_) { + // ignore rollback failures and surface the original error + } + } + throw convertError(err); + } + return results; + } + /** * Interrupts the database connection. */ @@ -388,3 +459,42 @@ module.exports = Database; module.exports.SqliteError = SqliteError; module.exports.Authorization = Authorization; module.exports.Action = Action; + +function normalizeBatchMode(mode) { + switch (String(mode).toLowerCase()) { + case "write": + return "IMMEDIATE"; + case "read": + return "DEFERRED"; + case "deferred": + return "DEFERRED"; + case "immediate": + return "IMMEDIATE"; + case "exclusive": + return "EXCLUSIVE"; + default: + return String(mode).toUpperCase(); + } +} + +function normalizeBatchOptions(options) { + if (options != null && typeof options === "object") { + return { + mode: options.mode, + raw: options.raw === true, + }; + } + return { + mode: options, + raw: false, + }; +} + +function makeResultSet(columns, columnTypes, rows, rowsAffected) { + return { + columns, + columnTypes, + rows, + rowsAffected, + }; +} diff --git a/docs/api.md b/docs/api.md index eb1e55d..a449290 100644 --- a/docs/api.md +++ b/docs/api.md @@ -262,6 +262,36 @@ Executes a SQL statement. | sql | string | The SQL statement string to execute. | | queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). | +### batch(statements, [options]) ⇒ array of ResultSet + +Executes a batch of SQL statements sequentially and returns one `ResultSet` +per input statement. Each statement may be a SQL string or an object of the +form `{ sql, args }`, where `args` is an array (positional) or an object +(named) of bind parameters. + +`options` may be a transaction mode string for compatibility, or an object with +`mode` and `raw` fields. Set `raw: true` to return reader rows in the same array +form as `Statement.raw().all()`. + +| Param | Type | Description | +| ---------- | ------------------- | ------------------------------------ | +| statements | array | The statements to execute. | +| options | string \| object | Optional transaction mode string or `{ mode, raw }` object. When `mode` is provided and the connection is not already in a transaction, the batch runs inside a transaction that is rolled back if any statement fails. When `raw` is true, reader rows are arrays. | + +Each `ResultSet` has the following shape: + +| Field | Type | Description | +| --------------- | ----------------------------- | --------------------------------------------- | +| columns | string[] | The column names of the result. | +| columnTypes | string[] | The declared column types of the result. | +| rows | Row[] | Rows returned by `Statement.all()`. | +| rowsAffected | number | The number of rows changed by the statement. | + +Mutation result sets intentionally expose `rowsAffected` only. Unlike +`Statement.run()`, `batch()` result sets do not include `lastInsertRowid`. + +**Note:** This is an extension in libSQL and not available in `better-sqlite3`. + ### interrupt() ⇒ this Cancel ongoing operations and make them return at earliest opportunity. diff --git a/integration-tests/tests/async.test.js b/integration-tests/tests/async.test.js index a7ac3c2..0bb42b5 100644 --- a/integration-tests/tests/async.test.js +++ b/integration-tests/tests/async.test.js @@ -745,6 +745,79 @@ test.serial("Database.run() forwards queryOptions", async (t) => { ); }); +test.serial("Database.batch() returns per-statement result sets", async (t) => { + const db = t.context.db; + + const results = await db.batch([ + { sql: "INSERT INTO users (id, name, email) VALUES (?, ?, ?)", args: [3, "Carol", "carol@example.org"] }, + { sql: "UPDATE users SET email = ? WHERE id = ?", args: ["alice@new.org", 1] }, + "SELECT id, name FROM users ORDER BY id", + ]); + + t.true(Array.isArray(results)); + t.is(results.length, 3); + + // INSERT + t.deepEqual(results[0].columns, []); + t.deepEqual(results[0].columnTypes, []); + t.deepEqual(results[0].rows, []); + t.is(results[0].rowsAffected, 1); + + // UPDATE + t.is(results[1].rowsAffected, 1); + + // SELECT + t.deepEqual(results[2].columns, ["id", "name"]); + t.is(results[2].rowsAffected, 0); + t.is(results[2].rows.length, 3); + + // Default Statement.all() row shape + const row = results[2].rows[0]; + t.false(Array.isArray(row)); + t.is(row.id, 1); + t.is(row.name, "Alice"); + + t.is(results[0].toJSON, undefined); +}); + +test.serial("Database.batch() with named args", async (t) => { + const db = t.context.db; + const results = await db.batch([ + { sql: "SELECT * FROM users WHERE id = :id", args: { id: 2 } }, + ]); + t.is(results.length, 1); + t.is(results[0].rows.length, 1); + t.is(results[0].rows[0].name, "Bob"); +}); + +test.serial("Database.batch() with raw rows", async (t) => { + const db = t.context.db; + const results = await db.batch([ + { sql: "SELECT id, name FROM users WHERE id = ?", args: [1] }, + ], { raw: true }); + t.is(results.length, 1); + t.deepEqual(results[0].rows, [[1, "Alice"]]); +}); + +test.serial("Database.batch() rolls back on error when given a mode", async (t) => { + const db = t.context.db; + await t.throwsAsync(async () => { + await db.batch([ + { sql: "INSERT INTO users (id, name, email) VALUES (?, ?, ?)", args: [10, "Dan", "dan@example.org"] }, + { sql: "INSERT INTO users (id, name, email) VALUES (?, ?, ?)", args: [1, "Dup", "dup@example.org"] }, + ], "write"); + }); + // The first insert must have been rolled back. + const stmt = await db.prepare("SELECT count(*) AS c FROM users WHERE id = 10"); + const row = await stmt.get(); + t.is(row.c, 0); +}); + +test.serial("Database.batch() rejects non-array argument", async (t) => { + const db = t.context.db; + await t.throwsAsync(() => db.batch("SELECT 1"), { instanceOf: TypeError }); +}); + const connect = async (path_opt, options = {}) => { const path = path_opt ?? "hello.db"; const provider = process.env.PROVIDER; diff --git a/integration-tests/tests/sync.test.js b/integration-tests/tests/sync.test.js index b2ee563..65259e6 100644 --- a/integration-tests/tests/sync.test.js +++ b/integration-tests/tests/sync.test.js @@ -667,6 +667,98 @@ test.serial("Statement.reader [DELETE RETURNING is true]", async (t) => { t.is(stmt.reader, true); }); +test.serial("Database.batch() returns per-statement result sets", async (t) => { + if (t.context.provider !== "libsql") { + t.pass(); + return; + } + const db = t.context.db; + + const results = db.batch([ + { sql: "INSERT INTO users (id, name, email) VALUES (?, ?, ?)", args: [3, "Carol", "carol@example.org"] }, + { sql: "UPDATE users SET email = ? WHERE id = ?", args: ["alice@new.org", 1] }, + "SELECT id, name FROM users ORDER BY id", + ]); + + t.true(Array.isArray(results)); + t.is(results.length, 3); + + // INSERT + t.deepEqual(results[0].columns, []); + t.deepEqual(results[0].columnTypes, []); + t.deepEqual(results[0].rows, []); + t.is(results[0].rowsAffected, 1); + + // UPDATE + t.is(results[1].rowsAffected, 1); + + // SELECT + t.deepEqual(results[2].columns, ["id", "name"]); + t.is(results[2].rowsAffected, 0); + t.is(results[2].rows.length, 3); + + // Default Statement.all() row shape + const row = results[2].rows[0]; + t.false(Array.isArray(row)); + t.is(row.id, 1); + t.is(row.name, "Alice"); + + t.is(results[0].toJSON, undefined); +}); + +test.serial("Database.batch() with named args", async (t) => { + if (t.context.provider !== "libsql") { + t.pass(); + return; + } + const db = t.context.db; + const results = db.batch([ + { sql: "SELECT * FROM users WHERE id = :id", args: { id: 2 } }, + ]); + t.is(results.length, 1); + t.is(results[0].rows.length, 1); + t.is(results[0].rows[0].name, "Bob"); +}); + +test.serial("Database.batch() with raw rows", async (t) => { + if (t.context.provider !== "libsql") { + t.pass(); + return; + } + const db = t.context.db; + const results = db.batch([ + { sql: "SELECT id, name FROM users WHERE id = ?", args: [1] }, + ], { raw: true }); + t.is(results.length, 1); + t.deepEqual(results[0].rows, [[1, "Alice"]]); +}); + +test.serial("Database.batch() rolls back on error when given a mode", async (t) => { + if (t.context.provider !== "libsql") { + t.pass(); + return; + } + const db = t.context.db; + t.throws(() => { + db.batch([ + { sql: "INSERT INTO users (id, name, email) VALUES (?, ?, ?)", args: [10, "Dan", "dan@example.org"] }, + { sql: "INSERT INTO users (id, name, email) VALUES (?, ?, ?)", args: [1, "Dup", "dup@example.org"] }, + ], "write"); + }); + // The first insert must have been rolled back. + const row = db.prepare("SELECT count(*) AS c FROM users WHERE id = 10").get(); + t.is(row.c, 0); +}); + +test.serial("Database.batch() rejects non-array argument", async (t) => { + if (t.context.provider !== "libsql") { + t.pass(); + return; + } + const db = t.context.db; + t.throws(() => db.batch("SELECT 1"), { instanceOf: TypeError }); +}); + const connect = async (path_opt, options = {}) => { const path = path_opt ?? "hello.db"; const provider = process.env.PROVIDER; diff --git a/promise.js b/promise.js index f020de7..1621627 100644 --- a/promise.js +++ b/promise.js @@ -284,6 +284,77 @@ class Database { } } + /** + * Executes a batch of SQL statements sequentially, returning one + * result object per input statement. + * + * When `mode` is provided and the connection is not already inside a + * transaction, the batch is wrapped in a `BEGIN ` / `COMMIT` + * transaction that is rolled back if any statement fails. + * + * @param {Array }>} statements - The statements to execute. + * @param {string | { mode?: string, raw?: boolean }} [options] - Optional + * transaction mode or batch options. When `mode` is provided and the + * connection is not already inside a transaction, the statements run inside + * a transaction. When `raw` is true, reader rows are returned as arrays. + * + * Batch mutation result sets intentionally expose `rowsAffected` only. Unlike + * `Statement.run()`, they do not include `lastInsertRowid`. + * @returns {Promise | any[]>, rowsAffected: number }>>} + */ + async batch(statements, options) { + if (!Array.isArray(statements)) { + throw new TypeError("Expected first argument to be an array of statements"); + } + + const { mode, raw } = normalizeBatchOptions(options); + const wrap = mode != null && !this.inTransaction; + if (wrap) { + await this.exec(`BEGIN ${normalizeBatchMode(mode)}`); + } + + const results = []; + try { + for (const statement of statements) { + const sql = typeof statement === "string" ? statement : statement.sql; + const args = typeof statement === "string" ? undefined : statement.args; + + const stmt = await this.prepare(sql); + const cols = stmt.columns(); + const columnNames = cols.map((c) => c.name); + const columnTypes = cols.map((c) => c.type ?? ""); + + if (columnNames.length > 0) { + // Reader statement: collect the returned rows. + if (raw) { + stmt.raw(true); + } + const rows = args !== undefined ? await stmt.all(args) : await stmt.all(); + results.push(makeResultSet(columnNames, columnTypes, rows, 0)); + } else { + // Mutating statement: report affected rows only; batch results do not + // expose Statement.run()'s lastInsertRowid by design. + const info = args !== undefined ? await stmt.run(args) : await stmt.run(); + results.push(makeResultSet(columnNames, columnTypes, [], info.changes)); + } + } + + if (wrap) { + await this.exec("COMMIT"); + } + } catch (err) { + if (wrap) { + try { + await this.exec("ROLLBACK"); + } catch (_) { + // ignore rollback failures and surface the original error + } + } + throw convertError(err); + } + return results; + } + /** * Interrupts the database connection. */ @@ -484,3 +555,42 @@ module.exports = { Statement, connect, }; + +function normalizeBatchMode(mode) { + switch (String(mode).toLowerCase()) { + case "write": + return "IMMEDIATE"; + case "read": + return "DEFERRED"; + case "deferred": + return "DEFERRED"; + case "immediate": + return "IMMEDIATE"; + case "exclusive": + return "EXCLUSIVE"; + default: + return String(mode).toUpperCase(); + } +} + +function normalizeBatchOptions(options) { + if (options != null && typeof options === "object") { + return { + mode: options.mode, + raw: options.raw === true, + }; + } + return { + mode: options, + raw: false, + }; +} + +function makeResultSet(columns, columnTypes, rows, rowsAffected) { + return { + columns, + columnTypes, + rows, + rowsAffected, + }; +}