Skip to content

Commit

Permalink
Add extension support for SQLite (#2062)
Browse files Browse the repository at this point in the history
* Add extension support for SQLite

While SQLite supports loading extensions at run-time via either the C
API or the SQL interface, they strongly recommend [1] only enabling the C
API so that SQL injections don't allow attackers to run arbitrary
extension code.

Here we take the most conservative approach, we enable only the C
function, and then only when the user requests extensions be loaded in
their `SqliteConnectOptions`, and disable it again once we're done
loading those requested modules. We don't add any support for loading
extensions via environment variables or connection strings.

Extensions in the options are stored as an IndexMap as the load order
can have side effects, they will be loaded in the order they are
supplied by the caller.

Extensions with custom entry points are supported, but a default API
is exposed as most users will interact with extensions using the
defaults.

[1]: https://sqlite.org/c3ref/enable_load_extension.html

* Add extension testing for SQlite

Extends x.py to download an appropriate shared object file for supported
operating systems, and uses wget to fetch one into the GitHub Actions
context for use by CI.

Overriding LD_LIBRARY_PATH for only this specific DB minimises the
impact on the rest of the suite.
  • Loading branch information
bradfier authored Sep 1, 2022
1 parent 9de70d2 commit 20877d8
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 4 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/sqlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
- uses: Swatinem/rust-cache@v1
with:
key: ${{ runner.os }}-check-${{ matrix.runtime }}-${{ matrix.tls }}

- uses: actions-rs/cargo@v1
with:
command: check
Expand Down Expand Up @@ -144,6 +144,8 @@ jobs:
steps:
- uses: actions/checkout@v2

- run: mkdir /tmp/sqlite3-lib && wget -O /tmp/sqlite3-lib/ipaddr.so https://github.com/nalgeon/sqlean/releases/download/0.15.2/ipaddr.so

- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand All @@ -164,6 +166,8 @@ jobs:
--test-threads=1
env:
DATABASE_URL: sqlite://tests/sqlite/sqlite.db
RUSTFLAGS: --cfg sqlite_ipaddr
LD_LIBRARY_PATH: /tmp/sqlite3-lib

postgres:
name: Postgres
Expand Down
128 changes: 125 additions & 3 deletions sqlx-core/src/sqlite/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,46 @@ use crate::error::Error;
use crate::sqlite::connection::handle::ConnectionHandle;
use crate::sqlite::connection::{ConnectionState, Statements};
use crate::sqlite::{SqliteConnectOptions, SqliteError};
use indexmap::IndexMap;
use libc::c_void;
use libsqlite3_sys::{
sqlite3_busy_timeout, sqlite3_extended_result_codes, sqlite3_open_v2, SQLITE_OK,
sqlite3, sqlite3_busy_timeout, sqlite3_db_config, sqlite3_extended_result_codes, sqlite3_free,
sqlite3_load_extension, sqlite3_open_v2, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, SQLITE_OK,
SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX,
SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE,
};
use std::ffi::CString;
use std::ffi::{CStr, CString};
use std::io;
use std::ptr::{null, null_mut};
use std::os::raw::c_int;
use std::ptr::{addr_of_mut, null, null_mut};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;

static THREAD_ID: AtomicU64 = AtomicU64::new(0);

enum SqliteLoadExtensionMode {
/// Enables only the C-API, leaving the SQL function disabled.
Enable,
/// Disables both the C-API and the SQL function.
DisableAll,
}

impl SqliteLoadExtensionMode {
fn as_int(self) -> c_int {
match self {
SqliteLoadExtensionMode::Enable => 1,
SqliteLoadExtensionMode::DisableAll => 0,
}
}
}

pub struct EstablishParams {
filename: CString,
open_flags: i32,
busy_timeout: Duration,
statement_cache_capacity: usize,
log_settings: LogSettings,
extensions: IndexMap<CString, Option<CString>>,
pub(crate) thread_name: String,
pub(crate) command_channel_size: usize,
}
Expand Down Expand Up @@ -89,17 +110,67 @@ impl EstablishParams {
)
})?;

let extensions = options
.extensions
.iter()
.map(|(name, entry)| {
let entry = entry
.as_ref()
.map(|e| {
CString::new(e.as_bytes()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"extension entrypoint names passed to SQLite must not contain nul bytes"
)
})
})
.transpose()?;
Ok((
CString::new(name.as_bytes()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"extension names passed to SQLite must not contain nul bytes",
)
})?,
entry,
))
})
.collect::<Result<IndexMap<CString, Option<CString>>, io::Error>>()?;

Ok(Self {
filename,
open_flags: flags,
busy_timeout: options.busy_timeout,
statement_cache_capacity: options.statement_cache_capacity,
log_settings: options.log_settings.clone(),
extensions,
thread_name: (options.thread_name)(THREAD_ID.fetch_add(1, Ordering::AcqRel)),
command_channel_size: options.command_channel_size,
})
}

// Enable extension loading via the db_config function, as recommended by the docs rather
// than the more obvious `sqlite3_enable_load_extension`
// https://www.sqlite.org/c3ref/db_config.html
// https://www.sqlite.org/c3ref/c_dbconfig_defensive.html#sqlitedbconfigenableloadextension
unsafe fn sqlite3_set_load_extension(
db: *mut sqlite3,
mode: SqliteLoadExtensionMode,
) -> Result<(), Error> {
let status = sqlite3_db_config(
db,
SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION,
mode.as_int(),
null::<i32>(),
);

if status != SQLITE_OK {
return Err(Error::Database(Box::new(SqliteError::new(db))));
}

Ok(())
}

pub(crate) fn establish(&self) -> Result<ConnectionState, Error> {
let mut handle = null_mut();

Expand Down Expand Up @@ -131,6 +202,57 @@ impl EstablishParams {
sqlite3_extended_result_codes(handle.as_ptr(), 1);
}

if !self.extensions.is_empty() {
// Enable loading extensions
unsafe {
Self::sqlite3_set_load_extension(handle.as_ptr(), SqliteLoadExtensionMode::Enable)?;
}

for ext in self.extensions.iter() {
// `sqlite3_load_extension` is unusual as it returns its errors via an out-pointer
// rather than by calling `sqlite3_errmsg`
let mut error = null_mut();
status = unsafe {
sqlite3_load_extension(
handle.as_ptr(),
ext.0.as_ptr(),
ext.1.as_ref().map_or(null(), |e| e.as_ptr()),
addr_of_mut!(error),
)
};

if status != SQLITE_OK {
// SAFETY: We become responsible for any memory allocation at `&error`, so test
// for null and take an RAII version for returns
let err_msg = if !error.is_null() {
unsafe {
let e = CStr::from_ptr(error).into();
sqlite3_free(error as *mut c_void);
e
}
} else {
CString::new("Unknown error when loading extension")
.expect("text should be representable as a CString")
};
return Err(Error::Database(Box::new(SqliteError::extension(
handle.as_ptr(),
&err_msg,
))));
}
}

// Preempt any hypothetical security issues arising from leaving ENABLE_LOAD_EXTENSION
// on by disabling the flag again once we've loaded all the requested modules.
// Fail-fast (via `?`) if disabling the extension loader didn't work for some reason,
// avoids an unexpected state going undetected.
unsafe {
Self::sqlite3_set_load_extension(
handle.as_ptr(),
SqliteLoadExtensionMode::DisableAll,
)?;
}
}

// Configure a busy timeout
// This causes SQLite to automatically sleep in increasing intervals until the time
// when there is something locked during [sqlite3_step].
Expand Down
7 changes: 7 additions & 0 deletions sqlx-core/src/sqlite/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ impl SqliteError {
message: message.to_owned(),
}
}

/// For errors during extension load, the error message is supplied via a separate pointer
pub(crate) fn extension(handle: *mut sqlite3, error_msg: &CStr) -> Self {
let mut err = Self::new(handle);
err.message = unsafe { from_utf8_unchecked(error_msg.to_bytes()).to_owned() };
err
}
}

impl Display for SqliteError {
Expand Down
44 changes: 44 additions & 0 deletions sqlx-core/src/sqlite/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ pub struct SqliteConnectOptions {
pub(crate) vfs: Option<Cow<'static, str>>,

pub(crate) pragmas: IndexMap<Cow<'static, str>, Option<Cow<'static, str>>>,
/// Extensions are specified as a pair of <Extension Name : Optional Entry Point>, the majority
/// of SQLite extensions will use the default entry points specified in the docs, these should
/// be added to the map with a `None` value.
/// <https://www.sqlite.org/loadext.html#loading_an_extension>
pub(crate) extensions: IndexMap<Cow<'static, str>, Option<Cow<'static, str>>>,

pub(crate) command_channel_size: usize,
pub(crate) row_channel_size: usize,
Expand Down Expand Up @@ -174,6 +179,7 @@ impl SqliteConnectOptions {
immutable: false,
vfs: None,
pragmas,
extensions: Default::default(),
collations: Default::default(),
serialized: false,
thread_name: Arc::new(DebugFn(|id| format!("sqlx-sqlite-worker-{}", id))),
Expand Down Expand Up @@ -414,4 +420,42 @@ impl SqliteConnectOptions {
self.vfs = Some(vfs_name.into());
self
}

/// Load an [extension](https://www.sqlite.org/loadext.html) at run-time when the database connection
/// is established, using the default entry point.
///
/// Most common SQLite extensions can be loaded using this method, for extensions where you need
/// to specify the entry point, use [`extension_with_entrypoint`][`Self::extension_with_entrypoint`] instead.
///
/// Multiple extensions can be loaded by calling the method repeatedly on the options struct, they
/// will be loaded in the order they are added.
/// ```rust,no_run
/// # use sqlx_core::error::Error;
/// use std::str::FromStr;
/// use sqlx::sqlite::SqliteConnectOptions;
/// # fn options() -> Result<SqliteConnectOptions, Error> {
/// let options = SqliteConnectOptions::from_str("sqlite://data.db")?
/// .extension("vsv")
/// .extension("mod_spatialite");
/// # Ok(options)
/// # }
/// ```
pub fn extension(mut self, extension_name: impl Into<Cow<'static, str>>) -> Self {
self.extensions.insert(extension_name.into(), None);
self
}

/// Load an extension with a specified entry point.
///
/// Useful when using non-standard extensions, or when developing your own, the second argument
/// specifies where SQLite should expect to find the extension init routine.
pub fn extension_with_entrypoint(
mut self,
extension_name: impl Into<Cow<'static, str>>,
entry_point: impl Into<Cow<'static, str>>,
) -> Self {
self.extensions
.insert(extension_name.into(), Some(entry_point.into()));
self
}
}
15 changes: 15 additions & 0 deletions tests/sqlite/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,21 @@ async fn it_executes_with_pool() -> anyhow::Result<()> {
Ok(())
}

#[cfg(sqlite_ipaddr)]
#[sqlx_macros::test]
async fn it_opens_with_extension() -> anyhow::Result<()> {
use std::str::FromStr;

let opts = SqliteConnectOptions::from_str(&dotenvy::var("DATABASE_URL")?)?.extension("ipaddr");

let mut conn = SqliteConnection::connect_with(&opts).await?;
conn.execute("SELECT ipmasklen('192.168.16.12/24');")
.await?;
conn.close().await?;

Ok(())
}

#[sqlx_macros::test]
async fn it_opens_in_memory() -> anyhow::Result<()> {
// If the filename is ":memory:", then a private, temporary in-memory database
Expand Down
39 changes: 39 additions & 0 deletions tests/x.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import sys
import time
import argparse
import platform
import urllib.request
from glob import glob
from docker import start_database

Expand All @@ -23,6 +25,36 @@
dir_tests = os.path.join(dir_workspace, "tests")


def maybe_fetch_sqlite_extension():
"""
For supported platforms, if we're testing SQLite and the file isn't
already present, grab a simple extension for testing.
Returns the extension name if it was downloaded successfully or `None` if not.
"""
BASE_URL = "https://github.com/nalgeon/sqlean/releases/download/0.15.2/"
if platform.system() == "Darwin":
if platform.machine() == "arm64":
download_url = BASE_URL + "/ipaddr.arm64.dylib"
filename = "ipaddr.dylib"
else:
download_url = BASE_URL + "/ipaddr.dylib"
filename = "ipaddr.dylib"
elif platform.system() == "Linux":
download_url = BASE_URL + "/ipaddr.so"
filename = "ipaddr.so"
else:
# Unsupported OS
return None

if not os.path.exists(filename):
content = urllib.request.urlopen(download_url).read()
with open(filename, "wb") as fd:
fd.write(content)

return filename.split(".")[0]


def run(command, comment=None, env=None, service=None, tag=None, args=None, database_url_args=None):
if argv.list_targets:
if tag:
Expand All @@ -41,6 +73,13 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data

environ = env or {}

if service == "sqlite":
if maybe_fetch_sqlite_extension() is not None:
if environ.get("RUSTFLAGS"):
environ["RUSTFLAGS"] += " --cfg sqlite_ipaddr"
else:
environ["RUSTFLAGS"] = "--cfg sqlite_ipaddr"

if service is not None:
database_url = start_database(service, database="sqlite/sqlite.db" if service == "sqlite" else "sqlx", cwd=dir_tests)

Expand Down

0 comments on commit 20877d8

Please sign in to comment.