| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- use axum::{
- body::Body,
- extract::Request,
- http::{Method, Response, Uri},
- response::Response as AxumResponse,
- Router,
- };
- use tower_http::cors::CorsLayer;
- use futures::TryStreamExt;
- mod tool;
- use tool::ToolRegistry;
- fn now() -> String {
- chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string()
- }
- #[repr(u8)]
- #[derive(PartialEq, PartialOrd, Debug, Clone, Copy)]
- #[allow(dead_code)]
- pub enum LogLevel {
- Debug = 0,
- Warning,
- Info,
- Error,
- OFF = 0xfe,
- Debuging,
- }
- use LogLevel::*;
- static mut LEVEL: LogLevel = Debug;
- pub fn log(level: LogLevel, msg: String) {
- if level < unsafe { LEVEL } {
- return;
- }
- println!("[{:?}] {} {}", level, now(), msg);
- }
- // 打印 HTTP 请求的 URL 和请求体
- async fn print_request_debug_info(uri: &Uri, method: &Method, body: &[u8]) {
- log(Debug, format!("HTTP Request: {} {}", method, uri));
- if !body.is_empty() {
- log(Debug, format!("Request Body: {}", String::from_utf8_lossy(body)));
- }
- }
- // 创建一个简单的反向代理处理函数
- async fn proxy_handler(
- uri: Uri,
- request: Request<Body>,
- ) -> AxumResponse<Body> {
- let client = reqwest::Client::new();
- // 获取请求体
- let (parts, body) = request.into_parts();
- let bytes = match axum::body::to_bytes(body, usize::MAX).await {
- Ok(b) => b,
- Err(_) => {
- return Response::builder()
- .status(500)
- .body(Body::empty())
- .unwrap();
- }
- };
- // 打印调试信息
- print_request_debug_info(&uri, &parts.method, &bytes).await;
- // 构建目标URL
- let query_part = if let Some(query) = uri.query() {
- format!("?{}", query)
- } else {
- "".to_string()
- };
- let target_url = format!("http://localhost:11434{}{}", uri.path(), query_part);
- // 转发请求
- let mut forwarded_request = client
- .request(parts.method.clone(), &target_url)
- .headers(parts.headers.clone());
- if !bytes.is_empty() {
- forwarded_request = forwarded_request.body(bytes.to_vec());
- }
- match forwarded_request.send().await {
- Ok(response) => {
- // 获取状态码
- let status = response.status();
- // 获取响应头
- let response_headers = response.headers().clone();
- // 在 reqwest v0.13 中,使用 bytes_stream 方法
- let body_stream = response.bytes_stream();
- let body = Body::from_stream(body_stream.map_err(|e| axum::Error::new(e)));
- // 构造响应
- let mut axum_response = Response::builder()
- .status(status);
- // 添加响应头
- if let Some(headers) = axum_response.headers_mut() {
- for (key, value) in response_headers.iter() {
- headers.append(key.clone(), value.clone());
- }
- }
- axum_response
- .body(body)
- .unwrap_or_else(|_| Response::builder().status(500).body(Body::empty()).unwrap())
- },
- Err(e) => {
- log(Error, format!("Request failed: error sending request for url ({}), error: {}", target_url, e));
- // 返回 502 Bad Gateway,表示下游服务不可达
- Response::builder()
- .status(502)
- .header("Content-Type", "application/json")
- .body(Body::from(format!(r#"{{"error": "Service Unavailable: unable to reach Ollama service at http://localhost:11434. Error: {}"}}"#, e)))
- .unwrap()
- }
- }
- }
- #[tokio::main]
- async fn main() {
- chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string();
- let hosting = "127.0.0.1:18854";
- log(Info, format!("welcome to clawclaw reverse proxy"));
- // 使用默认工具注册器
- let tool_router = ToolRegistry::with_default_tools().generate_routes();
- // 创建路由,首先合并工具路由(高优先级),然后是通配符路由(低优先级)
- let app = Router::new()
- .merge(tool_router) // 合并工具路由,优先级最高
- .route("/v1/models", axum::routing::get(|req: Request<Body>| async move {
- let uri = req.uri().clone();
- proxy_handler(uri, req).await
- }))
- .fallback(|req: Request<Body>| async move {
- let uri = req.uri().clone();
- proxy_handler(uri, req).await
- })
- .layer(CorsLayer::new()
- .allow_origin(tower_http::cors::Any)
- .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
- .allow_headers(tower_http::cors::Any));
- axum::serve(
- tokio::net::TcpListener::bind(hosting).await.unwrap(),
- app
- ).await.unwrap();
- }
|