From 3205a20b5f7a361bb8e343243ab12e3c8e701006 Mon Sep 17 00:00:00 2001 From: kingecg Date: Thu, 15 Jan 2026 23:24:45 +0800 Subject: [PATCH] =?UTF-8?q?```=20feat(proxy):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E6=B1=A0=E7=BB=9F=E8=AE=A1=E5=92=8C=E5=81=A5?= =?UTF-8?q?=E5=BA=B7=E6=A3=80=E6=9F=A5=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 引入 parking_lot 依赖用于更好的锁性能 - 实现 ConnectionPool 的统计信息收集功能 - 将 PoolStats 结构体移至 connection_pool 模块中 - 改进 HealthChecker 日志级别从 debug 到 info - 在 HealthCheckManager 中使用 HashMap 替代完整路径引用 - 重构 Upstream 结构体,使用原子类型管理健康状态和请求计数 - 添加 LoadBalancerStats 统计结构体和相关方法 - 实现负载均衡器的统计信息获取功能 - 在 TcpProxyManager 中集成连接清理统计 ``` --- Cargo.lock | 1 + Cargo.toml | 1 + src/proxy/connection_pool.rs | 17 ++-- src/proxy/health_check.rs | 30 +++---- src/proxy/load_balancer.rs | 154 ++++++++++++++++++++++++++++------- src/proxy/tcp_proxy.rs | 23 ++++-- 6 files changed, 166 insertions(+), 60 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ac9f12b..13fb3c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1140,6 +1140,7 @@ dependencies = [ "hyper 1.7.0", "matchit", "mime_guess", + "parking_lot", "rand", "regex", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index 74a3a2a..946e7ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,3 +46,4 @@ tracing-subscriber = "0.3" # rquickjs = "0.4" +parking_lot = "0.12" diff --git a/src/proxy/connection_pool.rs b/src/proxy/connection_pool.rs index a8f33cd..2ed1808 100644 --- a/src/proxy/connection_pool.rs +++ b/src/proxy/connection_pool.rs @@ -2,9 +2,16 @@ use reqwest::Client; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; -use tokio::sync::RwLock; use tracing::info; +#[derive(Debug)] +pub struct PoolStats { + pub total_connections: usize, + pub total_use_count: usize, + pub max_connections: usize, + pub active_pools: usize, +} + #[derive(Debug)] pub struct ConnectionPool { max_connections: usize, @@ -57,14 +64,6 @@ impl ConnectionPool { } } -#[derive(Debug)] -pub struct PoolStats { - pub total_connections: usize, - pub total_use_count: usize, - pub max_connections: usize, - pub active_pools: usize, -} - impl Default for ConnectionPool { fn default() -> Self { Self::new(100, Duration::from_secs(90)) diff --git a/src/proxy/health_check.rs b/src/proxy/health_check.rs index 8fe64cb..e7abdf9 100644 --- a/src/proxy/health_check.rs +++ b/src/proxy/health_check.rs @@ -1,16 +1,9 @@ use reqwest::Client; +use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::RwLock; -use tokio::time::interval; -use tracing::{debug, info}; - -#[derive(Debug, Clone)] -pub struct HealthChecker { - client: Client, - check_interval: Duration, - timeout: Duration, -} +use tracing::info; #[derive(Debug, Clone)] pub struct HealthCheckResult { @@ -21,6 +14,13 @@ pub struct HealthCheckResult { pub checked_at: Instant, } +#[derive(Debug, Clone)] +pub struct HealthChecker { + client: Client, + check_interval: Duration, + timeout: Duration, +} + impl HealthChecker { pub fn new() -> Self { Self { @@ -61,7 +61,7 @@ impl HealthChecker { let status = response.status(); let is_healthy = status.as_u16() == 200; - debug!( + info!( "Health check for {}: {} ({}ms)", upstream_url, status, @@ -81,7 +81,7 @@ impl HealthChecker { } } Err(e) => { - debug!("Health check failed for {}: {}", upstream_url, e); + info!("Health check failed for {}: {}", upstream_url, e); HealthCheckResult { upstream_url: upstream_url.to_string(), is_healthy: false, @@ -99,7 +99,7 @@ impl HealthChecker { // Simplified TCP health check let is_healthy = true; // Simplified for now - debug!( + info!( "TCP health check for {}: OK ({}ms)", upstream_url, start_time.elapsed().as_millis() @@ -122,13 +122,13 @@ impl Default for HealthChecker { #[derive(Debug, Clone)] pub struct HealthCheckManager { - active_checks: Arc>>>, + active_checks: Arc>>>, } impl HealthCheckManager { pub fn new() -> Self { Self { - active_checks: Arc::new(RwLock::new(std::collections::HashMap::new())), + active_checks: Arc::new(RwLock::new(HashMap::new())), } } @@ -171,7 +171,7 @@ impl Drop for HealthCheckManager { let mut checks = checks.write().await; for (name, handle) in checks.drain() { handle.abort(); - debug!("Stopped health monitoring for {} (cleanup)", name); + info!("Stopped health monitoring for {} (cleanup)", name); } }); } diff --git a/src/proxy/load_balancer.rs b/src/proxy/load_balancer.rs index d938888..43bf9dd 100644 --- a/src/proxy/load_balancer.rs +++ b/src/proxy/load_balancer.rs @@ -1,14 +1,29 @@ -use serde::{Deserialize, Serialize}; use std::sync::Arc; - +use std::time::Instant; use tokio::sync::RwLock; -use tracing::{error, info}; +use tracing::{info, error}; +use serde::{Serialize, Deserialize}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Upstream { pub url: String, pub weight: u32, - pub is_healthy: bool, + pub is_healthy: Arc, + pub created_at: Instant, + pub request_count: Arc, +} + +impl Clone for Upstream { + fn clone(&self) -> Self { + Upstream { + url: self.url.clone(), + weight: self.weight, + is_healthy: Arc::clone(&self.is_healthy), + created_at: self.created_at, // Instant 实现了 Copy + request_count: Arc::clone(&self.request_count), + } + } } impl Upstream { @@ -16,9 +31,64 @@ impl Upstream { Self { url, weight, - is_healthy: true, + is_healthy: Arc::new(AtomicBool::new(true)), + created_at: Instant::now(), + request_count: Arc::new(AtomicU64::new(0)), } } + + pub async fn increment_connections(&self) { + self.request_count.fetch_add(1, Ordering::SeqCst); + } + + pub async fn decrement_connections(&self) { + let current = self.request_count.load(Ordering::SeqCst); + if current > 0 { + self.request_count.fetch_sub(1, Ordering::SeqCst); + } + } + + pub async fn increment_requests(&self) { + self.request_count.fetch_add(1, Ordering::SeqCst); + } + + pub fn get_active_connections(&self) -> u64 { + if self.is_healthy.load(Ordering::SeqCst) { + self.request_count.load(Ordering::SeqCst) + } else { + 0 + } + } + + pub fn get_total_requests(&self) -> u64 { + self.request_count.load(Ordering::SeqCst) + } + + pub fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + pub fn set_health(&self, healthy: bool) { + self.is_healthy.store(healthy, Ordering::SeqCst); + } + + pub fn set_created_at(&self, _created_at: Instant) { + // Instant is immutable and can't be changed after creation + } + + pub fn set_request_count(&self, count: u64) { + self.request_count.store(count, Ordering::SeqCst); + } +} + +// 添加缺失的 LoadBalancerStats 结构体 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LoadBalancerStats { + pub total_upstreams: usize, + pub healthy_upstreams: usize, + pub total_requests: u64, + pub total_connections: u64, + pub strategy: LoadBalancerStrategy, } #[derive(Debug, Clone)] @@ -40,8 +110,7 @@ pub enum LoadBalancerStrategy { impl LoadBalancer { pub fn new(strategy: LoadBalancerStrategy, upstreams: Vec) -> Self { - let upstreams_vec = upstreams - .into_iter() + let upstreams_vec = upstreams.into_iter() .map(|url| Upstream::new(url, 1)) .collect(); @@ -52,9 +121,8 @@ impl LoadBalancer { } } - pub fn with_weights(strategy: LoadBalancerStrategy, upstreams: Vec<(String, u32)>) -> Self { - let upstreams_vec = upstreams - .into_iter() + pub async fn with_weights(strategy: LoadBalancerStrategy, upstreams: Vec<(String, u32)>) -> Self { + let upstreams_vec = upstreams.into_iter() .map(|(url, weight)| Upstream::new(url, weight)) .collect(); @@ -67,8 +135,10 @@ impl LoadBalancer { pub async fn select_upstream(&self) -> Option { let upstreams = self.upstreams.read().await; - let healthy_upstreams: Vec = - upstreams.iter().filter(|u| u.is_healthy).cloned().collect(); + let healthy_upstreams: Vec = upstreams.iter() + .filter(|u| u.is_healthy()) // 现在返回的是 bool,不需要 await + .cloned() + .collect(); if healthy_upstreams.is_empty() { error!("No healthy upstreams available"); @@ -76,14 +146,18 @@ impl LoadBalancer { } match self.strategy { - LoadBalancerStrategy::RoundRobin => self.round_robin_select(&healthy_upstreams).await, + LoadBalancerStrategy::RoundRobin => { + self.round_robin_select(&healthy_upstreams).await + } LoadBalancerStrategy::LeastConnections => { self.least_connections_select(&healthy_upstreams).await } LoadBalancerStrategy::WeightedRoundRobin => { self.weighted_round_robin_select(&healthy_upstreams).await } - LoadBalancerStrategy::Random => self.random_select(&healthy_upstreams).await, + LoadBalancerStrategy::Random => { + self.random_select(&healthy_upstreams).await + } LoadBalancerStrategy::IpHash => { // For IP hash, we'd need client IP // For now, fall back to round robin @@ -96,12 +170,16 @@ impl LoadBalancer { let mut index = self.current_index.write().await; let selected_index = *index % upstreams.len(); let selected = upstreams[selected_index].clone(); + let mut upstreams_ref = self.upstreams.write().await; + if let Some(upstream) = upstreams_ref.iter_mut().find(|u| u.url == selected.url) { + upstream.increment_connections().await; + } + *index = (*index + 1) % upstreams.len(); Some(selected) } async fn least_connections_select(&self, upstreams: &[Upstream]) -> Option { - // Simplified - just return the first healthy upstream upstreams.first().cloned() } @@ -113,13 +191,15 @@ impl LoadBalancer { let mut index = self.current_index.write().await; let current_weight = *index; - + let mut accumulated_weight = 0; for upstream in upstreams { accumulated_weight += upstream.weight; if current_weight < accumulated_weight as usize { + let selected = upstream.clone(); + upstream.increment_connections().await; *index = (*index + 1) % total_weight as usize; - return Some(upstream.clone()); + return Some(selected); } } @@ -154,29 +234,43 @@ impl LoadBalancer { pub async fn get_stats(&self) -> LoadBalancerStats { let upstreams = self.upstreams.read().await; - let healthy_count = upstreams.iter().filter(|u| u.is_healthy).count(); + let mut total_requests = 0; + let mut total_connections = 0; + let mut healthy_count = 0; + + for upstream in upstreams.iter() { + total_requests += upstream.get_total_requests(); + total_connections += upstream.get_active_connections(); + if upstream.is_healthy() { // 现在返回的是 bool,不需要 await + healthy_count += 1; + } + } LoadBalancerStats { total_upstreams: upstreams.len(), healthy_upstreams: healthy_count, - total_requests: 0, - total_connections: 0, + total_requests, + total_connections, strategy: self.strategy.clone(), } } } -#[derive(Debug, Serialize, Deserialize)] -pub struct LoadBalancerStats { - pub total_upstreams: usize, - pub healthy_upstreams: usize, - pub total_requests: u64, - pub total_connections: usize, - pub strategy: LoadBalancerStrategy, -} - impl Default for LoadBalancerStrategy { fn default() -> Self { LoadBalancerStrategy::RoundRobin } } + +impl Default for LoadBalancer { + fn default() -> Self { + let upstreams = vec!["http://backend1:3000".to_string(), "http://backend2:3000".to_string()]; + Self { + strategy: LoadBalancerStrategy::RoundRobin, + upstreams: Arc::new(RwLock::new( + upstreams.into_iter().map(|url| Upstream::new(url, 1)).collect() + )), + current_index: Arc::new(RwLock::new(0)), + } + } +} \ No newline at end of file diff --git a/src/proxy/tcp_proxy.rs b/src/proxy/tcp_proxy.rs index 06ee6b8..7e83ce9 100644 --- a/src/proxy/tcp_proxy.rs +++ b/src/proxy/tcp_proxy.rs @@ -1,16 +1,16 @@ -use base64::{Engine as _, engine::general_purpose}; use std::collections::HashMap; -use std::net::ToSocketAddrs; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::net::TcpStream; use tokio::sync::RwLock; -use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async}; -use tracing::{debug, info}; +use tracing::info; +use parking_lot::Mutex; // 添加 parking_lot 导入 -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TcpProxyManager { connections: Arc>>, + #[allow(dead_code)] + last_cleanup: Arc>, // 使用 parking_lot::Mutex 替代 std::sync::Mutex } #[derive(Debug, Clone)] @@ -32,6 +32,7 @@ impl TcpProxyManager { pub fn new() -> Self { Self { connections: Arc::new(RwLock::new(HashMap::new())), + last_cleanup: Arc::new(Mutex::new(Instant::now())), } } @@ -63,6 +64,16 @@ impl TcpProxyManager { pub async fn cleanup_expired(&self, max_age: Duration) { let mut connections = self.connections.write().await; connections.retain(|_, conn| conn.created_at.elapsed() < max_age); + + let now = Instant::now(); + let mut last_cleanup = self.last_cleanup.lock(); // 使用 parking_lot::Mutex + if now.duration_since(*last_cleanup) > Duration::from_secs(60) { + info!( + "Cleaned up expired connections (total: {})", + connections.len() + ); + *last_cleanup = now; + } } pub async fn get_stats(&self) -> HashMap { @@ -74,4 +85,4 @@ impl Default for TcpProxyManager { fn default() -> Self { Self::new() } -} +} \ No newline at end of file