Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

63 changes: 57 additions & 6 deletions core/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2864,23 +2864,66 @@ impl Connection {
self.syms.read().vtab_modules.keys().cloned().collect()
}

/// Returns external (extension) functions: (name, is_aggregate, argc)
pub fn get_syms_functions(&self) -> Vec<(String, bool, i32)> {
/// Returns external (extension) functions: (name, is_aggregate, argc, deterministic)
pub fn get_syms_functions(&self) -> Vec<(String, bool, i32, bool)> {
self.syms
.read()
.functions
.values()
.map(|f| {
let is_agg = matches!(f.func, function::ExtFunc::Aggregate { .. });
let is_agg = f.func.is_aggregate();
let argc = match &f.func {
function::ExtFunc::Aggregate { argc, .. } => *argc as i32,
function::ExtFunc::ContextScalar { argc, .. } => *argc,
function::ExtFunc::Scalar(_) => -1,
};
(f.name.clone(), is_agg, argc)
(
f.name.clone(),
is_agg,
argc,
function::Deterministic::is_deterministic(f.as_ref()),
)
})
.collect()
}

#[allow(clippy::too_many_arguments)]
pub fn register_external_scalar_function(
&self,
name: String,
argc: i32,
deterministic: bool,
context: usize,
callback: crate::ContextScalarFunction,
context_destructor: Option<crate::ContextDestructor>,
value_destructor: Option<crate::ContextValueDestructor>,
) {
assert!(
argc >= -1,
"managed scalar argument count must be -1 (variadic) or non-negative"
);
let normalized_name = crate::util::normalize_ident(&name);
self.syms.write().functions.insert(
normalized_name.clone(),
Arc::new(function::ExternalFunc::new_context_scalar(
normalized_name,
argc,
deterministic,
context,
callback,
context_destructor,
value_destructor,
)),
);
self.bump_prepare_context_generation();
}

pub fn unregister_external_function(&self, name: &str) {
let normalized_name = crate::util::normalize_ident(name);
self.syms.write().functions.remove(&normalized_name);
self.bump_prepare_context_generation();
}

pub(crate) fn database_ptr(&self) -> usize {
Arc::as_ptr(&self.db) as usize
}
Expand Down Expand Up @@ -3391,9 +3434,17 @@ impl SymbolTable {
pub fn resolve_function(
&self,
name: &str,
_arg_count: usize,
arg_count: usize,
) -> Option<Arc<function::ExternalFunc>> {
self.functions.get(name).cloned()
self.functions
.get(name)
.cloned()
.or_else(|| {
self.functions
.get(&crate::util::normalize_ident(name))
.cloned()
})
.filter(|func| func.func.matches_arg_count(arg_count))
}

pub fn extend(&mut self, other: &SymbolTable) {
Expand Down
155 changes: 151 additions & 4 deletions core/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,100 @@ use crate::sync::Arc;
use std::fmt;
use std::fmt::{Debug, Display};
use strum::IntoEnumIterator;
use turso_ext::{FinalizeFunction, InitAggFunction, ScalarFunction, StepFunction};
use turso_ext::{
FinalizeFunction, InitAggFunction, ScalarFunction, StepFunction, Value as ExtValue,
};

use crate::{LimboError, Value};

pub type ContextScalarFunction = unsafe extern "C" fn(
context: usize,
argc: i32,
argv: *const ExtValue,
result: *mut ContextValue,
);
pub type ContextDestructor = unsafe extern "C" fn(context: usize);
pub type ContextValueDestructor = unsafe extern "C" fn(result: *mut ContextValue);

#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ContextValueType {
Null,
Integer,
Float,
Text,
Blob,
Error,
}

#[repr(C)]
#[derive(Clone, Copy)]
pub struct ContextValueBytes {
pub ptr: *const u8,
pub len: usize,
}

#[repr(C)]
#[derive(Clone, Copy)]
pub union ContextValueData {
pub int: i64,
pub float: f64,
pub bytes: ContextValueBytes,
}

#[repr(C)]
#[derive(Clone, Copy)]
pub struct ContextValue {
pub value_type: ContextValueType,
pub value: ContextValueData,
}

impl ContextValue {
pub fn null() -> Self {
Self {
value_type: ContextValueType::Null,
value: ContextValueData { int: 0 },
}
}

use crate::LimboError;
pub fn into_value(self) -> Result<Value, LimboError> {
// Text/blob/error payloads are callback-owned; copy them before the
// caller invokes the registered value destructor.
match self.value_type {
ContextValueType::Null => Ok(Value::Null),
ContextValueType::Integer => Ok(Value::from_i64(unsafe { self.value.int })),
ContextValueType::Float => Ok(Value::from_f64(unsafe { self.value.float })),
ContextValueType::Text => {
let bytes = unsafe { self.value.bytes };
if bytes.ptr.is_null() {
return Ok(Value::Null);
}
let slice = unsafe { std::slice::from_raw_parts(bytes.ptr, bytes.len) };
let text = std::str::from_utf8(slice)
.map_err(|err| LimboError::ExtensionError(err.to_string()))?;
Ok(Value::build_text(text.to_string()))
}
ContextValueType::Blob => {
let bytes = unsafe { self.value.bytes };
if bytes.ptr.is_null() {
return Ok(Value::Blob(Vec::new()));
}
let slice = unsafe { std::slice::from_raw_parts(bytes.ptr, bytes.len) };
Ok(Value::Blob(slice.to_vec()))
}
ContextValueType::Error => {
let bytes = unsafe { self.value.bytes };
if bytes.ptr.is_null() {
return Err(LimboError::ExtensionError(String::new()));
}
let slice = unsafe { std::slice::from_raw_parts(bytes.ptr, bytes.len) };
let message = std::str::from_utf8(slice)
.map_err(|err| LimboError::ExtensionError(err.to_string()))?;
Err(LimboError::ExtensionError(message.to_string()))
}
}
}
}

pub trait Deterministic: std::fmt::Display {
fn is_deterministic(&self) -> bool;
Expand All @@ -17,14 +108,24 @@ pub struct ExternalFunc {

impl Deterministic for ExternalFunc {
fn is_deterministic(&self) -> bool {
// external functions can be whatever so let's just default to false
false
match self.func {
ExtFunc::ContextScalar { deterministic, .. } => deterministic,
_ => false,
}
}
}

#[derive(Debug, Clone)]
pub enum ExtFunc {
Scalar(ScalarFunction),
ContextScalar {
context: usize,
argc: i32,
deterministic: bool,
callback: ContextScalarFunction,
context_destructor: Option<ContextDestructor>,
value_destructor: Option<ContextValueDestructor>,
},
Aggregate {
argc: usize,
init: InitAggFunction,
Expand All @@ -40,6 +141,17 @@ impl ExtFunc {
}
Err(())
}

pub fn matches_arg_count(&self, arg_count: usize) -> bool {
match self {
Self::ContextScalar { argc, .. } => *argc < 0 || *argc as usize == arg_count,
Self::Scalar(_) | Self::Aggregate { .. } => true,
}
}

pub fn is_aggregate(&self) -> bool {
matches!(self, Self::Aggregate { .. })
}
}

impl ExternalFunc {
Expand All @@ -65,6 +177,41 @@ impl ExternalFunc {
},
}
}

pub fn new_context_scalar(
name: String,
argc: i32,
deterministic: bool,
context: usize,
callback: ContextScalarFunction,
context_destructor: Option<ContextDestructor>,
value_destructor: Option<ContextValueDestructor>,
) -> Self {
Self {
name,
func: ExtFunc::ContextScalar {
context,
argc,
deterministic,
callback,
context_destructor,
value_destructor,
},
}
}
}

impl Drop for ExternalFunc {
fn drop(&mut self) {
if let ExtFunc::ContextScalar {
context,
context_destructor: Some(context_destructor),
..
} = self.func
{
unsafe { context_destructor(context) };
}
}
}

impl Debug for ExternalFunc {
Expand Down
4 changes: 4 additions & 0 deletions core/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ use util::parse_schema_rows;
pub use connection::{resolve_ext_path, Connection, Row, StepResult, SymbolTable};
pub(crate) use connection::{AtomicTransactionState, TransactionState};
pub use error::{io_error, CompletionError, LimboError};
pub use function::{
ContextDestructor, ContextScalarFunction, ContextValue, ContextValueBytes, ContextValueData,
ContextValueDestructor, ContextValueType,
};
#[cfg(all(feature = "fs", target_family = "unix", not(miri)))]
pub use io::UnixIO;
#[cfg(all(feature = "fs", target_os = "linux", feature = "io_uring", not(miri)))]
Expand Down
4 changes: 2 additions & 2 deletions core/storage/buffer_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,11 @@ mod arena {

#[cfg(any(not(unix), miri))]
mod arena {
pub fn alloc(len: usize) -> *mut u8 {
pub unsafe fn alloc(len: usize) -> *mut u8 {
let layout = std::alloc::Layout::from_size_align(len, std::mem::size_of::<u8>()).unwrap();
unsafe { std::alloc::alloc_zeroed(layout) }
}
pub fn dealloc(ptr: *mut u8, len: usize) {
pub unsafe fn dealloc(ptr: *mut u8, len: usize) {
let layout = std::alloc::Layout::from_size_align(len, std::mem::size_of::<u8>()).unwrap();
unsafe { std::alloc::dealloc(ptr, layout) };
}
Expand Down
8 changes: 6 additions & 2 deletions core/translate/pragma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -949,14 +949,18 @@ fn query_pragma(
}

// External (extension) functions
for (name, is_agg, argc) in connection.get_syms_functions() {
for (name, is_agg, argc, deterministic) in connection.get_syms_functions() {
let func_type = if is_agg { "a" } else { "s" };
let mut flags = 0;
if deterministic {
flags |= SQLITE_DETERMINISTIC;
}
program.emit_string8(name, base_reg);
program.emit_int(0, base_reg + 1); // builtin = 0
program.emit_string8(func_type.to_string(), base_reg + 2);
program.emit_string8("utf8".to_string(), base_reg + 3);
program.emit_int(argc as i64, base_reg + 4);
program.emit_int(0, base_reg + 5); // flags = 0 for extensions
program.emit_int(flags, base_reg + 5);
program.emit_result_row(base_reg, 6);
}

Expand Down
30 changes: 30 additions & 0 deletions core/vdbe/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8139,6 +8139,36 @@ pub fn op_function(
}
}
}
ExtFunc::ContextScalar {
context,
callback,
value_destructor,
..
} => {
let mut ext_values = Vec::with_capacity(arg_count);
if arg_count != 0 {
let register_slice = &state.registers[*start_reg..*start_reg + arg_count];
for ov in register_slice.iter() {
ext_values.push(ov.get_value().to_ffi());
}
}

let argv_ptr = if ext_values.is_empty() {
std::ptr::null()
} else {
ext_values.as_ptr()
};
let mut result = crate::function::ContextValue::null();
unsafe { callback(context, arg_count as i32, argv_ptr, &mut result) };
let value = result.into_value();
if let Some(value_destructor) = value_destructor {
unsafe { value_destructor(&mut result) };
}
for ext_value in ext_values {
unsafe { ext_value.__free_internal_type() };
}
state.registers[*dest].set_value(value?);
}
_ => unreachable!("aggregate called in scalar context"),
},
crate::function::Func::Math(math_func) => match math_func.arity() {
Expand Down
Loading
Loading