feat(memory): add MariaDB backend support (#2788)

This commit is contained in:
argenis de la rosa 2026-03-04 21:37:41 -05:00
parent a00ae631e6
commit 79ab8cdb0f
7 changed files with 841 additions and 19 deletions

201
Cargo.lock generated
View File

@ -699,6 +699,21 @@ dependencies = [
"tinyvec",
]
[[package]]
name = "btoi"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd6407f73a9b8b6162d8a2ef999fe6afd7cc15902ebf42c5cd296addf17e0ad"
dependencies = [
"num-traits",
]
[[package]]
name = "bufstream"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40e38929add23cdf8a366df9b0e088953150724bcbe5fc330b0d8eb3b328eec8"
[[package]]
name = "bumpalo"
version = "3.19.1"
@ -1230,6 +1245,15 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
@ -1546,6 +1570,17 @@ dependencies = [
"unicode-xid",
]
[[package]]
name = "derive_utils"
version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "362f47930db19fe7735f527e6595e4900316b893ebf6d48ad3d31be928d57dd6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
]
[[package]]
name = "dialoguer"
version = "0.12.0"
@ -2035,6 +2070,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c"
dependencies = [
"crc32fast",
"libz-sys",
"miniz_oxide",
"zlib-rs",
]
@ -2648,7 +2684,7 @@ dependencies = [
"libc",
"percent-encoding",
"pin-project-lite",
"socket2",
"socket2 0.6.2",
"tokio",
"tower-service",
"tracing",
@ -2984,6 +3020,15 @@ dependencies = [
"web-sys",
]
[[package]]
name = "io-enum"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7de9008599afe8527a8c9d70423437363b321649161e98473f433de802d76107"
dependencies = [
"derive_utils",
]
[[package]]
name = "io-kit-sys"
version = "0.4.1"
@ -3173,7 +3218,7 @@ dependencies = [
"percent-encoding",
"quoted_printable",
"rustls",
"socket2",
"socket2 0.6.2",
"tokio",
"url",
"webpki-roots 1.0.6",
@ -3212,6 +3257,17 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "libz-sys"
version = "1.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4735e9cbde5aac84a5ce588f6b23a90b9b0b528f6c5a8db8a4aff300463a0839"
dependencies = [
"cc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "linux-raw-sys"
version = "0.12.1"
@ -3298,6 +3354,12 @@ dependencies = [
"weezl",
]
[[package]]
name = "lru"
version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38"
[[package]]
name = "lru"
version = "0.16.3"
@ -3869,6 +3931,83 @@ version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084"
[[package]]
name = "mysql"
version = "26.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce2510a735f601bab18202b07ea0a197bd1d130d3a5ce2edf4577d225f0c3ee4"
dependencies = [
"bufstream",
"bytes",
"crossbeam-queue",
"crossbeam-utils",
"flate2",
"io-enum",
"libc",
"lru 0.12.5",
"mysql_common",
"named_pipe",
"pem",
"percent-encoding",
"socket2 0.5.10",
"twox-hash",
"url",
]
[[package]]
name = "mysql-common-derive"
version = "0.32.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "66f62cad7623a9cb6f8f64037f0c4f69c8db8e82914334a83c9788201c2c1bfa"
dependencies = [
"darling",
"heck",
"num-bigint",
"proc-macro-crate",
"proc-macro-error2",
"proc-macro2",
"quote",
"syn 2.0.116",
"termcolor",
"thiserror 2.0.18",
]
[[package]]
name = "mysql_common"
version = "0.35.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbb9f371618ce723f095c61fbcdc36e8936956d2b62832f9c7648689b338e052"
dependencies = [
"base64",
"bitflags 2.11.0",
"btoi",
"byteorder",
"bytes",
"crc32fast",
"flate2",
"getrandom 0.3.4",
"mysql-common-derive",
"num-bigint",
"num-traits",
"regex",
"saturating",
"serde",
"serde_json",
"sha1",
"sha2",
"thiserror 2.0.18",
"uuid",
]
[[package]]
name = "named_pipe"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad9c443cce91fc3e12f017290db75dde490d685cdaaf508d7159d7cf41f0eb2b"
dependencies = [
"winapi",
]
[[package]]
name = "nanohtml2text"
version = "0.2.1"
@ -4012,7 +4151,7 @@ version = "0.44.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7462c9d8ae5ef6a28d66a192d399ad2530f1f2130b13186296dbb11bdef5b3d1"
dependencies = [
"lru",
"lru 0.16.3",
"nostr",
"tokio",
]
@ -4036,7 +4175,7 @@ dependencies = [
"async-wsocket",
"atomic-destructor",
"hex",
"lru",
"lru 0.16.3",
"negentropy",
"nostr",
"nostr-database",
@ -4068,12 +4207,31 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "num-bigint"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-conv"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050"
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
@ -4817,6 +4975,7 @@ dependencies = [
"proc-macro-error-attr2",
"proc-macro2",
"quote",
"syn 2.0.116",
]
[[package]]
@ -4993,7 +5152,7 @@ dependencies = [
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"socket2 0.6.2",
"thiserror 2.0.18",
"tokio",
"tracing",
@ -5030,7 +5189,7 @@ dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"socket2 0.6.2",
"tracing",
"windows-sys 0.60.2",
]
@ -5716,6 +5875,12 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "saturating"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ece8e78b2f38ec51c51f5d475df0a7187ba5111b2a28bdc761ee05b075d40a71"
[[package]]
name = "schannel"
version = "0.1.28"
@ -6166,6 +6331,16 @@ dependencies = [
"serde",
]
[[package]]
name = "socket2"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678"
dependencies = [
"libc",
"windows-sys 0.52.0",
]
[[package]]
name = "socket2"
version = "0.6.2"
@ -6380,6 +6555,15 @@ dependencies = [
"utf-8",
]
[[package]]
name = "termcolor"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
dependencies = [
"winapi-util",
]
[[package]]
name = "thiserror"
version = "1.0.69"
@ -6515,7 +6699,7 @@ dependencies = [
"mio",
"pin-project-lite",
"signal-hook-registry",
"socket2",
"socket2 0.6.2",
"tokio-macros",
"windows-sys 0.61.2",
]
@ -6551,7 +6735,7 @@ dependencies = [
"postgres-protocol",
"postgres-types",
"rand 0.9.2",
"socket2",
"socket2 0.6.2",
"tokio",
"tokio-util",
"whoami",
@ -8459,6 +8643,7 @@ dependencies = [
"mail-parser",
"matrix-sdk",
"mime_guess",
"mysql",
"nanohtml2text",
"nostr-sdk",
"nusb",

View File

@ -105,6 +105,7 @@ prost = { version = "0.14", default-features = false, features = ["derive"], opt
rusqlite = { version = "0.37", features = ["bundled"] }
postgres = { version = "0.19", features = ["with-chrono-0_4"], optional = true }
tokio-postgres-rustls = { version = "0.12", optional = true }
mysql = { version = "26", optional = true }
chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] }
chrono-tz = "0.10"
cron = "0.15"
@ -195,6 +196,7 @@ hardware = ["nusb", "tokio-serial"]
channel-matrix = ["dep:matrix-sdk"]
channel-lark = ["dep:prost"]
memory-postgres = ["dep:postgres", "dep:tokio-postgres-rustls"]
memory-mariadb = ["dep:mysql"]
observability-otel = ["dep:opentelemetry", "dep:opentelemetry_sdk", "dep:opentelemetry-otlp"]
web-fetch-html2md = ["dep:fast_html2md"]
web-fetch-plaintext = ["dep:nanohtml2text"]

View File

@ -1901,7 +1901,7 @@ fn parse_proxy_enabled(raw: &str) -> Option<bool> {
/// Persistent storage configuration (`[storage]` section).
#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)]
pub struct StorageConfig {
/// Storage provider settings (e.g. sqlite, postgres).
/// Storage provider settings (e.g. sqlite, postgres, mariadb).
#[serde(default)]
pub provider: StorageProviderSection,
}
@ -1914,10 +1914,10 @@ pub struct StorageProviderSection {
pub config: StorageProviderConfig,
}
/// Storage provider backend configuration (e.g. postgres connection details).
/// Storage provider backend configuration (e.g. postgres/mariadb connection details).
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct StorageProviderConfig {
/// Storage engine key (e.g. "postgres", "sqlite").
/// Storage engine key (e.g. "postgres", "mariadb", "sqlite").
#[serde(default)]
pub provider: String,
@ -1943,10 +1943,10 @@ pub struct StorageProviderConfig {
#[serde(default)]
pub connect_timeout_secs: Option<u64>,
/// Enable TLS for the PostgreSQL connection.
/// Enable TLS for SQL remote connections.
///
/// `true` — require TLS (skips certificate verification; suitable for
/// self-signed certs and most managed databases).
/// `true` — request TLS from the backend (and for PostgreSQL skips certificate
/// verification; suitable for self-signed certs and many managed databases).
/// `false` (default) — plain TCP, backward-compatible.
#[serde(default)]
pub tls: bool,
@ -2012,9 +2012,9 @@ impl Default for QdrantConfig {
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[allow(clippy::struct_excessive_bools)]
pub struct MemoryConfig {
/// "sqlite" | "lucid" | "postgres" | "qdrant" | "markdown" | "none" (`none` = explicit no-op memory)
/// "sqlite" | "lucid" | "postgres" | "mariadb" | "qdrant" | "markdown" | "none" (`none` = explicit no-op memory)
///
/// `postgres` requires `[storage.provider.config]` with `db_url` (`dbURL` alias supported).
/// `postgres` / `mariadb` require `[storage.provider.config]` with `db_url` (`dbURL` alias supported).
/// `qdrant` uses `[memory.qdrant]` config or `QDRANT_URL` env var.
pub backend: String,
/// Auto-save user-stated conversation input to memory (assistant output is excluded)

View File

@ -3,6 +3,7 @@ pub enum MemoryBackendKind {
Sqlite,
Lucid,
Postgres,
Mariadb,
Qdrant,
Markdown,
None,
@ -56,6 +57,15 @@ const POSTGRES_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
optional_dependency: true,
};
const MARIADB_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
key: "mariadb",
label: "MariaDB/MySQL — remote durable storage via [storage.provider.config]",
auto_save_default: true,
uses_sqlite_hygiene: false,
sqlite_based: false,
optional_dependency: true,
};
const QDRANT_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
key: "qdrant",
label: "Qdrant — vector database for semantic search via [memory.qdrant]",
@ -103,6 +113,7 @@ pub fn classify_memory_backend(backend: &str) -> MemoryBackendKind {
"sqlite" => MemoryBackendKind::Sqlite,
"lucid" => MemoryBackendKind::Lucid,
"postgres" => MemoryBackendKind::Postgres,
"mariadb" | "mysql" => MemoryBackendKind::Mariadb,
"qdrant" => MemoryBackendKind::Qdrant,
"markdown" => MemoryBackendKind::Markdown,
"none" => MemoryBackendKind::None,
@ -115,6 +126,7 @@ pub fn memory_backend_profile(backend: &str) -> MemoryBackendProfile {
MemoryBackendKind::Sqlite => SQLITE_PROFILE,
MemoryBackendKind::Lucid => LUCID_PROFILE,
MemoryBackendKind::Postgres => POSTGRES_PROFILE,
MemoryBackendKind::Mariadb => MARIADB_PROFILE,
MemoryBackendKind::Qdrant => QDRANT_PROFILE,
MemoryBackendKind::Markdown => MARKDOWN_PROFILE,
MemoryBackendKind::None => NONE_PROFILE,
@ -134,6 +146,11 @@ mod tests {
classify_memory_backend("postgres"),
MemoryBackendKind::Postgres
);
assert_eq!(
classify_memory_backend("mariadb"),
MemoryBackendKind::Mariadb
);
assert_eq!(classify_memory_backend("mysql"), MemoryBackendKind::Mariadb);
assert_eq!(
classify_memory_backend("markdown"),
MemoryBackendKind::Markdown

View File

@ -4,7 +4,7 @@ use super::{
MemoryBackendKind,
};
use crate::config::Config;
#[cfg(feature = "memory-postgres")]
#[cfg(any(feature = "memory-postgres", feature = "memory-mariadb"))]
use anyhow::Context;
use anyhow::{bail, Result};
use console::style;
@ -72,6 +72,28 @@ fn create_cli_memory(config: &Config) -> Result<Box<dyn Memory>> {
MemoryBackendKind::Postgres => {
bail!("memory backend 'postgres' requires the 'memory-postgres' feature to be enabled");
}
#[cfg(feature = "memory-mariadb")]
MemoryBackendKind::Mariadb => {
let sp = &config.storage.provider.config;
let db_url = sp
.db_url
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty())
.context("memory backend 'mariadb' requires db_url in [storage.provider.config]")?;
let mem = super::MariadbMemory::new(
db_url,
&sp.schema,
&sp.table,
sp.connect_timeout_secs,
sp.tls,
)?;
Ok(Box::new(mem))
}
#[cfg(not(feature = "memory-mariadb"))]
MemoryBackendKind::Mariadb => {
bail!("memory backend 'mariadb' requires the 'memory-mariadb' feature to be enabled");
}
_ => create_memory_for_migration(&backend, &config.workspace_dir),
}
}

527
src/memory/mariadb.rs Normal file
View File

@ -0,0 +1,527 @@
use super::traits::{Memory, MemoryCategory, MemoryEntry};
use anyhow::{Context, Result};
use async_trait::async_trait;
use chrono::Utc;
use mysql::prelude::Queryable;
use mysql::{params, Opts, OptsBuilder, Pool, SslOpts};
use std::time::Duration;
use uuid::Uuid;
/// Maximum allowed connect timeout (seconds) to avoid unreasonable waits.
const MARIADB_CONNECT_TIMEOUT_CAP_SECS: u64 = 300;
/// MariaDB/MySQL-backed persistent memory.
///
/// This backend focuses on reliable CRUD and keyword recall using SQL.
pub struct MariadbMemory {
pool: Pool,
qualified_table: String,
}
impl MariadbMemory {
pub fn new(
db_url: &str,
schema: &str,
table: &str,
connect_timeout_secs: Option<u64>,
tls_mode: bool,
) -> Result<Self> {
validate_identifier(table, "storage table")?;
// Treat "public" as unset for MariaDB/MySQL compatibility because
// this default originates from PostgreSQL config conventions.
let schema = normalize_schema(schema);
if let Some(schema_name) = schema.as_deref() {
validate_identifier(schema_name, "storage schema")?;
}
let table_ident = quote_identifier(table);
let qualified_table = match schema.as_deref() {
Some(schema_name) => format!("{}.{}", quote_identifier(schema_name), table_ident),
None => table_ident,
};
let pool = Self::initialize_pool(
db_url,
connect_timeout_secs,
tls_mode,
schema.as_deref(),
&qualified_table,
)?;
Ok(Self {
pool,
qualified_table,
})
}
fn initialize_pool(
db_url: &str,
connect_timeout_secs: Option<u64>,
tls_mode: bool,
schema: Option<&str>,
qualified_table: &str,
) -> Result<Pool> {
let db_url = db_url.to_string();
let schema = schema.map(str::to_string);
let qualified_table = qualified_table.to_string();
let init_handle = std::thread::Builder::new()
.name("mariadb-memory-init".to_string())
.spawn(move || -> Result<Pool> {
let mut builder = OptsBuilder::from_opts(
Opts::from_url(&db_url).context("invalid MariaDB connection URL")?,
);
if let Some(timeout_secs) = connect_timeout_secs {
let bounded = timeout_secs.min(MARIADB_CONNECT_TIMEOUT_CAP_SECS);
builder = builder.tcp_connect_timeout(Some(Duration::from_secs(bounded)));
}
if tls_mode {
builder = builder.ssl_opts(Some(SslOpts::default()));
}
let pool = Pool::new(builder).context("failed to create MariaDB pool")?;
let mut conn = pool
.get_conn()
.context("failed to connect to MariaDB memory backend")?;
Self::init_schema(&mut conn, schema.as_deref(), &qualified_table)?;
drop(conn);
Ok(pool)
})
.context("failed to spawn MariaDB initializer thread")?;
init_handle
.join()
.map_err(|_| anyhow::anyhow!("MariaDB initializer thread panicked"))?
}
fn init_schema(
conn: &mut mysql::PooledConn,
schema: Option<&str>,
qualified_table: &str,
) -> Result<()> {
if let Some(schema_name) = schema {
let create_schema = format!(
"CREATE DATABASE IF NOT EXISTS {}",
quote_identifier(schema_name)
);
conn.query_drop(create_schema)?;
}
let create_table = format!(
"
CREATE TABLE IF NOT EXISTS {qualified_table} (
id VARCHAR(64) PRIMARY KEY,
`key` VARCHAR(255) NOT NULL UNIQUE,
content LONGTEXT NOT NULL,
category VARCHAR(64) NOT NULL,
created_at VARCHAR(40) NOT NULL,
updated_at VARCHAR(40) NOT NULL,
session_id VARCHAR(255) NULL,
INDEX idx_memories_category (category),
INDEX idx_memories_session_id (session_id),
INDEX idx_memories_updated_at (updated_at)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
"
);
conn.query_drop(create_table)?;
Ok(())
}
fn category_to_str(category: &MemoryCategory) -> String {
match category {
MemoryCategory::Core => "core".to_string(),
MemoryCategory::Daily => "daily".to_string(),
MemoryCategory::Conversation => "conversation".to_string(),
MemoryCategory::Custom(name) => name.clone(),
}
}
fn parse_category(value: &str) -> MemoryCategory {
match value {
"core" => MemoryCategory::Core,
"daily" => MemoryCategory::Daily,
"conversation" => MemoryCategory::Conversation,
other => MemoryCategory::Custom(other.to_string()),
}
}
fn row_to_entry(row: mysql::Row) -> Result<MemoryEntry> {
let id: String = row.get(0).context("missing id column in memory row")?;
let key: String = row.get(1).context("missing key column in memory row")?;
let content: String = row.get(2).context("missing content column in memory row")?;
let category: String = row
.get(3)
.context("missing category column in memory row")?;
let timestamp: String = row
.get(4)
.context("missing created_at column in memory row")?;
let session_id: Option<String> = row.get(5);
let score: Option<f64> = row.get(6);
Ok(MemoryEntry {
id,
key,
content,
category: Self::parse_category(&category),
timestamp,
session_id,
score,
})
}
}
fn normalize_schema(schema: &str) -> Option<String> {
let trimmed = schema.trim();
if trimmed.is_empty() || trimmed.eq_ignore_ascii_case("public") {
None
} else {
Some(trimmed.to_string())
}
}
fn validate_identifier(value: &str, field_name: &str) -> Result<()> {
if value.is_empty() {
anyhow::bail!("{field_name} must not be empty");
}
let mut chars = value.chars();
let Some(first) = chars.next() else {
anyhow::bail!("{field_name} must not be empty");
};
if !(first.is_ascii_alphabetic() || first == '_') {
anyhow::bail!("{field_name} must start with an ASCII letter or underscore; got '{value}'");
}
if !chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_') {
anyhow::bail!(
"{field_name} can only contain ASCII letters, numbers, and underscores; got '{value}'"
);
}
Ok(())
}
fn quote_identifier(value: &str) -> String {
format!("`{value}`")
}
#[async_trait]
impl Memory for MariadbMemory {
fn name(&self) -> &str {
"mariadb"
}
async fn store(
&self,
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
) -> Result<()> {
let pool = self.pool.clone();
let qualified_table = self.qualified_table.clone();
let key = key.to_string();
let content = content.to_string();
let category = Self::category_to_str(&category);
let session_id = session_id.map(str::to_string);
tokio::task::spawn_blocking(move || -> Result<()> {
let mut conn = pool.get_conn()?;
let now = Utc::now().to_rfc3339();
let sql = format!(
"
INSERT INTO {qualified_table}
(id, `key`, content, category, created_at, updated_at, session_id)
VALUES
(:id, :key, :content, :category, :created_at, :updated_at, :session_id)
ON DUPLICATE KEY UPDATE
content = VALUES(content),
category = VALUES(category),
updated_at = VALUES(updated_at),
session_id = VALUES(session_id)
"
);
conn.exec_drop(
sql,
params! {
"id" => Uuid::new_v4().to_string(),
"key" => key,
"content" => content,
"category" => category,
"created_at" => now.clone(),
"updated_at" => now,
"session_id" => session_id,
},
)?;
Ok(())
})
.await?
}
async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
) -> Result<Vec<MemoryEntry>> {
let pool = self.pool.clone();
let qualified_table = self.qualified_table.clone();
let query = query.trim().to_string();
let session_id = session_id.map(str::to_string);
tokio::task::spawn_blocking(move || -> Result<Vec<MemoryEntry>> {
let mut conn = pool.get_conn()?;
let sql = format!(
"
SELECT
id,
`key`,
content,
category,
created_at,
session_id,
(
CASE WHEN LOWER(`key`) LIKE CONCAT('%', LOWER(:query), '%') THEN 2.0 ELSE 0.0 END +
CASE WHEN LOWER(content) LIKE CONCAT('%', LOWER(:query), '%') THEN 1.0 ELSE 0.0 END
) AS score
FROM {qualified_table}
WHERE (:session_id IS NULL OR session_id = :session_id)
AND (
:query = '' OR
LOWER(`key`) LIKE CONCAT('%', LOWER(:query), '%') OR
LOWER(content) LIKE CONCAT('%', LOWER(:query), '%')
)
ORDER BY score DESC, updated_at DESC
LIMIT :limit
"
);
#[allow(clippy::cast_possible_wrap)]
let limit_i64 = limit as i64;
let rows = conn.exec(
sql,
params! {
"query" => query,
"session_id" => session_id,
"limit" => limit_i64,
},
)?;
rows.into_iter()
.map(Self::row_to_entry)
.collect::<Result<Vec<MemoryEntry>>>()
})
.await?
}
async fn get(&self, key: &str) -> Result<Option<MemoryEntry>> {
let pool = self.pool.clone();
let qualified_table = self.qualified_table.clone();
let key = key.to_string();
tokio::task::spawn_blocking(move || -> Result<Option<MemoryEntry>> {
let mut conn = pool.get_conn()?;
let sql = format!(
"
SELECT id, `key`, content, category, created_at, session_id
FROM {qualified_table}
WHERE `key` = :key
LIMIT 1
"
);
let row = conn.exec_first(sql, params! { "key" => key })?;
row.map(Self::row_to_entry).transpose()
})
.await?
}
async fn list(
&self,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> Result<Vec<MemoryEntry>> {
let pool = self.pool.clone();
let qualified_table = self.qualified_table.clone();
let category = category.map(Self::category_to_str);
let session_id = session_id.map(str::to_string);
tokio::task::spawn_blocking(move || -> Result<Vec<MemoryEntry>> {
let mut conn = pool.get_conn()?;
let sql = format!(
"
SELECT id, `key`, content, category, created_at, session_id
FROM {qualified_table}
WHERE (:category IS NULL OR category = :category)
AND (:session_id IS NULL OR session_id = :session_id)
ORDER BY updated_at DESC
"
);
let rows = conn.exec(
sql,
params! {
"category" => category,
"session_id" => session_id,
},
)?;
rows.into_iter()
.map(Self::row_to_entry)
.collect::<Result<Vec<MemoryEntry>>>()
})
.await?
}
async fn forget(&self, key: &str) -> Result<bool> {
let pool = self.pool.clone();
let qualified_table = self.qualified_table.clone();
let key = key.to_string();
tokio::task::spawn_blocking(move || -> Result<bool> {
let mut conn = pool.get_conn()?;
let sql = format!("DELETE FROM {qualified_table} WHERE `key` = :key");
conn.exec_drop(sql, params! { "key" => key })?;
Ok(conn.affected_rows() > 0)
})
.await?
}
async fn count(&self) -> Result<usize> {
let pool = self.pool.clone();
let qualified_table = self.qualified_table.clone();
tokio::task::spawn_blocking(move || -> Result<usize> {
let mut conn = pool.get_conn()?;
let sql = format!("SELECT COUNT(*) FROM {qualified_table}");
let count: Option<i64> = conn.query_first(sql)?;
let count = count.unwrap_or(0);
let count =
usize::try_from(count).context("MariaDB returned a negative memory count")?;
Ok(count)
})
.await?
}
async fn health_check(&self) -> bool {
let pool = self.pool.clone();
tokio::task::spawn_blocking(move || -> bool {
match pool.get_conn() {
Ok(mut conn) => conn.query_drop("SELECT 1").is_ok(),
Err(_) => false,
}
})
.await
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn valid_identifiers_pass_validation() {
assert!(validate_identifier("memories", "table").is_ok());
assert!(validate_identifier("_memories_01", "table").is_ok());
}
#[test]
fn invalid_identifiers_are_rejected() {
assert!(validate_identifier("", "schema").is_err());
assert!(validate_identifier("1bad", "schema").is_err());
assert!(validate_identifier("bad-name", "table").is_err());
}
#[test]
fn parse_category_maps_known_and_custom_values() {
assert_eq!(MariadbMemory::parse_category("core"), MemoryCategory::Core);
assert_eq!(
MariadbMemory::parse_category("daily"),
MemoryCategory::Daily
);
assert_eq!(
MariadbMemory::parse_category("conversation"),
MemoryCategory::Conversation
);
assert_eq!(
MariadbMemory::parse_category("custom_notes"),
MemoryCategory::Custom("custom_notes".into())
);
}
#[test]
fn normalize_schema_handles_default_postgres_schema() {
assert!(normalize_schema("").is_none());
assert!(normalize_schema("public").is_none());
assert!(normalize_schema("PUBLIC").is_none());
assert_eq!(normalize_schema("zeroclaw"), Some("zeroclaw".into()));
}
#[tokio::test(flavor = "current_thread")]
async fn new_does_not_panic_inside_tokio_runtime() {
let outcome = std::panic::catch_unwind(|| {
MariadbMemory::new(
"mysql://zeroclaw:password@127.0.0.1:1/zeroclaw",
"public",
"memories",
Some(1),
false,
)
});
assert!(outcome.is_ok(), "MariadbMemory::new should not panic");
assert!(
outcome.unwrap().is_err(),
"MariadbMemory::new should return a connect error for an unreachable endpoint"
);
}
#[tokio::test(flavor = "current_thread")]
async fn integration_roundtrip_when_test_db_is_configured() {
let Some(db_url) = std::env::var("ZEROCLAW_TEST_MARIADB_URL")
.ok()
.filter(|value| !value.trim().is_empty())
else {
eprintln!("Skipping MariaDB integration test: set ZEROCLAW_TEST_MARIADB_URL to enable");
return;
};
let schema = format!("zeroclaw_test_{}", Uuid::new_v4().simple());
let memory = MariadbMemory::new(&db_url, &schema, "memories", Some(5), false)
.expect("should initialize MariaDB memory backend");
memory
.store(
"integration_key",
"integration content",
MemoryCategory::Conversation,
None,
)
.await
.expect("store should succeed");
let fetched = memory
.get("integration_key")
.await
.expect("get should succeed")
.expect("entry should exist");
assert_eq!(fetched.content, "integration content");
let recalled = memory
.recall("integration", 5, None)
.await
.expect("recall should succeed");
assert!(
recalled.iter().any(|entry| entry.key == "integration_key"),
"recall should return the stored key"
);
}
}

View File

@ -4,6 +4,8 @@ pub mod cli;
pub mod embeddings;
pub mod hygiene;
pub mod lucid;
#[cfg(feature = "memory-mariadb")]
pub mod mariadb;
pub mod markdown;
pub mod none;
#[cfg(feature = "memory-postgres")]
@ -21,6 +23,8 @@ pub use backend::{
selectable_memory_backends, MemoryBackendKind, MemoryBackendProfile,
};
pub use lucid::LucidMemory;
#[cfg(feature = "memory-mariadb")]
pub use mariadb::MariadbMemory;
pub use markdown::MarkdownMemory;
pub use none::NoneMemory;
#[cfg(feature = "memory-postgres")]
@ -55,6 +59,9 @@ where
Ok(Box::new(LucidMemory::new(workspace_dir, local)))
}
MemoryBackendKind::Postgres => postgres_builder(),
MemoryBackendKind::Mariadb => {
anyhow::bail!("memory backend 'mariadb' is not available in this build context")
}
MemoryBackendKind::Qdrant | MemoryBackendKind::Markdown => {
Ok(Box::new(MarkdownMemory::new(workspace_dir)))
}
@ -299,6 +306,40 @@ pub fn create_memory_with_storage_and_routes(
);
}
#[cfg(feature = "memory-mariadb")]
fn build_mariadb_memory(
storage_provider: Option<&StorageProviderConfig>,
) -> anyhow::Result<Box<dyn Memory>> {
let storage_provider = storage_provider
.context("memory backend 'mariadb' requires [storage.provider.config] settings")?;
let db_url = storage_provider
.db_url
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.context(
"memory backend 'mariadb' requires [storage.provider.config].db_url (or dbURL)",
)?;
let memory = MariadbMemory::new(
db_url,
&storage_provider.schema,
&storage_provider.table,
storage_provider.connect_timeout_secs,
storage_provider.tls,
)?;
Ok(Box::new(memory))
}
#[cfg(not(feature = "memory-mariadb"))]
fn build_mariadb_memory(
_storage_provider: Option<&StorageProviderConfig>,
) -> anyhow::Result<Box<dyn Memory>> {
anyhow::bail!(
"memory backend 'mariadb' requested but this build was compiled without `memory-mariadb`; rebuild with `--features memory-mariadb`"
);
}
if matches!(backend_kind, MemoryBackendKind::Qdrant) {
let url = config
.qdrant
@ -340,6 +381,10 @@ pub fn create_memory_with_storage_and_routes(
)));
}
if matches!(backend_kind, MemoryBackendKind::Mariadb) {
return build_mariadb_memory(storage_provider);
}
create_memory_with_builders(
&backend_name,
workspace_dir,
@ -361,10 +406,10 @@ pub fn create_memory_for_migration(
if matches!(
classify_memory_backend(backend),
MemoryBackendKind::Postgres
MemoryBackendKind::Postgres | MemoryBackendKind::Mariadb
) {
anyhow::bail!(
"memory migration for backend 'postgres' is unsupported; migrate with sqlite or markdown first"
"memory migration for SQL backends ('postgres' / 'mariadb') is unsupported; migrate with sqlite or markdown first"
);
}
@ -526,6 +571,30 @@ mod tests {
}
}
#[test]
fn factory_mariadb_without_db_url_is_rejected() {
let tmp = TempDir::new().unwrap();
let cfg = MemoryConfig {
backend: "mariadb".into(),
..MemoryConfig::default()
};
let storage = StorageProviderConfig {
provider: "mariadb".into(),
db_url: None,
..StorageProviderConfig::default()
};
let error = create_memory_with_storage(&cfg, Some(&storage), tmp.path(), None)
.err()
.expect("mariadb without db_url should be rejected");
if cfg!(feature = "memory-mariadb") {
assert!(error.to_string().contains("db_url"));
} else {
assert!(error.to_string().contains("memory-mariadb"));
}
}
#[test]
fn resolve_embedding_config_uses_base_config_when_model_is_not_hint() {
let cfg = MemoryConfig {