main.rs 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. use axum::{
  2. body::Body,
  3. extract::Request,
  4. http::{Method, Response, Uri},
  5. response::Response as AxumResponse,
  6. Router,
  7. };
  8. use tower_http::cors::CorsLayer;
  9. use futures::TryStreamExt;
  10. mod tool;
  11. use tool::ToolRegistry;
  12. fn now() -> String {
  13. chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string()
  14. }
  15. #[repr(u8)]
  16. #[derive(PartialEq, PartialOrd, Debug, Clone, Copy)]
  17. #[allow(dead_code)]
  18. pub enum LogLevel {
  19. Debug = 0,
  20. Warning,
  21. Info,
  22. Error,
  23. OFF = 0xfe,
  24. Debuging,
  25. }
  26. use LogLevel::*;
  27. static mut LEVEL: LogLevel = Debug;
  28. pub fn log(level: LogLevel, msg: String) {
  29. if level < unsafe { LEVEL } {
  30. return;
  31. }
  32. println!("[{:?}] {} {}", level, now(), msg);
  33. }
  34. // 打印 HTTP 请求的 URL 和请求体
  35. async fn print_request_debug_info(uri: &Uri, method: &Method, body: &[u8]) {
  36. log(Debug, format!("HTTP Request: {} {}", method, uri));
  37. if !body.is_empty() {
  38. log(Debug, format!("Request Body: {}", String::from_utf8_lossy(body)));
  39. }
  40. }
  41. // 创建一个简单的反向代理处理函数
  42. async fn proxy_handler(
  43. uri: Uri,
  44. request: Request<Body>,
  45. ) -> AxumResponse<Body> {
  46. let client = reqwest::Client::new();
  47. // 获取请求体
  48. let (parts, body) = request.into_parts();
  49. let bytes = match axum::body::to_bytes(body, usize::MAX).await {
  50. Ok(b) => b,
  51. Err(_) => {
  52. return Response::builder()
  53. .status(500)
  54. .body(Body::empty())
  55. .unwrap();
  56. }
  57. };
  58. // 打印调试信息
  59. print_request_debug_info(&uri, &parts.method, &bytes).await;
  60. // 构建目标URL
  61. let query_part = if let Some(query) = uri.query() {
  62. format!("?{}", query)
  63. } else {
  64. "".to_string()
  65. };
  66. let target_url = format!("http://localhost:11434{}{}", uri.path(), query_part);
  67. // 转发请求
  68. let mut forwarded_request = client
  69. .request(parts.method.clone(), &target_url)
  70. .headers(parts.headers.clone());
  71. if !bytes.is_empty() {
  72. forwarded_request = forwarded_request.body(bytes.to_vec());
  73. }
  74. match forwarded_request.send().await {
  75. Ok(response) => {
  76. // 获取状态码
  77. let status = response.status();
  78. // 获取响应头
  79. let response_headers = response.headers().clone();
  80. // 在 reqwest v0.13 中,使用 bytes_stream 方法
  81. let body_stream = response.bytes_stream();
  82. let body = Body::from_stream(body_stream.map_err(|e| axum::Error::new(e)));
  83. // 构造响应
  84. let mut axum_response = Response::builder()
  85. .status(status);
  86. // 添加响应头
  87. if let Some(headers) = axum_response.headers_mut() {
  88. for (key, value) in response_headers.iter() {
  89. headers.append(key.clone(), value.clone());
  90. }
  91. }
  92. axum_response
  93. .body(body)
  94. .unwrap_or_else(|_| Response::builder().status(500).body(Body::empty()).unwrap())
  95. },
  96. Err(e) => {
  97. log(Error, format!("Request failed: error sending request for url ({}), error: {}", target_url, e));
  98. // 返回 502 Bad Gateway,表示下游服务不可达
  99. Response::builder()
  100. .status(502)
  101. .header("Content-Type", "application/json")
  102. .body(Body::from(format!(r#"{{"error": "Service Unavailable: unable to reach Ollama service at http://localhost:11434. Error: {}"}}"#, e)))
  103. .unwrap()
  104. }
  105. }
  106. }
  107. #[tokio::main]
  108. async fn main() {
  109. chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string();
  110. let hosting = "127.0.0.1:18854";
  111. log(Info, format!("welcome to clawclaw reverse proxy"));
  112. // 使用默认工具注册器
  113. let tool_router = ToolRegistry::with_default_tools().generate_routes();
  114. // 创建路由,首先合并工具路由(高优先级),然后是通配符路由(低优先级)
  115. let app = Router::new()
  116. .merge(tool_router) // 合并工具路由,优先级最高
  117. .route("/v1/models", axum::routing::get(|req: Request<Body>| async move {
  118. let uri = req.uri().clone();
  119. proxy_handler(uri, req).await
  120. }))
  121. .fallback(|req: Request<Body>| async move {
  122. let uri = req.uri().clone();
  123. proxy_handler(uri, req).await
  124. })
  125. .layer(CorsLayer::new()
  126. .allow_origin(tower_http::cors::Any)
  127. .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
  128. .allow_headers(tower_http::cors::Any));
  129. axum::serve(
  130. tokio::net::TcpListener::bind(hosting).await.unwrap(),
  131. app
  132. ).await.unwrap();
  133. }