sqlite.rs 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. use bb8::Pool;
  2. /// SQLite连接池类型
  3. pub type SqlitePool = bb8::Pool<SqliteConnectionManager>;
  4. /// SQLite连接管理器
  5. #[derive(Clone)]
  6. pub struct SqliteConnectionManager {
  7. url: String,
  8. }
  9. impl SqliteConnectionManager {
  10. /// 创建新的SQLite连接管理器
  11. pub fn new(url: String) -> Self {
  12. Self { url }
  13. }
  14. }
  15. impl bb8::ManageConnection for SqliteConnectionManager {
  16. type Connection = rusqlite::Connection;
  17. type Error = rusqlite::Error;
  18. fn connect(&self) -> impl Future<Output = Result<Self::Connection, Self::Error>> + Send {
  19. let url = self.url.clone();
  20. async move {
  21. let conn = tokio::task::spawn_blocking(move || rusqlite::Connection::open(&url))
  22. .await
  23. .unwrap();
  24. conn
  25. }
  26. }
  27. fn is_valid(&self, _conn: &mut Self::Connection) -> impl Future<Output = Result<(), Self::Error>> + Send {
  28. // 对于SQLite,我们暂时跳过验证
  29. async move {
  30. Ok(())
  31. }
  32. }
  33. fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
  34. // 暂时假设连接未损坏
  35. false
  36. }
  37. }
  38. /// 初始化SQLite连接池
  39. pub async fn init_sqlite_pool(url: &str, max_size: u32) -> Result<SqlitePool, Box<dyn std::error::Error>> {
  40. let manager = SqliteConnectionManager::new(url.to_string());
  41. let pool = Pool::builder()
  42. .max_size(max_size)
  43. .build(manager)
  44. .await?;
  45. Ok(pool)
  46. }
  47. impl crate::datasource::Datasource for SqlitePool{
  48. async fn query<P, T, F>(&self, sql: &str, params: P,f: F) -> Result<T, String>
  49. where
  50. P: rusqlite::Params,
  51. F: FnOnce(&rusqlite::Row<'_>) -> rusqlite::Result<T>,
  52. {
  53. match tokio::time::timeout(std::time::Duration::from_secs(5), self.get()).await {
  54. Ok(Ok(conn)) => conn,
  55. Ok(Err(e)) => return Err(format!("connection err: {}",e.to_string())),
  56. Err(_) => return Err("Timeout".to_string())
  57. }.query_row(sql, params, f).map_err(|e| if e == rusqlite::Error::QueryReturnedNoRows { String::new() } else{ e.to_string()})
  58. }
  59. async fn query_rows<P, T, F>(&self, sql: &str, params: P, mut f: F) -> Result<Vec<T>, String>
  60. where
  61. P: rusqlite::Params,
  62. F: FnMut(&rusqlite::Row<'_>) -> rusqlite::Result<T>,
  63. {
  64. match tokio::time::timeout(std::time::Duration::from_secs(5), self.get()).await {
  65. Ok(Ok(conn)) => conn,
  66. Ok(Err(e)) => return Err(format!("connection err: {}",e.to_string())),
  67. Err(_) => return Err("Timeout".to_string())
  68. }.prepare(sql)
  69. .map_err(|e| e.to_string())
  70. .and_then(|mut stmt| {
  71. let mut results = Vec::new();
  72. let mut rows = stmt.query(params).map_err(|e| e.to_string())?;
  73. while let Some(row) = rows.next().map_err(|e| e.to_string())? {
  74. match f(row) {
  75. Ok(item) => results.push(item),
  76. Err(e) => return Err(e.to_string()),
  77. }
  78. }
  79. Ok(results)
  80. })
  81. }
  82. async fn execute<P>(&self, sql: &str, params:P) -> Result<usize, String>
  83. where
  84. P: rusqlite::Params {
  85. match tokio::time::timeout(std::time::Duration::from_secs(5), self.get()).await {
  86. Ok(Ok(conn)) => conn,
  87. Ok(Err(e)) => return Err(format!("connection err: {}",e.to_string())),
  88. Err(_) => return Err("Timeout".to_string())
  89. }.execute(sql, params).map_err(|e| if e.sqlite_error_code()==Some(rusqlite::ErrorCode::ConstraintViolation) { println!("{e}");String::new() } else{ e.to_string()})
  90. }
  91. }
  92. #[cfg(test)]
  93. mod tests {
  94. use super::*;
  95. use crate::datasource::Datasource;
  96. #[tokio::test]
  97. async fn test_sqlite_pool() {
  98. let pool = init_sqlite_pool("./db.sqlite", 10).await.unwrap();
  99. match pool.execute("insert into flow_task_share(did,typo,ticket)values(?,?,?)", (1,1,1)).await{
  100. Ok(n) => println!("inserted {} rows", n),
  101. Err(e) => println!("Err {}", e),
  102. };
  103. }
  104. }