StratoSDK is a framework with a declarative approach similar to Flutter/React, written and designed entirely for Rust.
| 1 | //! Advanced shader management system |
| 2 | //! |
| 3 | //! This module provides comprehensive shader management including: |
| 4 | //! - Hot-reload with file watching and automatic recompilation |
| 5 | //! - Dynamic shader compilation with macro support |
| 6 | //! - Intelligent caching with dependency tracking |
| 7 | //! - Shader variant generation and specialization |
| 8 | //! - Cross-platform shader compilation (HLSL, GLSL, WGSL) |
| 9 | //! - Performance profiling and optimization hints |
| 10 | //! - Shader debugging and validation tools |
| 11 | //! - Modular shader composition system |
| 12 | |
| 13 | use anyhow::{bail, Context, Result}; |
| 14 | use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher}; |
| 15 | use parking_lot::{Mutex, RwLock}; |
| 16 | use regex::Regex; |
| 17 | use serde::{Deserialize, Serialize}; |
| 18 | use sha2::{Digest, Sha256}; |
| 19 | use std::collections::{HashMap, HashSet}; |
| 20 | use std::fs; |
| 21 | use std::path::{Path, PathBuf}; |
| 22 | use std::sync::mpsc; |
| 23 | use std::sync::{ |
| 24 | atomic::{AtomicBool, AtomicU64, Ordering}, |
| 25 | Arc, |
| 26 | }; |
| 27 | use std::time::{Duration, Instant, SystemTime}; |
| 28 | use tracing::{info, instrument}; |
| 29 | use wgpu::*; |
| 30 | |
| 31 | use crate::device::ManagedDevice; |
| 32 | use crate::resources::ResourceHandle; |
| 33 | |
| 34 | /// Shader stage type |
| 35 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] |
| 36 | pub enum ShaderStage { |
| 37 | Vertex, |
| 38 | Fragment, |
| 39 | Compute, |
| 40 | } |
| 41 | |
| 42 | /// Shader language |
| 43 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] |
| 44 | pub enum ShaderLanguage { |
| 45 | WGSL, |
| 46 | GLSL, |
| 47 | HLSL, |
| 48 | SPIRV, |
| 49 | } |
| 50 | |
| 51 | /// Shader compilation target |
| 52 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] |
| 53 | pub enum CompilationTarget { |
| 54 | Vulkan, |
| 55 | Metal, |
| 56 | DirectX12, |
| 57 | OpenGL, |
| 58 | WebGPU, |
| 59 | } |
| 60 | |
| 61 | /// Shader macro definition |
| 62 | #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] |
| 63 | pub struct ShaderMacro { |
| 64 | pub name: String, |
| 65 | pub value: Option<String>, |
| 66 | } |
| 67 | |
| 68 | /// Shader variant configuration |
| 69 | #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] |
| 70 | pub struct ShaderVariant { |
| 71 | pub macros: Vec<ShaderMacro>, |
| 72 | pub features: Vec<String>, |
| 73 | pub optimization_level: u32, |
| 74 | } |
| 75 | |
| 76 | /// Shader source information |
| 77 | #[derive(Debug, Clone)] |
| 78 | pub struct ShaderSource { |
| 79 | pub path: PathBuf, |
| 80 | pub content: String, |
| 81 | pub language: ShaderLanguage, |
| 82 | pub stage: ShaderStage, |
| 83 | pub includes: HashSet<PathBuf>, |
| 84 | pub dependencies: HashSet<PathBuf>, |
| 85 | pub last_modified: SystemTime, |
| 86 | pub content_hash: [u8; 32], |
| 87 | } |
| 88 | |
| 89 | /// Compiled shader module |
| 90 | #[derive(Debug)] |
| 91 | pub struct CompiledShader { |
| 92 | pub module: ShaderModule, |
| 93 | pub source_hash: [u8; 32], |
| 94 | pub variant: ShaderVariant, |
| 95 | pub compilation_time: Instant, |
| 96 | pub spirv_size: usize, |
| 97 | pub validation_errors: Vec<String>, |
| 98 | pub optimization_applied: bool, |
| 99 | pub usage_count: AtomicU64, |
| 100 | pub last_used: RwLock<Instant>, |
| 101 | } |
| 102 | |
| 103 | /// Shader compilation statistics |
| 104 | #[derive(Debug, Clone, Default)] |
| 105 | pub struct CompilationStats { |
| 106 | pub total_compilations: u64, |
| 107 | pub successful_compilations: u64, |
| 108 | pub failed_compilations: u64, |
| 109 | pub cache_hits: u64, |
| 110 | pub cache_misses: u64, |
| 111 | pub hot_reloads: u64, |
| 112 | pub average_compilation_time: Duration, |
| 113 | pub total_compilation_time: Duration, |
| 114 | } |
| 115 | |
| 116 | /// Shader dependency graph node |
| 117 | #[derive(Debug, Clone)] |
| 118 | pub struct DependencyNode { |
| 119 | pub path: PathBuf, |
| 120 | pub dependents: HashSet<PathBuf>, |
| 121 | pub dependencies: HashSet<PathBuf>, |
| 122 | pub last_modified: SystemTime, |
| 123 | } |
| 124 | |
| 125 | /// Hot-reload event |
| 126 | #[derive(Debug, Clone)] |
| 127 | pub enum HotReloadEvent { |
| 128 | FileChanged(PathBuf), |
| 129 | FileDeleted(PathBuf), |
| 130 | FileCreated(PathBuf), |
| 131 | DependencyChanged(PathBuf, HashSet<PathBuf>), |
| 132 | } |
| 133 | |
| 134 | /// Shader manager with advanced features |
| 135 | pub struct ShaderManager { |
| 136 | device: Arc<ManagedDevice>, |
| 137 | shader_cache: RwLock<HashMap<([u8; 32], ShaderVariant), Arc<CompiledShader>>>, |
| 138 | source_cache: RwLock<HashMap<PathBuf, Arc<ShaderSource>>>, |
| 139 | dependency_graph: RwLock<HashMap<PathBuf, DependencyNode>>, |
| 140 | compilation_stats: RwLock<CompilationStats>, |
| 141 | |
| 142 | // Hot-reload system |
| 143 | file_watcher: Arc<Mutex<Option<notify::RecommendedWatcher>>>, |
| 144 | hot_reload_enabled: AtomicBool, |
| 145 | hot_reload_receiver: Arc<Mutex<Option<mpsc::Receiver<Event>>>>, |
| 146 | watched_directories: RwLock<HashSet<PathBuf>>, |
| 147 | |
| 148 | // Shader preprocessing |
| 149 | include_directories: RwLock<Vec<PathBuf>>, |
| 150 | global_macros: RwLock<Vec<ShaderMacro>>, |
| 151 | preprocessor_cache: RwLock<HashMap<String, String>>, |
| 152 | |
| 153 | // Performance tracking |
| 154 | compilation_queue: Mutex<Vec<(PathBuf, ShaderVariant)>>, |
| 155 | background_compilation: AtomicBool, |
| 156 | |
| 157 | // Validation and debugging |
| 158 | validation_enabled: AtomicBool, |
| 159 | debug_info_enabled: AtomicBool, |
| 160 | optimization_enabled: AtomicBool, |
| 161 | } |
| 162 | |
| 163 | impl ShaderManager { |
| 164 | /// Create a new shader manager |
| 165 | pub fn new(device: Arc<ManagedDevice>) -> Result<Self> { |
| 166 | let (tx, rx) = mpsc::channel(); |
| 167 | let watcher = notify::recommended_watcher(move |res: Result<Event, notify::Error>| { |
| 168 | if let Ok(event) = res { |
| 169 | let _ = tx.send(event); |
| 170 | } |
| 171 | })?; |
| 172 | |
| 173 | Ok(Self { |
| 174 | device, |
| 175 | shader_cache: RwLock::new(HashMap::new()), |
| 176 | source_cache: RwLock::new(HashMap::new()), |
| 177 | dependency_graph: RwLock::new(HashMap::new()), |
| 178 | compilation_stats: RwLock::new(CompilationStats::default()), |
| 179 | |
| 180 | file_watcher: Arc::new(Mutex::new(Some(watcher))), |
| 181 | hot_reload_enabled: AtomicBool::new(true), |
| 182 | hot_reload_receiver: Arc::new(Mutex::new(Some(rx))), |
| 183 | watched_directories: RwLock::new(HashSet::new()), |
| 184 | |
| 185 | include_directories: RwLock::new(Vec::new()), |
| 186 | global_macros: RwLock::new(Vec::new()), |
| 187 | preprocessor_cache: RwLock::new(HashMap::new()), |
| 188 | |
| 189 | compilation_queue: Mutex::new(Vec::new()), |
| 190 | background_compilation: AtomicBool::new(true), |
| 191 | |
| 192 | validation_enabled: AtomicBool::new(true), |
| 193 | debug_info_enabled: AtomicBool::new(false), |
| 194 | optimization_enabled: AtomicBool::new(true), |
| 195 | }) |
| 196 | } |
| 197 | |
| 198 | /// Load and compile a shader |
| 199 | #[instrument(skip(self))] |
| 200 | pub fn load_shader( |
| 201 | &self, |
| 202 | path: impl AsRef<Path> + std::fmt::Debug, |
| 203 | stage: ShaderStage, |
| 204 | variant: ShaderVariant, |
| 205 | ) -> Result<Arc<CompiledShader>> { |
| 206 | let path = path.as_ref().to_path_buf(); |
| 207 | |
| 208 | // Load source if not cached |
| 209 | let source = self.load_source(&path, stage)?; |
| 210 | |
| 211 | // Check cache first |
| 212 | let cache_key = (source.content_hash, variant.clone()); |
| 213 | if let Some(cached) = self.shader_cache.read().get(&cache_key) { |
| 214 | let mut stats = self.compilation_stats.write(); |
| 215 | stats.cache_hits += 1; |
| 216 | cached.usage_count.fetch_add(1, Ordering::Relaxed); |
| 217 | *cached.last_used.write() = Instant::now(); |
| 218 | return Ok(cached.clone()); |
| 219 | } |
| 220 | |
| 221 | // Compile shader |
| 222 | let compiled = self.compile_shader(&source, variant)?; |
| 223 | |
| 224 | // Cache the result |
| 225 | self.shader_cache |
| 226 | .write() |
| 227 | .insert(cache_key, compiled.clone()); |
| 228 | |
| 229 | let mut stats = self.compilation_stats.write(); |
| 230 | stats.cache_misses += 1; |
| 231 | |
| 232 | Ok(compiled) |
| 233 | } |
| 234 | |
| 235 | /// Load shader source from file |
| 236 | fn load_source(&self, path: &Path, stage: ShaderStage) -> Result<Arc<ShaderSource>> { |
| 237 | // Check source cache |
| 238 | if let Some(cached) = self.source_cache.read().get(path) { |
| 239 | // Check if file has been modified |
| 240 | let metadata = fs::metadata(path)?; |
| 241 | if metadata.modified()? <= cached.last_modified { |
| 242 | return Ok(cached.clone()); |
| 243 | } |
| 244 | } |
| 245 | |
| 246 | // Read file content |
| 247 | let content = fs::read_to_string(path) |
| 248 | .with_context(|| format!("Failed to read shader file: {}", path.display()))?; |
| 249 | |
| 250 | // Detect language |
| 251 | let language = self.detect_shader_language(path, &content)?; |
| 252 | |
| 253 | // Preprocess shader |
| 254 | let processed_content = self.preprocess_shader(&content, path)?; |
| 255 | |
| 256 | // Calculate content hash |
| 257 | let mut hasher = Sha256::new(); |
| 258 | hasher.update(&processed_content); |
| 259 | let content_hash: [u8; 32] = hasher.finalize().into(); |
| 260 | |
| 261 | // Extract dependencies |
| 262 | let (includes, dependencies) = self.extract_dependencies(&processed_content, path)?; |
| 263 | |
| 264 | let source = Arc::new(ShaderSource { |
| 265 | path: path.to_path_buf(), |
| 266 | content: processed_content, |
| 267 | language, |
| 268 | stage, |
| 269 | includes, |
| 270 | dependencies: dependencies.clone(), |
| 271 | last_modified: fs::metadata(path)?.modified()?, |
| 272 | content_hash, |
| 273 | }); |
| 274 | |
| 275 | // Update dependency graph |
| 276 | self.update_dependency_graph(path, dependencies); |
| 277 | |
| 278 | // Cache the source |
| 279 | self.source_cache |
| 280 | .write() |
| 281 | .insert(path.to_path_buf(), source.clone()); |
| 282 | |
| 283 | // Watch for changes if hot-reload is enabled |
| 284 | if self.hot_reload_enabled.load(Ordering::Relaxed) { |
| 285 | self.watch_file(path)?; |
| 286 | } |
| 287 | |
| 288 | Ok(source) |
| 289 | } |
| 290 | |
| 291 | /// Detect shader language from file extension and content |
| 292 | fn detect_shader_language(&self, path: &Path, content: &str) -> Result<ShaderLanguage> { |
| 293 | if let Some(ext) = path.extension().and_then(|e| e.to_str()) { |
| 294 | match ext.to_lowercase().as_str() { |
| 295 | "wgsl" => return Ok(ShaderLanguage::WGSL), |
| 296 | "glsl" | "vert" | "frag" | "comp" => return Ok(ShaderLanguage::GLSL), |
| 297 | "hlsl" | "fx" => return Ok(ShaderLanguage::HLSL), |
| 298 | "spv" => return Ok(ShaderLanguage::SPIRV), |
| 299 | _ => {} |
| 300 | } |
| 301 | } |
| 302 | |
| 303 | // Try to detect from content |
| 304 | if content.contains("@vertex") |
| 305 | || content.contains("@fragment") |
| 306 | || content.contains("@compute") |
| 307 | { |
| 308 | Ok(ShaderLanguage::WGSL) |
| 309 | } else if content.contains("#version") || content.contains("gl_") { |
| 310 | Ok(ShaderLanguage::GLSL) |
| 311 | } else if content.contains("cbuffer") || content.contains("SV_") { |
| 312 | Ok(ShaderLanguage::HLSL) |
| 313 | } else { |
| 314 | // Default to WGSL for new shaders |
| 315 | Ok(ShaderLanguage::WGSL) |
| 316 | } |
| 317 | } |
| 318 | |
| 319 | /// Preprocess shader with includes and macros |
| 320 | fn preprocess_shader(&self, content: &str, base_path: &Path) -> Result<String> { |
| 321 | let mut processed = content.to_string(); |
| 322 | |
| 323 | // Apply global macros |
| 324 | for macro_def in self.global_macros.read().iter() { |
| 325 | let replacement = match ¯o_def.value { |
| 326 | Some(value) => format!("#define {} {}", macro_def.name, value), |
| 327 | None => format!("#define {}", macro_def.name), |
| 328 | }; |
| 329 | |
| 330 | let pattern = format!(r"#define\s+{}\s*.*", regex::escape(¯o_def.name)); |
| 331 | let re = Regex::new(&pattern)?; |
| 332 | processed = re.replace_all(&processed, replacement.as_str()).to_string(); |
| 333 | } |
| 334 | |
| 335 | // Process includes |
| 336 | processed = self.process_includes(&processed, base_path)?; |
| 337 | |
| 338 | Ok(processed) |
| 339 | } |
| 340 | |
| 341 | /// Process #include directives |
| 342 | fn process_includes(&self, content: &str, base_path: &Path) -> Result<String> { |
| 343 | let include_re = Regex::new(r#"#include\s+"([^"]+)""#)?; |
| 344 | let mut processed = content.to_string(); |
| 345 | let mut included_files = HashSet::new(); |
| 346 | |
| 347 | // Recursive include processing |
| 348 | loop { |
| 349 | let mut found_include = false; |
| 350 | |
| 351 | for cap in include_re.captures_iter(&processed.clone()) { |
| 352 | let include_path = &cap[1]; |
| 353 | let full_path = self.resolve_include_path(include_path, base_path)?; |
| 354 | |
| 355 | if included_files.contains(&full_path) { |
| 356 | // Avoid circular includes |
| 357 | continue; |
| 358 | } |
| 359 | |
| 360 | let include_content = fs::read_to_string(&full_path).with_context(|| { |
| 361 | format!("Failed to read include file: {}", full_path.display()) |
| 362 | })?; |
| 363 | |
| 364 | processed = processed.replace(&cap[0], &include_content); |
| 365 | included_files.insert(full_path); |
| 366 | found_include = true; |
| 367 | break; |
| 368 | } |
| 369 | |
| 370 | if !found_include { |
| 371 | break; |
| 372 | } |
| 373 | } |
| 374 | |
| 375 | Ok(processed) |
| 376 | } |
| 377 | |
| 378 | /// Resolve include path relative to base path and include directories |
| 379 | fn resolve_include_path(&self, include_path: &str, base_path: &Path) -> Result<PathBuf> { |
| 380 | let include_path = Path::new(include_path); |
| 381 | |
| 382 | // Try relative to current file |
| 383 | if let Some(parent) = base_path.parent() { |
| 384 | let full_path = parent.join(include_path); |
| 385 | if full_path.exists() { |
| 386 | return Ok(full_path); |
| 387 | } |
| 388 | } |
| 389 | |
| 390 | // Try include directories |
| 391 | for include_dir in self.include_directories.read().iter() { |
| 392 | let full_path = include_dir.join(include_path); |
| 393 | if full_path.exists() { |
| 394 | return Ok(full_path); |
| 395 | } |
| 396 | } |
| 397 | |
| 398 | bail!("Include file not found: {}", include_path.display()); |
| 399 | } |
| 400 | |
| 401 | /// Extract shader dependencies from content |
| 402 | fn extract_dependencies( |
| 403 | &self, |
| 404 | content: &str, |
| 405 | base_path: &Path, |
| 406 | ) -> Result<(HashSet<PathBuf>, HashSet<PathBuf>)> { |
| 407 | let include_re = Regex::new(r#"#include\s+"([^"]+)""#)?; |
| 408 | let mut includes = HashSet::new(); |
| 409 | let mut dependencies = HashSet::new(); |
| 410 | |
| 411 | for cap in include_re.captures_iter(content) { |
| 412 | let include_path = &cap[1]; |
| 413 | if let Ok(full_path) = self.resolve_include_path(include_path, base_path) { |
| 414 | includes.insert(full_path.clone()); |
| 415 | dependencies.insert(full_path); |
| 416 | } |
| 417 | } |
| 418 | |
| 419 | Ok((includes, dependencies)) |
| 420 | } |
| 421 | |
| 422 | /// Update dependency graph |
| 423 | fn update_dependency_graph(&self, path: &Path, dependencies: HashSet<PathBuf>) { |
| 424 | let mut graph = self.dependency_graph.write(); |
| 425 | |
| 426 | // Update current node |
| 427 | let node = graph |
| 428 | .entry(path.to_path_buf()) |
| 429 | .or_insert_with(|| DependencyNode { |
| 430 | path: path.to_path_buf(), |
| 431 | dependents: HashSet::new(), |
| 432 | dependencies: HashSet::new(), |
| 433 | last_modified: SystemTime::now(), |
| 434 | }); |
| 435 | |
| 436 | node.dependencies = dependencies.clone(); |
| 437 | node.last_modified = fs::metadata(path) |
| 438 | .ok() |
| 439 | .and_then(|m| m.modified().ok()) |
| 440 | .unwrap_or(SystemTime::now()); |
| 441 | |
| 442 | // Update dependent nodes |
| 443 | for dep_path in dependencies { |
| 444 | let dep_node = graph |
| 445 | .entry(dep_path.clone()) |
| 446 | .or_insert_with(|| DependencyNode { |
| 447 | path: dep_path.clone(), |
| 448 | dependents: HashSet::new(), |
| 449 | dependencies: HashSet::new(), |
| 450 | last_modified: SystemTime::now(), |
| 451 | }); |
| 452 | |
| 453 | dep_node.dependents.insert(path.to_path_buf()); |
| 454 | } |
| 455 | } |
| 456 | |
| 457 | /// Compile shader with variant |
| 458 | #[instrument(skip(self, source))] |
| 459 | fn compile_shader( |
| 460 | &self, |
| 461 | source: &ShaderSource, |
| 462 | variant: ShaderVariant, |
| 463 | ) -> Result<Arc<CompiledShader>> { |
| 464 | let start_time = Instant::now(); |
| 465 | |
| 466 | // Apply variant macros |
| 467 | let mut shader_source = source.content.clone(); |
| 468 | for macro_def in &variant.macros { |
| 469 | let definition = match ¯o_def.value { |
| 470 | Some(value) => format!("#define {} {}\n", macro_def.name, value), |
| 471 | None => format!("#define {}\n", macro_def.name), |
| 472 | }; |
| 473 | shader_source = format!("{}{}", definition, shader_source); |
| 474 | } |
| 475 | |
| 476 | // Create shader module descriptor |
| 477 | let descriptor = ShaderModuleDescriptor { |
| 478 | label: Some(&format!("Shader-{}", source.path.display())), |
| 479 | source: wgpu::ShaderSource::Wgsl(shader_source.clone().into()), |
| 480 | }; |
| 481 | |
| 482 | // Compile shader |
| 483 | let module = self.device.device.create_shader_module(descriptor); |
| 484 | |
| 485 | let compilation_time = start_time.elapsed(); |
| 486 | |
| 487 | // Validate if enabled |
| 488 | let mut validation_errors = Vec::new(); |
| 489 | if self.validation_enabled.load(Ordering::Relaxed) { |
| 490 | validation_errors = self.validate_shader(&module, &source)?; |
| 491 | } |
| 492 | |
| 493 | let compiled = Arc::new(CompiledShader { |
| 494 | module, |
| 495 | source_hash: source.content_hash, |
| 496 | variant, |
| 497 | compilation_time: start_time, |
| 498 | spirv_size: shader_source.len(), // Approximation |
| 499 | validation_errors: validation_errors.clone(), |
| 500 | optimization_applied: self.optimization_enabled.load(Ordering::Relaxed), |
| 501 | usage_count: AtomicU64::new(1), |
| 502 | last_used: RwLock::new(Instant::now()), |
| 503 | }); |
| 504 | |
| 505 | // Update statistics |
| 506 | let mut stats = self.compilation_stats.write(); |
| 507 | stats.total_compilations += 1; |
| 508 | if validation_errors.is_empty() { |
| 509 | stats.successful_compilations += 1; |
| 510 | } else { |
| 511 | stats.failed_compilations += 1; |
| 512 | } |
| 513 | stats.total_compilation_time += compilation_time; |
| 514 | stats.average_compilation_time = |
| 515 | stats.total_compilation_time / stats.total_compilations as u32; |
| 516 | |
| 517 | info!( |
| 518 | "Compiled shader: {} in {:?}", |
| 519 | source.path.display(), |
| 520 | compilation_time |
| 521 | ); |
| 522 | |
| 523 | Ok(compiled) |
| 524 | } |
| 525 | |
| 526 | /// Validate compiled shader |
| 527 | fn validate_shader( |
| 528 | &self, |
| 529 | _module: &ShaderModule, |
| 530 | _source: &ShaderSource, |
| 531 | ) -> Result<Vec<String>> { |
| 532 | // Placeholder for shader validation |
| 533 | |
| 534 | Ok(Vec::new()) |
| 535 | } |
| 536 | |
| 537 | /// Watch file for changes |
| 538 | fn watch_file(&self, path: &Path) -> Result<()> { |
| 539 | if let Some(parent) = path.parent() { |
| 540 | let mut watched = self.watched_directories.write(); |
| 541 | if !watched.contains(parent) { |
| 542 | if let Some(ref mut watcher) = *self.file_watcher.lock() { |
| 543 | watcher.watch(parent, RecursiveMode::NonRecursive)?; |
| 544 | watched.insert(parent.to_path_buf()); |
| 545 | } |
| 546 | } |
| 547 | } |
| 548 | Ok(()) |
| 549 | } |
| 550 | |
| 551 | /// Process hot-reload events |
| 552 | pub fn process_hot_reload_events(&self) -> Result<Vec<HotReloadEvent>> { |
| 553 | let mut events = Vec::new(); |
| 554 | |
| 555 | if let Some(ref receiver) = *self.hot_reload_receiver.lock() { |
| 556 | while let Ok(event) = receiver.try_recv() { |
| 557 | match event.kind { |
| 558 | notify::EventKind::Modify(_) => { |
| 559 | for path in event.paths { |
| 560 | if self.is_shader_file(&path) { |
| 561 | self.invalidate_shader_cache(&path); |
| 562 | events.push(HotReloadEvent::FileChanged(path)); |
| 563 | } |
| 564 | } |
| 565 | } |
| 566 | notify::EventKind::Remove(_) => { |
| 567 | for path in event.paths { |
| 568 | if self.is_shader_file(&path) { |
| 569 | self.remove_from_cache(&path); |
| 570 | events.push(HotReloadEvent::FileDeleted(path)); |
| 571 | } |
| 572 | } |
| 573 | } |
| 574 | notify::EventKind::Create(_) => { |
| 575 | for path in event.paths { |
| 576 | if self.is_shader_file(&path) { |
| 577 | events.push(HotReloadEvent::FileCreated(path)); |
| 578 | } |
| 579 | } |
| 580 | } |
| 581 | _ => {} |
| 582 | } |
| 583 | } |
| 584 | } |
| 585 | |
| 586 | Ok(events) |
| 587 | } |
| 588 | |
| 589 | /// Check if file is a shader file |
| 590 | fn is_shader_file(&self, path: &Path) -> bool { |
| 591 | if let Some(ext) = path.extension().and_then(|e| e.to_str()) { |
| 592 | matches!( |
| 593 | ext.to_lowercase().as_str(), |
| 594 | "wgsl" | "glsl" | "hlsl" | "vert" | "frag" | "comp" |
| 595 | ) |
| 596 | } else { |
| 597 | false |
| 598 | } |
| 599 | } |
| 600 | |
| 601 | /// Invalidate shader cache for a file |
| 602 | fn invalidate_shader_cache(&self, path: &Path) { |
| 603 | // Remove from source cache |
| 604 | self.source_cache.write().remove(path); |
| 605 | |
| 606 | // Find and remove dependent shaders from compiled cache |
| 607 | let dependents = self.find_dependents(path); |
| 608 | let mut cache = self.shader_cache.write(); |
| 609 | |
| 610 | cache.retain(|_, compiled| !dependents.contains(&compiled.source_hash)); |
| 611 | |
| 612 | let mut stats = self.compilation_stats.write(); |
| 613 | stats.hot_reloads += 1; |
| 614 | |
| 615 | info!("Invalidated shader cache for: {}", path.display()); |
| 616 | } |
| 617 | |
| 618 | /// Remove shader from cache |
| 619 | fn remove_from_cache(&self, path: &Path) { |
| 620 | self.source_cache.write().remove(path); |
| 621 | self.dependency_graph.write().remove(path); |
| 622 | } |
| 623 | |
| 624 | /// Find all shaders dependent on a file |
| 625 | fn find_dependents(&self, path: &Path) -> HashSet<[u8; 32]> { |
| 626 | let mut dependents = HashSet::new(); |
| 627 | let graph = self.dependency_graph.read(); |
| 628 | |
| 629 | if let Some(node) = graph.get(path) { |
| 630 | for dependent_path in &node.dependents { |
| 631 | if let Some(source) = self.source_cache.read().get(dependent_path) { |
| 632 | dependents.insert(source.content_hash); |
| 633 | } |
| 634 | } |
| 635 | } |
| 636 | |
| 637 | dependents |
| 638 | } |
| 639 | |
| 640 | /// Add include directory |
| 641 | pub fn add_include_directory(&self, path: impl AsRef<Path>) { |
| 642 | self.include_directories |
| 643 | .write() |
| 644 | .push(path.as_ref().to_path_buf()); |
| 645 | } |
| 646 | |
| 647 | /// Add global macro |
| 648 | pub fn add_global_macro(&self, name: impl Into<String>, value: Option<String>) { |
| 649 | self.global_macros.write().push(ShaderMacro { |
| 650 | name: name.into(), |
| 651 | value, |
| 652 | }); |
| 653 | } |
| 654 | |
| 655 | /// Enable or disable hot-reload |
| 656 | pub fn set_hot_reload_enabled(&self, enabled: bool) { |
| 657 | self.hot_reload_enabled.store(enabled, Ordering::Relaxed); |
| 658 | } |
| 659 | |
| 660 | /// Enable or disable validation |
| 661 | pub fn set_validation_enabled(&self, enabled: bool) { |
| 662 | self.validation_enabled.store(enabled, Ordering::Relaxed); |
| 663 | } |
| 664 | |
| 665 | /// Enable or disable optimization |
| 666 | pub fn set_optimization_enabled(&self, enabled: bool) { |
| 667 | self.optimization_enabled.store(enabled, Ordering::Relaxed); |
| 668 | } |
| 669 | |
| 670 | /// Initialize the shader manager (placeholder for integration) |
| 671 | pub fn initialize(&self) -> Result<()> { |
| 672 | info!("Shader manager initialized"); |
| 673 | Ok(()) |
| 674 | } |
| 675 | |
| 676 | /// Check for shader reloads (integration method) |
| 677 | pub fn check_for_reloads(&self) -> Result<()> { |
| 678 | let _events = self.process_hot_reload_events()?; |
| 679 | Ok(()) |
| 680 | } |
| 681 | |
| 682 | /// Get compilation statistics |
| 683 | pub fn get_stats(&self) -> CompilationStats { |
| 684 | self.compilation_stats.read().clone() |
| 685 | } |
| 686 | |
| 687 | /// Clear all caches |
| 688 | pub fn clear_caches(&self) { |
| 689 | self.shader_cache.write().clear(); |
| 690 | self.source_cache.write().clear(); |
| 691 | self.preprocessor_cache.write().clear(); |
| 692 | |
| 693 | info!("Cleared all shader caches"); |
| 694 | } |
| 695 | |
| 696 | /// Get cache statistics |
| 697 | pub fn get_cache_stats(&self) -> (usize, usize, usize) { |
| 698 | let shader_cache_size = self.shader_cache.read().len(); |
| 699 | let source_cache_size = self.source_cache.read().len(); |
| 700 | let preprocessor_cache_size = self.preprocessor_cache.read().len(); |
| 701 | |
| 702 | ( |
| 703 | shader_cache_size, |
| 704 | source_cache_size, |
| 705 | preprocessor_cache_size, |
| 706 | ) |
| 707 | } |
| 708 | } |
| 709 | |
| 710 | impl Drop for ShaderManager { |
| 711 | fn drop(&mut self) { |
| 712 | // File watcher cleanup disabled for compatibility |
| 713 | } |
| 714 | } |
| 715 | |
| 716 | #[cfg(test)] |
| 717 | mod tests { |
| 718 | use super::*; |
| 719 | |
| 720 | #[test] |
| 721 | fn test_shader_macro_creation() { |
| 722 | let macro_def = ShaderMacro { |
| 723 | name: "MAX_LIGHTS".to_string(), |
| 724 | value: Some("16".to_string()), |
| 725 | }; |
| 726 | |
| 727 | assert_eq!(macro_def.name, "MAX_LIGHTS"); |
| 728 | assert_eq!(macro_def.value, Some("16".to_string())); |
| 729 | } |
| 730 | |
| 731 | #[test] |
| 732 | fn test_shader_variant_equality() { |
| 733 | let variant1 = ShaderVariant { |
| 734 | macros: vec![ShaderMacro { |
| 735 | name: "TEST".to_string(), |
| 736 | value: None, |
| 737 | }], |
| 738 | features: Vec::new(), |
| 739 | optimization_level: 2, |
| 740 | }; |
| 741 | |
| 742 | let variant2 = ShaderVariant { |
| 743 | macros: vec![ShaderMacro { |
| 744 | name: "TEST".to_string(), |
| 745 | value: None, |
| 746 | }], |
| 747 | features: Vec::new(), |
| 748 | optimization_level: 2, |
| 749 | }; |
| 750 | |
| 751 | assert_eq!(variant1, variant2); |
| 752 | } |
| 753 | |
| 754 | #[test] |
| 755 | fn test_language_detection() { |
| 756 | assert_eq!(ShaderLanguage::WGSL, ShaderLanguage::WGSL); |
| 757 | } |
| 758 | } |
| 759 |