use axum::{ body::Body, extract::Request, http::{Method, Response, Uri}, response::Response as AxumResponse, Router, }; use tower_http::cors::CorsLayer; use futures::TryStreamExt; mod tool; use tool::ToolRegistry; fn now() -> String { chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string() } #[repr(u8)] #[derive(PartialEq, PartialOrd, Debug, Clone, Copy)] #[allow(dead_code)] pub enum LogLevel { Debug = 0, Warning, Info, Error, OFF = 0xfe, Debuging, } use LogLevel::*; static mut LEVEL: LogLevel = Debug; pub fn log(level: LogLevel, msg: String) { if level < unsafe { LEVEL } { return; } println!("[{:?}] {} {}", level, now(), msg); } // 打印 HTTP 请求的 URL 和请求体 async fn print_request_debug_info(uri: &Uri, method: &Method, body: &[u8]) { log(Debug, format!("HTTP Request: {} {}", method, uri)); if !body.is_empty() { log(Debug, format!("Request Body: {}", String::from_utf8_lossy(body))); } } // 创建一个简单的反向代理处理函数 async fn proxy_handler( uri: Uri, request: Request, ) -> AxumResponse { let client = reqwest::Client::new(); // 获取请求体 let (parts, body) = request.into_parts(); let bytes = match axum::body::to_bytes(body, usize::MAX).await { Ok(b) => b, Err(_) => { return Response::builder() .status(500) .body(Body::empty()) .unwrap(); } }; // 打印调试信息 print_request_debug_info(&uri, &parts.method, &bytes).await; // 构建目标URL let query_part = if let Some(query) = uri.query() { format!("?{}", query) } else { "".to_string() }; let target_url = format!("http://localhost:11434{}{}", uri.path(), query_part); // 转发请求 let mut forwarded_request = client .request(parts.method.clone(), &target_url) .headers(parts.headers.clone()); if !bytes.is_empty() { forwarded_request = forwarded_request.body(bytes.to_vec()); } match forwarded_request.send().await { Ok(response) => { // 获取状态码 let status = response.status(); // 获取响应头 let response_headers = response.headers().clone(); // 在 reqwest v0.13 中,使用 bytes_stream 方法 let body_stream = response.bytes_stream(); let body = Body::from_stream(body_stream.map_err(|e| axum::Error::new(e))); // 构造响应 let mut axum_response = Response::builder() .status(status); // 添加响应头 if let Some(headers) = axum_response.headers_mut() { for (key, value) in response_headers.iter() { headers.append(key.clone(), value.clone()); } } axum_response .body(body) .unwrap_or_else(|_| Response::builder().status(500).body(Body::empty()).unwrap()) }, Err(e) => { log(Error, format!("Request failed: error sending request for url ({}), error: {}", target_url, e)); // 返回 502 Bad Gateway,表示下游服务不可达 Response::builder() .status(502) .header("Content-Type", "application/json") .body(Body::from(format!(r#"{{"error": "Service Unavailable: unable to reach Ollama service at http://localhost:11434. Error: {}"}}"#, e))) .unwrap() } } } #[tokio::main] async fn main() { chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string(); let hosting = "127.0.0.1:18854"; log(Info, format!("welcome to clawclaw reverse proxy")); // 使用默认工具注册器 let tool_router = ToolRegistry::with_default_tools().generate_routes(); // 创建路由,首先合并工具路由(高优先级),然后是通配符路由(低优先级) let app = Router::new() .merge(tool_router) // 合并工具路由,优先级最高 .route("/v1/models", axum::routing::get(|req: Request| async move { let uri = req.uri().clone(); proxy_handler(uri, req).await })) .fallback(|req: Request| async move { let uri = req.uri().clone(); proxy_handler(uri, req).await }) .layer(CorsLayer::new() .allow_origin(tower_http::cors::Any) .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]) .allow_headers(tower_http::cors::Any)); axum::serve( tokio::net::TcpListener::bind(hosting).await.unwrap(), app ).await.unwrap(); }