diff --git a/Cargo.lock b/Cargo.lock index f6c584d57..63f81c032 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -699,6 +699,21 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "btoi" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd6407f73a9b8b6162d8a2ef999fe6afd7cc15902ebf42c5cd296addf17e0ad" +dependencies = [ + "num-traits", +] + +[[package]] +name = "bufstream" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40e38929add23cdf8a366df9b0e088953150724bcbe5fc330b0d8eb3b328eec8" + [[package]] name = "bumpalo" version = "3.19.1" @@ -1230,6 +1245,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -1546,6 +1570,17 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "derive_utils" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "362f47930db19fe7735f527e6595e4900316b893ebf6d48ad3d31be928d57dd6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + [[package]] name = "dialoguer" version = "0.12.0" @@ -2035,6 +2070,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" dependencies = [ "crc32fast", + "libz-sys", "miniz_oxide", "zlib-rs", ] @@ -2648,7 +2684,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2", + "socket2 0.6.2", "tokio", "tower-service", "tracing", @@ -2984,6 +3020,15 @@ dependencies = [ "web-sys", ] +[[package]] +name = "io-enum" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7de9008599afe8527a8c9d70423437363b321649161e98473f433de802d76107" +dependencies = [ + "derive_utils", +] + [[package]] name = "io-kit-sys" version = "0.4.1" @@ -3173,7 +3218,7 @@ dependencies = [ "percent-encoding", "quoted_printable", "rustls", - "socket2", + "socket2 0.6.2", "tokio", "url", "webpki-roots 1.0.6", @@ -3212,6 +3257,17 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libz-sys" +version = "1.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4735e9cbde5aac84a5ce588f6b23a90b9b0b528f6c5a8db8a4aff300463a0839" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.12.1" @@ -3298,6 +3354,12 @@ dependencies = [ "weezl", ] +[[package]] +name = "lru" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" + [[package]] name = "lru" version = "0.16.3" @@ -3869,6 +3931,83 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" +[[package]] +name = "mysql" +version = "26.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce2510a735f601bab18202b07ea0a197bd1d130d3a5ce2edf4577d225f0c3ee4" +dependencies = [ + "bufstream", + "bytes", + "crossbeam-queue", + "crossbeam-utils", + "flate2", + "io-enum", + "libc", + "lru 0.12.5", + "mysql_common", + "named_pipe", + "pem", + "percent-encoding", + "socket2 0.5.10", + "twox-hash", + "url", +] + +[[package]] +name = "mysql-common-derive" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66f62cad7623a9cb6f8f64037f0c4f69c8db8e82914334a83c9788201c2c1bfa" +dependencies = [ + "darling", + "heck", + "num-bigint", + "proc-macro-crate", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn 2.0.116", + "termcolor", + "thiserror 2.0.18", +] + +[[package]] +name = "mysql_common" +version = "0.35.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbb9f371618ce723f095c61fbcdc36e8936956d2b62832f9c7648689b338e052" +dependencies = [ + "base64", + "bitflags 2.11.0", + "btoi", + "byteorder", + "bytes", + "crc32fast", + "flate2", + "getrandom 0.3.4", + "mysql-common-derive", + "num-bigint", + "num-traits", + "regex", + "saturating", + "serde", + "serde_json", + "sha1", + "sha2", + "thiserror 2.0.18", + "uuid", +] + +[[package]] +name = "named_pipe" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad9c443cce91fc3e12f017290db75dde490d685cdaaf508d7159d7cf41f0eb2b" +dependencies = [ + "winapi", +] + [[package]] name = "nanohtml2text" version = "0.2.1" @@ -4012,7 +4151,7 @@ version = "0.44.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7462c9d8ae5ef6a28d66a192d399ad2530f1f2130b13186296dbb11bdef5b3d1" dependencies = [ - "lru", + "lru 0.16.3", "nostr", "tokio", ] @@ -4036,7 +4175,7 @@ dependencies = [ "async-wsocket", "atomic-destructor", "hex", - "lru", + "lru 0.16.3", "negentropy", "nostr", "nostr-database", @@ -4068,12 +4207,31 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-conv" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -4817,6 +4975,7 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", + "syn 2.0.116", ] [[package]] @@ -4993,7 +5152,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls", - "socket2", + "socket2 0.6.2", "thiserror 2.0.18", "tokio", "tracing", @@ -5030,7 +5189,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2", + "socket2 0.6.2", "tracing", "windows-sys 0.60.2", ] @@ -5716,6 +5875,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "saturating" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ece8e78b2f38ec51c51f5d475df0a7187ba5111b2a28bdc761ee05b075d40a71" + [[package]] name = "schannel" version = "0.1.28" @@ -6166,6 +6331,16 @@ dependencies = [ "serde", ] +[[package]] +name = "socket2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "socket2" version = "0.6.2" @@ -6380,6 +6555,15 @@ dependencies = [ "utf-8", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -6515,7 +6699,7 @@ dependencies = [ "mio", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.6.2", "tokio-macros", "windows-sys 0.61.2", ] @@ -6551,7 +6735,7 @@ dependencies = [ "postgres-protocol", "postgres-types", "rand 0.9.2", - "socket2", + "socket2 0.6.2", "tokio", "tokio-util", "whoami", @@ -8459,6 +8643,7 @@ dependencies = [ "mail-parser", "matrix-sdk", "mime_guess", + "mysql", "nanohtml2text", "nostr-sdk", "nusb", diff --git a/Cargo.toml b/Cargo.toml index 7a5680657..391e0a64e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -105,6 +105,7 @@ prost = { version = "0.14", default-features = false, features = ["derive"], opt rusqlite = { version = "0.37", features = ["bundled"] } postgres = { version = "0.19", features = ["with-chrono-0_4"], optional = true } tokio-postgres-rustls = { version = "0.12", optional = true } +mysql = { version = "26", optional = true } chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] } chrono-tz = "0.10" cron = "0.15" @@ -195,6 +196,7 @@ hardware = ["nusb", "tokio-serial"] channel-matrix = ["dep:matrix-sdk"] channel-lark = ["dep:prost"] memory-postgres = ["dep:postgres", "dep:tokio-postgres-rustls"] +memory-mariadb = ["dep:mysql"] observability-otel = ["dep:opentelemetry", "dep:opentelemetry_sdk", "dep:opentelemetry-otlp"] web-fetch-html2md = ["dep:fast_html2md"] web-fetch-plaintext = ["dep:nanohtml2text"] diff --git a/src/config/schema.rs b/src/config/schema.rs index 532fa6887..614d4fcac 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -1901,7 +1901,7 @@ fn parse_proxy_enabled(raw: &str) -> Option { /// Persistent storage configuration (`[storage]` section). #[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)] pub struct StorageConfig { - /// Storage provider settings (e.g. sqlite, postgres). + /// Storage provider settings (e.g. sqlite, postgres, mariadb). #[serde(default)] pub provider: StorageProviderSection, } @@ -1914,10 +1914,10 @@ pub struct StorageProviderSection { pub config: StorageProviderConfig, } -/// Storage provider backend configuration (e.g. postgres connection details). +/// Storage provider backend configuration (e.g. postgres/mariadb connection details). #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct StorageProviderConfig { - /// Storage engine key (e.g. "postgres", "sqlite"). + /// Storage engine key (e.g. "postgres", "mariadb", "sqlite"). #[serde(default)] pub provider: String, @@ -1943,10 +1943,10 @@ pub struct StorageProviderConfig { #[serde(default)] pub connect_timeout_secs: Option, - /// Enable TLS for the PostgreSQL connection. + /// Enable TLS for SQL remote connections. /// - /// `true` — require TLS (skips certificate verification; suitable for - /// self-signed certs and most managed databases). + /// `true` — request TLS from the backend (and for PostgreSQL skips certificate + /// verification; suitable for self-signed certs and many managed databases). /// `false` (default) — plain TCP, backward-compatible. #[serde(default)] pub tls: bool, @@ -2012,9 +2012,9 @@ impl Default for QdrantConfig { #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[allow(clippy::struct_excessive_bools)] pub struct MemoryConfig { - /// "sqlite" | "lucid" | "postgres" | "qdrant" | "markdown" | "none" (`none` = explicit no-op memory) + /// "sqlite" | "lucid" | "postgres" | "mariadb" | "qdrant" | "markdown" | "none" (`none` = explicit no-op memory) /// - /// `postgres` requires `[storage.provider.config]` with `db_url` (`dbURL` alias supported). + /// `postgres` / `mariadb` require `[storage.provider.config]` with `db_url` (`dbURL` alias supported). /// `qdrant` uses `[memory.qdrant]` config or `QDRANT_URL` env var. pub backend: String, /// Auto-save user-stated conversation input to memory (assistant output is excluded) diff --git a/src/memory/backend.rs b/src/memory/backend.rs index 353d1b3cd..9efe558a8 100644 --- a/src/memory/backend.rs +++ b/src/memory/backend.rs @@ -3,6 +3,7 @@ pub enum MemoryBackendKind { Sqlite, Lucid, Postgres, + Mariadb, Qdrant, Markdown, None, @@ -56,6 +57,15 @@ const POSTGRES_PROFILE: MemoryBackendProfile = MemoryBackendProfile { optional_dependency: true, }; +const MARIADB_PROFILE: MemoryBackendProfile = MemoryBackendProfile { + key: "mariadb", + label: "MariaDB/MySQL — remote durable storage via [storage.provider.config]", + auto_save_default: true, + uses_sqlite_hygiene: false, + sqlite_based: false, + optional_dependency: true, +}; + const QDRANT_PROFILE: MemoryBackendProfile = MemoryBackendProfile { key: "qdrant", label: "Qdrant — vector database for semantic search via [memory.qdrant]", @@ -103,6 +113,7 @@ pub fn classify_memory_backend(backend: &str) -> MemoryBackendKind { "sqlite" => MemoryBackendKind::Sqlite, "lucid" => MemoryBackendKind::Lucid, "postgres" => MemoryBackendKind::Postgres, + "mariadb" | "mysql" => MemoryBackendKind::Mariadb, "qdrant" => MemoryBackendKind::Qdrant, "markdown" => MemoryBackendKind::Markdown, "none" => MemoryBackendKind::None, @@ -115,6 +126,7 @@ pub fn memory_backend_profile(backend: &str) -> MemoryBackendProfile { MemoryBackendKind::Sqlite => SQLITE_PROFILE, MemoryBackendKind::Lucid => LUCID_PROFILE, MemoryBackendKind::Postgres => POSTGRES_PROFILE, + MemoryBackendKind::Mariadb => MARIADB_PROFILE, MemoryBackendKind::Qdrant => QDRANT_PROFILE, MemoryBackendKind::Markdown => MARKDOWN_PROFILE, MemoryBackendKind::None => NONE_PROFILE, @@ -134,6 +146,11 @@ mod tests { classify_memory_backend("postgres"), MemoryBackendKind::Postgres ); + assert_eq!( + classify_memory_backend("mariadb"), + MemoryBackendKind::Mariadb + ); + assert_eq!(classify_memory_backend("mysql"), MemoryBackendKind::Mariadb); assert_eq!( classify_memory_backend("markdown"), MemoryBackendKind::Markdown diff --git a/src/memory/cli.rs b/src/memory/cli.rs index 66bba58ec..fc7b61127 100644 --- a/src/memory/cli.rs +++ b/src/memory/cli.rs @@ -4,7 +4,7 @@ use super::{ MemoryBackendKind, }; use crate::config::Config; -#[cfg(feature = "memory-postgres")] +#[cfg(any(feature = "memory-postgres", feature = "memory-mariadb"))] use anyhow::Context; use anyhow::{bail, Result}; use console::style; @@ -72,6 +72,28 @@ fn create_cli_memory(config: &Config) -> Result> { MemoryBackendKind::Postgres => { bail!("memory backend 'postgres' requires the 'memory-postgres' feature to be enabled"); } + #[cfg(feature = "memory-mariadb")] + MemoryBackendKind::Mariadb => { + let sp = &config.storage.provider.config; + let db_url = sp + .db_url + .as_deref() + .map(str::trim) + .filter(|v| !v.is_empty()) + .context("memory backend 'mariadb' requires db_url in [storage.provider.config]")?; + let mem = super::MariadbMemory::new( + db_url, + &sp.schema, + &sp.table, + sp.connect_timeout_secs, + sp.tls, + )?; + Ok(Box::new(mem)) + } + #[cfg(not(feature = "memory-mariadb"))] + MemoryBackendKind::Mariadb => { + bail!("memory backend 'mariadb' requires the 'memory-mariadb' feature to be enabled"); + } _ => create_memory_for_migration(&backend, &config.workspace_dir), } } diff --git a/src/memory/mariadb.rs b/src/memory/mariadb.rs new file mode 100644 index 000000000..ba414fb7d --- /dev/null +++ b/src/memory/mariadb.rs @@ -0,0 +1,527 @@ +use super::traits::{Memory, MemoryCategory, MemoryEntry}; +use anyhow::{Context, Result}; +use async_trait::async_trait; +use chrono::Utc; +use mysql::prelude::Queryable; +use mysql::{params, Opts, OptsBuilder, Pool, SslOpts}; +use std::time::Duration; +use uuid::Uuid; + +/// Maximum allowed connect timeout (seconds) to avoid unreasonable waits. +const MARIADB_CONNECT_TIMEOUT_CAP_SECS: u64 = 300; + +/// MariaDB/MySQL-backed persistent memory. +/// +/// This backend focuses on reliable CRUD and keyword recall using SQL. +pub struct MariadbMemory { + pool: Pool, + qualified_table: String, +} + +impl MariadbMemory { + pub fn new( + db_url: &str, + schema: &str, + table: &str, + connect_timeout_secs: Option, + tls_mode: bool, + ) -> Result { + validate_identifier(table, "storage table")?; + + // Treat "public" as unset for MariaDB/MySQL compatibility because + // this default originates from PostgreSQL config conventions. + let schema = normalize_schema(schema); + if let Some(schema_name) = schema.as_deref() { + validate_identifier(schema_name, "storage schema")?; + } + + let table_ident = quote_identifier(table); + let qualified_table = match schema.as_deref() { + Some(schema_name) => format!("{}.{}", quote_identifier(schema_name), table_ident), + None => table_ident, + }; + + let pool = Self::initialize_pool( + db_url, + connect_timeout_secs, + tls_mode, + schema.as_deref(), + &qualified_table, + )?; + + Ok(Self { + pool, + qualified_table, + }) + } + + fn initialize_pool( + db_url: &str, + connect_timeout_secs: Option, + tls_mode: bool, + schema: Option<&str>, + qualified_table: &str, + ) -> Result { + let db_url = db_url.to_string(); + let schema = schema.map(str::to_string); + let qualified_table = qualified_table.to_string(); + + let init_handle = std::thread::Builder::new() + .name("mariadb-memory-init".to_string()) + .spawn(move || -> Result { + let mut builder = OptsBuilder::from_opts( + Opts::from_url(&db_url).context("invalid MariaDB connection URL")?, + ); + + if let Some(timeout_secs) = connect_timeout_secs { + let bounded = timeout_secs.min(MARIADB_CONNECT_TIMEOUT_CAP_SECS); + builder = builder.tcp_connect_timeout(Some(Duration::from_secs(bounded))); + } + + if tls_mode { + builder = builder.ssl_opts(Some(SslOpts::default())); + } + + let pool = Pool::new(builder).context("failed to create MariaDB pool")?; + let mut conn = pool + .get_conn() + .context("failed to connect to MariaDB memory backend")?; + Self::init_schema(&mut conn, schema.as_deref(), &qualified_table)?; + drop(conn); + Ok(pool) + }) + .context("failed to spawn MariaDB initializer thread")?; + + init_handle + .join() + .map_err(|_| anyhow::anyhow!("MariaDB initializer thread panicked"))? + } + + fn init_schema( + conn: &mut mysql::PooledConn, + schema: Option<&str>, + qualified_table: &str, + ) -> Result<()> { + if let Some(schema_name) = schema { + let create_schema = format!( + "CREATE DATABASE IF NOT EXISTS {}", + quote_identifier(schema_name) + ); + conn.query_drop(create_schema)?; + } + + let create_table = format!( + " + CREATE TABLE IF NOT EXISTS {qualified_table} ( + id VARCHAR(64) PRIMARY KEY, + `key` VARCHAR(255) NOT NULL UNIQUE, + content LONGTEXT NOT NULL, + category VARCHAR(64) NOT NULL, + created_at VARCHAR(40) NOT NULL, + updated_at VARCHAR(40) NOT NULL, + session_id VARCHAR(255) NULL, + INDEX idx_memories_category (category), + INDEX idx_memories_session_id (session_id), + INDEX idx_memories_updated_at (updated_at) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + " + ); + conn.query_drop(create_table)?; + + Ok(()) + } + + fn category_to_str(category: &MemoryCategory) -> String { + match category { + MemoryCategory::Core => "core".to_string(), + MemoryCategory::Daily => "daily".to_string(), + MemoryCategory::Conversation => "conversation".to_string(), + MemoryCategory::Custom(name) => name.clone(), + } + } + + fn parse_category(value: &str) -> MemoryCategory { + match value { + "core" => MemoryCategory::Core, + "daily" => MemoryCategory::Daily, + "conversation" => MemoryCategory::Conversation, + other => MemoryCategory::Custom(other.to_string()), + } + } + + fn row_to_entry(row: mysql::Row) -> Result { + let id: String = row.get(0).context("missing id column in memory row")?; + let key: String = row.get(1).context("missing key column in memory row")?; + let content: String = row.get(2).context("missing content column in memory row")?; + let category: String = row + .get(3) + .context("missing category column in memory row")?; + let timestamp: String = row + .get(4) + .context("missing created_at column in memory row")?; + let session_id: Option = row.get(5); + let score: Option = row.get(6); + + Ok(MemoryEntry { + id, + key, + content, + category: Self::parse_category(&category), + timestamp, + session_id, + score, + }) + } +} + +fn normalize_schema(schema: &str) -> Option { + let trimmed = schema.trim(); + if trimmed.is_empty() || trimmed.eq_ignore_ascii_case("public") { + None + } else { + Some(trimmed.to_string()) + } +} + +fn validate_identifier(value: &str, field_name: &str) -> Result<()> { + if value.is_empty() { + anyhow::bail!("{field_name} must not be empty"); + } + + let mut chars = value.chars(); + let Some(first) = chars.next() else { + anyhow::bail!("{field_name} must not be empty"); + }; + + if !(first.is_ascii_alphabetic() || first == '_') { + anyhow::bail!("{field_name} must start with an ASCII letter or underscore; got '{value}'"); + } + + if !chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_') { + anyhow::bail!( + "{field_name} can only contain ASCII letters, numbers, and underscores; got '{value}'" + ); + } + + Ok(()) +} + +fn quote_identifier(value: &str) -> String { + format!("`{value}`") +} + +#[async_trait] +impl Memory for MariadbMemory { + fn name(&self) -> &str { + "mariadb" + } + + async fn store( + &self, + key: &str, + content: &str, + category: MemoryCategory, + session_id: Option<&str>, + ) -> Result<()> { + let pool = self.pool.clone(); + let qualified_table = self.qualified_table.clone(); + let key = key.to_string(); + let content = content.to_string(); + let category = Self::category_to_str(&category); + let session_id = session_id.map(str::to_string); + + tokio::task::spawn_blocking(move || -> Result<()> { + let mut conn = pool.get_conn()?; + let now = Utc::now().to_rfc3339(); + let sql = format!( + " + INSERT INTO {qualified_table} + (id, `key`, content, category, created_at, updated_at, session_id) + VALUES + (:id, :key, :content, :category, :created_at, :updated_at, :session_id) + ON DUPLICATE KEY UPDATE + content = VALUES(content), + category = VALUES(category), + updated_at = VALUES(updated_at), + session_id = VALUES(session_id) + " + ); + + conn.exec_drop( + sql, + params! { + "id" => Uuid::new_v4().to_string(), + "key" => key, + "content" => content, + "category" => category, + "created_at" => now.clone(), + "updated_at" => now, + "session_id" => session_id, + }, + )?; + Ok(()) + }) + .await? + } + + async fn recall( + &self, + query: &str, + limit: usize, + session_id: Option<&str>, + ) -> Result> { + let pool = self.pool.clone(); + let qualified_table = self.qualified_table.clone(); + let query = query.trim().to_string(); + let session_id = session_id.map(str::to_string); + + tokio::task::spawn_blocking(move || -> Result> { + let mut conn = pool.get_conn()?; + let sql = format!( + " + SELECT + id, + `key`, + content, + category, + created_at, + session_id, + ( + CASE WHEN LOWER(`key`) LIKE CONCAT('%', LOWER(:query), '%') THEN 2.0 ELSE 0.0 END + + CASE WHEN LOWER(content) LIKE CONCAT('%', LOWER(:query), '%') THEN 1.0 ELSE 0.0 END + ) AS score + FROM {qualified_table} + WHERE (:session_id IS NULL OR session_id = :session_id) + AND ( + :query = '' OR + LOWER(`key`) LIKE CONCAT('%', LOWER(:query), '%') OR + LOWER(content) LIKE CONCAT('%', LOWER(:query), '%') + ) + ORDER BY score DESC, updated_at DESC + LIMIT :limit + " + ); + + #[allow(clippy::cast_possible_wrap)] + let limit_i64 = limit as i64; + + let rows = conn.exec( + sql, + params! { + "query" => query, + "session_id" => session_id, + "limit" => limit_i64, + }, + )?; + + rows.into_iter() + .map(Self::row_to_entry) + .collect::>>() + }) + .await? + } + + async fn get(&self, key: &str) -> Result> { + let pool = self.pool.clone(); + let qualified_table = self.qualified_table.clone(); + let key = key.to_string(); + + tokio::task::spawn_blocking(move || -> Result> { + let mut conn = pool.get_conn()?; + let sql = format!( + " + SELECT id, `key`, content, category, created_at, session_id + FROM {qualified_table} + WHERE `key` = :key + LIMIT 1 + " + ); + + let row = conn.exec_first(sql, params! { "key" => key })?; + row.map(Self::row_to_entry).transpose() + }) + .await? + } + + async fn list( + &self, + category: Option<&MemoryCategory>, + session_id: Option<&str>, + ) -> Result> { + let pool = self.pool.clone(); + let qualified_table = self.qualified_table.clone(); + let category = category.map(Self::category_to_str); + let session_id = session_id.map(str::to_string); + + tokio::task::spawn_blocking(move || -> Result> { + let mut conn = pool.get_conn()?; + let sql = format!( + " + SELECT id, `key`, content, category, created_at, session_id + FROM {qualified_table} + WHERE (:category IS NULL OR category = :category) + AND (:session_id IS NULL OR session_id = :session_id) + ORDER BY updated_at DESC + " + ); + + let rows = conn.exec( + sql, + params! { + "category" => category, + "session_id" => session_id, + }, + )?; + + rows.into_iter() + .map(Self::row_to_entry) + .collect::>>() + }) + .await? + } + + async fn forget(&self, key: &str) -> Result { + let pool = self.pool.clone(); + let qualified_table = self.qualified_table.clone(); + let key = key.to_string(); + + tokio::task::spawn_blocking(move || -> Result { + let mut conn = pool.get_conn()?; + let sql = format!("DELETE FROM {qualified_table} WHERE `key` = :key"); + conn.exec_drop(sql, params! { "key" => key })?; + Ok(conn.affected_rows() > 0) + }) + .await? + } + + async fn count(&self) -> Result { + let pool = self.pool.clone(); + let qualified_table = self.qualified_table.clone(); + + tokio::task::spawn_blocking(move || -> Result { + let mut conn = pool.get_conn()?; + let sql = format!("SELECT COUNT(*) FROM {qualified_table}"); + let count: Option = conn.query_first(sql)?; + let count = count.unwrap_or(0); + let count = + usize::try_from(count).context("MariaDB returned a negative memory count")?; + Ok(count) + }) + .await? + } + + async fn health_check(&self) -> bool { + let pool = self.pool.clone(); + tokio::task::spawn_blocking(move || -> bool { + match pool.get_conn() { + Ok(mut conn) => conn.query_drop("SELECT 1").is_ok(), + Err(_) => false, + } + }) + .await + .unwrap_or(false) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn valid_identifiers_pass_validation() { + assert!(validate_identifier("memories", "table").is_ok()); + assert!(validate_identifier("_memories_01", "table").is_ok()); + } + + #[test] + fn invalid_identifiers_are_rejected() { + assert!(validate_identifier("", "schema").is_err()); + assert!(validate_identifier("1bad", "schema").is_err()); + assert!(validate_identifier("bad-name", "table").is_err()); + } + + #[test] + fn parse_category_maps_known_and_custom_values() { + assert_eq!(MariadbMemory::parse_category("core"), MemoryCategory::Core); + assert_eq!( + MariadbMemory::parse_category("daily"), + MemoryCategory::Daily + ); + assert_eq!( + MariadbMemory::parse_category("conversation"), + MemoryCategory::Conversation + ); + assert_eq!( + MariadbMemory::parse_category("custom_notes"), + MemoryCategory::Custom("custom_notes".into()) + ); + } + + #[test] + fn normalize_schema_handles_default_postgres_schema() { + assert!(normalize_schema("").is_none()); + assert!(normalize_schema("public").is_none()); + assert!(normalize_schema("PUBLIC").is_none()); + assert_eq!(normalize_schema("zeroclaw"), Some("zeroclaw".into())); + } + + #[tokio::test(flavor = "current_thread")] + async fn new_does_not_panic_inside_tokio_runtime() { + let outcome = std::panic::catch_unwind(|| { + MariadbMemory::new( + "mysql://zeroclaw:password@127.0.0.1:1/zeroclaw", + "public", + "memories", + Some(1), + false, + ) + }); + + assert!(outcome.is_ok(), "MariadbMemory::new should not panic"); + assert!( + outcome.unwrap().is_err(), + "MariadbMemory::new should return a connect error for an unreachable endpoint" + ); + } + + #[tokio::test(flavor = "current_thread")] + async fn integration_roundtrip_when_test_db_is_configured() { + let Some(db_url) = std::env::var("ZEROCLAW_TEST_MARIADB_URL") + .ok() + .filter(|value| !value.trim().is_empty()) + else { + eprintln!("Skipping MariaDB integration test: set ZEROCLAW_TEST_MARIADB_URL to enable"); + return; + }; + + let schema = format!("zeroclaw_test_{}", Uuid::new_v4().simple()); + let memory = MariadbMemory::new(&db_url, &schema, "memories", Some(5), false) + .expect("should initialize MariaDB memory backend"); + + memory + .store( + "integration_key", + "integration content", + MemoryCategory::Conversation, + None, + ) + .await + .expect("store should succeed"); + + let fetched = memory + .get("integration_key") + .await + .expect("get should succeed") + .expect("entry should exist"); + assert_eq!(fetched.content, "integration content"); + + let recalled = memory + .recall("integration", 5, None) + .await + .expect("recall should succeed"); + assert!( + recalled.iter().any(|entry| entry.key == "integration_key"), + "recall should return the stored key" + ); + } +} diff --git a/src/memory/mod.rs b/src/memory/mod.rs index 5e022ef53..d0120965e 100644 --- a/src/memory/mod.rs +++ b/src/memory/mod.rs @@ -4,6 +4,8 @@ pub mod cli; pub mod embeddings; pub mod hygiene; pub mod lucid; +#[cfg(feature = "memory-mariadb")] +pub mod mariadb; pub mod markdown; pub mod none; #[cfg(feature = "memory-postgres")] @@ -21,6 +23,8 @@ pub use backend::{ selectable_memory_backends, MemoryBackendKind, MemoryBackendProfile, }; pub use lucid::LucidMemory; +#[cfg(feature = "memory-mariadb")] +pub use mariadb::MariadbMemory; pub use markdown::MarkdownMemory; pub use none::NoneMemory; #[cfg(feature = "memory-postgres")] @@ -55,6 +59,9 @@ where Ok(Box::new(LucidMemory::new(workspace_dir, local))) } MemoryBackendKind::Postgres => postgres_builder(), + MemoryBackendKind::Mariadb => { + anyhow::bail!("memory backend 'mariadb' is not available in this build context") + } MemoryBackendKind::Qdrant | MemoryBackendKind::Markdown => { Ok(Box::new(MarkdownMemory::new(workspace_dir))) } @@ -299,6 +306,40 @@ pub fn create_memory_with_storage_and_routes( ); } + #[cfg(feature = "memory-mariadb")] + fn build_mariadb_memory( + storage_provider: Option<&StorageProviderConfig>, + ) -> anyhow::Result> { + let storage_provider = storage_provider + .context("memory backend 'mariadb' requires [storage.provider.config] settings")?; + let db_url = storage_provider + .db_url + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .context( + "memory backend 'mariadb' requires [storage.provider.config].db_url (or dbURL)", + )?; + + let memory = MariadbMemory::new( + db_url, + &storage_provider.schema, + &storage_provider.table, + storage_provider.connect_timeout_secs, + storage_provider.tls, + )?; + Ok(Box::new(memory)) + } + + #[cfg(not(feature = "memory-mariadb"))] + fn build_mariadb_memory( + _storage_provider: Option<&StorageProviderConfig>, + ) -> anyhow::Result> { + anyhow::bail!( + "memory backend 'mariadb' requested but this build was compiled without `memory-mariadb`; rebuild with `--features memory-mariadb`" + ); + } + if matches!(backend_kind, MemoryBackendKind::Qdrant) { let url = config .qdrant @@ -340,6 +381,10 @@ pub fn create_memory_with_storage_and_routes( ))); } + if matches!(backend_kind, MemoryBackendKind::Mariadb) { + return build_mariadb_memory(storage_provider); + } + create_memory_with_builders( &backend_name, workspace_dir, @@ -361,10 +406,10 @@ pub fn create_memory_for_migration( if matches!( classify_memory_backend(backend), - MemoryBackendKind::Postgres + MemoryBackendKind::Postgres | MemoryBackendKind::Mariadb ) { anyhow::bail!( - "memory migration for backend 'postgres' is unsupported; migrate with sqlite or markdown first" + "memory migration for SQL backends ('postgres' / 'mariadb') is unsupported; migrate with sqlite or markdown first" ); } @@ -526,6 +571,30 @@ mod tests { } } + #[test] + fn factory_mariadb_without_db_url_is_rejected() { + let tmp = TempDir::new().unwrap(); + let cfg = MemoryConfig { + backend: "mariadb".into(), + ..MemoryConfig::default() + }; + + let storage = StorageProviderConfig { + provider: "mariadb".into(), + db_url: None, + ..StorageProviderConfig::default() + }; + + let error = create_memory_with_storage(&cfg, Some(&storage), tmp.path(), None) + .err() + .expect("mariadb without db_url should be rejected"); + if cfg!(feature = "memory-mariadb") { + assert!(error.to_string().contains("db_url")); + } else { + assert!(error.to_string().contains("memory-mariadb")); + } + } + #[test] fn resolve_embedding_config_uses_base_config_when_model_is_not_hint() { let cfg = MemoryConfig {