diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index a9a805002f..2083eebc1c 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -46,7 +46,7 @@ jobs: - uses: Swatinem/rust-cache@v1 with: key: ${{ runner.os }}-check-${{ matrix.runtime }}-${{ matrix.tls }} - + - uses: actions-rs/cargo@v1 with: command: check @@ -144,6 +144,8 @@ jobs: steps: - uses: actions/checkout@v2 + - run: mkdir /tmp/sqlite3-lib && wget -O /tmp/sqlite3-lib/ipaddr.so https://github.com/nalgeon/sqlean/releases/download/0.15.2/ipaddr.so + - uses: actions-rs/toolchain@v1 with: profile: minimal @@ -164,6 +166,8 @@ jobs: --test-threads=1 env: DATABASE_URL: sqlite://tests/sqlite/sqlite.db + RUSTFLAGS: --cfg sqlite_ipaddr + LD_LIBRARY_PATH: /tmp/sqlite3-lib postgres: name: Postgres diff --git a/sqlx-core/src/sqlite/connection/establish.rs b/sqlx-core/src/sqlite/connection/establish.rs index ead59a9dfa..9cdf110b78 100644 --- a/sqlx-core/src/sqlite/connection/establish.rs +++ b/sqlx-core/src/sqlite/connection/establish.rs @@ -3,25 +3,46 @@ use crate::error::Error; use crate::sqlite::connection::handle::ConnectionHandle; use crate::sqlite::connection::{ConnectionState, Statements}; use crate::sqlite::{SqliteConnectOptions, SqliteError}; +use indexmap::IndexMap; +use libc::c_void; use libsqlite3_sys::{ - sqlite3_busy_timeout, sqlite3_extended_result_codes, sqlite3_open_v2, SQLITE_OK, + sqlite3, sqlite3_busy_timeout, sqlite3_db_config, sqlite3_extended_result_codes, sqlite3_free, + sqlite3_load_extension, sqlite3_open_v2, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, SQLITE_OK, SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX, SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE, }; -use std::ffi::CString; +use std::ffi::{CStr, CString}; use std::io; -use std::ptr::{null, null_mut}; +use std::os::raw::c_int; +use std::ptr::{addr_of_mut, null, null_mut}; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; static THREAD_ID: AtomicU64 = AtomicU64::new(0); +enum SqliteLoadExtensionMode { + /// Enables only the C-API, leaving the SQL function disabled. + Enable, + /// Disables both the C-API and the SQL function. + DisableAll, +} + +impl SqliteLoadExtensionMode { + fn as_int(self) -> c_int { + match self { + SqliteLoadExtensionMode::Enable => 1, + SqliteLoadExtensionMode::DisableAll => 0, + } + } +} + pub struct EstablishParams { filename: CString, open_flags: i32, busy_timeout: Duration, statement_cache_capacity: usize, log_settings: LogSettings, + extensions: IndexMap>, pub(crate) thread_name: String, pub(crate) command_channel_size: usize, } @@ -89,17 +110,67 @@ impl EstablishParams { ) })?; + let extensions = options + .extensions + .iter() + .map(|(name, entry)| { + let entry = entry + .as_ref() + .map(|e| { + CString::new(e.as_bytes()).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidData, + "extension entrypoint names passed to SQLite must not contain nul bytes" + ) + }) + }) + .transpose()?; + Ok(( + CString::new(name.as_bytes()).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidData, + "extension names passed to SQLite must not contain nul bytes", + ) + })?, + entry, + )) + }) + .collect::>, io::Error>>()?; + Ok(Self { filename, open_flags: flags, busy_timeout: options.busy_timeout, statement_cache_capacity: options.statement_cache_capacity, log_settings: options.log_settings.clone(), + extensions, thread_name: (options.thread_name)(THREAD_ID.fetch_add(1, Ordering::AcqRel)), command_channel_size: options.command_channel_size, }) } + // Enable extension loading via the db_config function, as recommended by the docs rather + // than the more obvious `sqlite3_enable_load_extension` + // https://www.sqlite.org/c3ref/db_config.html + // https://www.sqlite.org/c3ref/c_dbconfig_defensive.html#sqlitedbconfigenableloadextension + unsafe fn sqlite3_set_load_extension( + db: *mut sqlite3, + mode: SqliteLoadExtensionMode, + ) -> Result<(), Error> { + let status = sqlite3_db_config( + db, + SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, + mode.as_int(), + null::(), + ); + + if status != SQLITE_OK { + return Err(Error::Database(Box::new(SqliteError::new(db)))); + } + + Ok(()) + } + pub(crate) fn establish(&self) -> Result { let mut handle = null_mut(); @@ -131,6 +202,57 @@ impl EstablishParams { sqlite3_extended_result_codes(handle.as_ptr(), 1); } + if !self.extensions.is_empty() { + // Enable loading extensions + unsafe { + Self::sqlite3_set_load_extension(handle.as_ptr(), SqliteLoadExtensionMode::Enable)?; + } + + for ext in self.extensions.iter() { + // `sqlite3_load_extension` is unusual as it returns its errors via an out-pointer + // rather than by calling `sqlite3_errmsg` + let mut error = null_mut(); + status = unsafe { + sqlite3_load_extension( + handle.as_ptr(), + ext.0.as_ptr(), + ext.1.as_ref().map_or(null(), |e| e.as_ptr()), + addr_of_mut!(error), + ) + }; + + if status != SQLITE_OK { + // SAFETY: We become responsible for any memory allocation at `&error`, so test + // for null and take an RAII version for returns + let err_msg = if !error.is_null() { + unsafe { + let e = CStr::from_ptr(error).into(); + sqlite3_free(error as *mut c_void); + e + } + } else { + CString::new("Unknown error when loading extension") + .expect("text should be representable as a CString") + }; + return Err(Error::Database(Box::new(SqliteError::extension( + handle.as_ptr(), + &err_msg, + )))); + } + } + + // Preempt any hypothetical security issues arising from leaving ENABLE_LOAD_EXTENSION + // on by disabling the flag again once we've loaded all the requested modules. + // Fail-fast (via `?`) if disabling the extension loader didn't work for some reason, + // avoids an unexpected state going undetected. + unsafe { + Self::sqlite3_set_load_extension( + handle.as_ptr(), + SqliteLoadExtensionMode::DisableAll, + )?; + } + } + // Configure a busy timeout // This causes SQLite to automatically sleep in increasing intervals until the time // when there is something locked during [sqlite3_step]. diff --git a/sqlx-core/src/sqlite/error.rs b/sqlx-core/src/sqlite/error.rs index 1c507a7839..227f37f677 100644 --- a/sqlx-core/src/sqlite/error.rs +++ b/sqlx-core/src/sqlite/error.rs @@ -35,6 +35,13 @@ impl SqliteError { message: message.to_owned(), } } + + /// For errors during extension load, the error message is supplied via a separate pointer + pub(crate) fn extension(handle: *mut sqlite3, error_msg: &CStr) -> Self { + let mut err = Self::new(handle); + err.message = unsafe { from_utf8_unchecked(error_msg.to_bytes()).to_owned() }; + err + } } impl Display for SqliteError { diff --git a/sqlx-core/src/sqlite/options/mod.rs b/sqlx-core/src/sqlite/options/mod.rs index f7a815d314..7070ec4dda 100644 --- a/sqlx-core/src/sqlite/options/mod.rs +++ b/sqlx-core/src/sqlite/options/mod.rs @@ -66,6 +66,11 @@ pub struct SqliteConnectOptions { pub(crate) vfs: Option>, pub(crate) pragmas: IndexMap, Option>>, + /// Extensions are specified as a pair of , the majority + /// of SQLite extensions will use the default entry points specified in the docs, these should + /// be added to the map with a `None` value. + /// + pub(crate) extensions: IndexMap, Option>>, pub(crate) command_channel_size: usize, pub(crate) row_channel_size: usize, @@ -174,6 +179,7 @@ impl SqliteConnectOptions { immutable: false, vfs: None, pragmas, + extensions: Default::default(), collations: Default::default(), serialized: false, thread_name: Arc::new(DebugFn(|id| format!("sqlx-sqlite-worker-{}", id))), @@ -414,4 +420,42 @@ impl SqliteConnectOptions { self.vfs = Some(vfs_name.into()); self } + + /// Load an [extension](https://www.sqlite.org/loadext.html) at run-time when the database connection + /// is established, using the default entry point. + /// + /// Most common SQLite extensions can be loaded using this method, for extensions where you need + /// to specify the entry point, use [`extension_with_entrypoint`][`Self::extension_with_entrypoint`] instead. + /// + /// Multiple extensions can be loaded by calling the method repeatedly on the options struct, they + /// will be loaded in the order they are added. + /// ```rust,no_run + /// # use sqlx_core::error::Error; + /// use std::str::FromStr; + /// use sqlx::sqlite::SqliteConnectOptions; + /// # fn options() -> Result { + /// let options = SqliteConnectOptions::from_str("sqlite://data.db")? + /// .extension("vsv") + /// .extension("mod_spatialite"); + /// # Ok(options) + /// # } + /// ``` + pub fn extension(mut self, extension_name: impl Into>) -> Self { + self.extensions.insert(extension_name.into(), None); + self + } + + /// Load an extension with a specified entry point. + /// + /// Useful when using non-standard extensions, or when developing your own, the second argument + /// specifies where SQLite should expect to find the extension init routine. + pub fn extension_with_entrypoint( + mut self, + extension_name: impl Into>, + entry_point: impl Into>, + ) -> Self { + self.extensions + .insert(extension_name.into(), Some(entry_point.into())); + self + } } diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index cdcff0508c..914d4ff88a 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -204,6 +204,21 @@ async fn it_executes_with_pool() -> anyhow::Result<()> { Ok(()) } +#[cfg(sqlite_ipaddr)] +#[sqlx_macros::test] +async fn it_opens_with_extension() -> anyhow::Result<()> { + use std::str::FromStr; + + let opts = SqliteConnectOptions::from_str(&dotenvy::var("DATABASE_URL")?)?.extension("ipaddr"); + + let mut conn = SqliteConnection::connect_with(&opts).await?; + conn.execute("SELECT ipmasklen('192.168.16.12/24');") + .await?; + conn.close().await?; + + Ok(()) +} + #[sqlx_macros::test] async fn it_opens_in_memory() -> anyhow::Result<()> { // If the filename is ":memory:", then a private, temporary in-memory database diff --git a/tests/x.py b/tests/x.py index 6b8785d83f..75791b5392 100755 --- a/tests/x.py +++ b/tests/x.py @@ -5,6 +5,8 @@ import sys import time import argparse +import platform +import urllib.request from glob import glob from docker import start_database @@ -23,6 +25,36 @@ dir_tests = os.path.join(dir_workspace, "tests") +def maybe_fetch_sqlite_extension(): + """ + For supported platforms, if we're testing SQLite and the file isn't + already present, grab a simple extension for testing. + + Returns the extension name if it was downloaded successfully or `None` if not. + """ + BASE_URL = "https://github.com/nalgeon/sqlean/releases/download/0.15.2/" + if platform.system() == "Darwin": + if platform.machine() == "arm64": + download_url = BASE_URL + "/ipaddr.arm64.dylib" + filename = "ipaddr.dylib" + else: + download_url = BASE_URL + "/ipaddr.dylib" + filename = "ipaddr.dylib" + elif platform.system() == "Linux": + download_url = BASE_URL + "/ipaddr.so" + filename = "ipaddr.so" + else: + # Unsupported OS + return None + + if not os.path.exists(filename): + content = urllib.request.urlopen(download_url).read() + with open(filename, "wb") as fd: + fd.write(content) + + return filename.split(".")[0] + + def run(command, comment=None, env=None, service=None, tag=None, args=None, database_url_args=None): if argv.list_targets: if tag: @@ -41,6 +73,13 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data environ = env or {} + if service == "sqlite": + if maybe_fetch_sqlite_extension() is not None: + if environ.get("RUSTFLAGS"): + environ["RUSTFLAGS"] += " --cfg sqlite_ipaddr" + else: + environ["RUSTFLAGS"] = "--cfg sqlite_ipaddr" + if service is not None: database_url = start_database(service, database="sqlite/sqlite.db" if service == "sqlite" else "sqlx", cwd=dir_tests)