Seregon/StratoSDK

StratoSDK is a framework with a declarative approach similar to Flutter/React, written and designed entirely for Rust.

Rust/27.3 KB/No license
crates/strato-core/src/state.rs
StratoSDK / crates / strato-core / src / state.rs
1//! Advanced reactive state management system for StratoUI
2//!
3//! Provides reactive state primitives with signals, stores, computed values,
4//! effects, and automatic dependency tracking similar to modern reactive frameworks
5 
6use dashmap::DashMap;
7use parking_lot::{Mutex, RwLock};
8use smallvec::SmallVec;
9use std::any::Any;
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13 
14// Helper for optional serialization
15#[cfg(feature = "serde")]
16mod serde_helper {
17 pub struct JsonInspector<'a, T: ?Sized>(pub &'a T);
18 
19 pub trait Fallback {
20 fn to_json(&self) -> String {
21 "<unserializable>".into()
22 }
23 }
24 
25 impl<'a, T: ?Sized> Fallback for JsonInspector<'a, T> {}
26 
27 impl<'a, T: ?Sized + serde::Serialize> JsonInspector<'a, T> {
28 pub fn to_json(&self) -> String {
29 serde_json::to_string(self.0).unwrap_or_else(|_| "<unserializable>".into())
30 }
31 }
32}
33 
34/// Unique identifier for state values
35pub type StateId = slotmap::DefaultKey;
36 
37/// Unique identifier for reactive computations
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39pub struct ComputationId(u64);
40 
41impl ComputationId {
42 fn new() -> Self {
43 static COUNTER: AtomicU64 = AtomicU64::new(0);
44 Self(COUNTER.fetch_add(1, Ordering::Relaxed))
45 }
46}
47 
48/// Callback function triggered on state changes
49pub type StateCallback = Box<dyn Fn(&dyn Any) + Send + Sync>;
50 
51/// Effect function that can be disposed
52pub type EffectFn = Box<dyn Fn() + Send + Sync>;
53 
54/// Disposable handle for effects and subscriptions
55pub struct Disposable {
56 dispose_fn: Box<dyn FnOnce() + Send>,
57}
58 
59impl Disposable {
60 pub fn new(dispose_fn: impl FnOnce() + Send + 'static) -> Self {
61 Self {
62 dispose_fn: Box::new(dispose_fn),
63 }
64 }
65 
66 pub fn dispose(self) {
67 (self.dispose_fn)();
68 }
69}
70 
71/// Reactive context for tracking dependencies
72#[derive(Default)]
73pub struct ReactiveContext {
74 current_computation: Arc<Mutex<Option<ComputationId>>>,
75 dependencies: Arc<RwLock<HashMap<ComputationId, Vec<StateId>>>>,
76 dependents: Arc<RwLock<HashMap<StateId, Vec<ComputationId>>>>,
77}
78 
79impl ReactiveContext {
80 pub fn new() -> Self {
81 Self::default()
82 }
83 
84 /// Track a dependency for the current computation
85 pub fn track_dependency(&self, state_id: StateId) {
86 if let Some(computation_id) = *self.current_computation.lock() {
87 self.dependencies
88 .write()
89 .entry(computation_id)
90 .or_default()
91 .push(state_id);
92 
93 self.dependents
94 .write()
95 .entry(state_id)
96 .or_default()
97 .push(computation_id);
98 }
99 }
100 
101 /// Run a computation with dependency tracking
102 pub fn run_with_tracking<T>(&self, computation_id: ComputationId, f: impl FnOnce() -> T) -> T {
103 let _guard = ComputationGuard::new(self, computation_id);
104 f()
105 }
106 
107 /// Invalidate all computations that depend on a state
108 pub fn invalidate_dependents(&self, state_id: StateId) {
109 if let Some(dependents) = self.dependents.read().get(&state_id) {
110 for &computation_id in dependents {
111 // Trigger recomputation
112 self.recompute(computation_id);
113 }
114 }
115 }
116 
117 fn recompute(&self, _computation_id: ComputationId) {
118 // Implementation for recomputing dependent values
119 // This would trigger the recomputation of computed signals and effects
120 }
121}
122 
123/// RAII guard for computation tracking
124struct ComputationGuard<'a> {
125 context: &'a ReactiveContext,
126 previous: Option<ComputationId>,
127}
128 
129impl<'a> ComputationGuard<'a> {
130 fn new(context: &'a ReactiveContext, computation_id: ComputationId) -> Self {
131 let previous = context.current_computation.lock().replace(computation_id);
132 Self { context, previous }
133 }
134}
135 
136impl<'a> Drop for ComputationGuard<'a> {
137 fn drop(&mut self) {
138 *self.context.current_computation.lock() = self.previous;
139 }
140}
141 
142/// Enhanced signal with automatic dependency tracking
143pub struct Signal<T: Clone + Send + Sync + 'static> {
144 id: StateId,
145 value: Arc<RwLock<T>>,
146 subscribers: Arc<RwLock<SmallVec<[StateCallback; 4]>>>,
147 context: Arc<ReactiveContext>,
148}
149 
150impl<T: Clone + Send + Sync + 'static> Signal<T> {
151 /// Create a new signal with initial value
152 pub fn new(initial: T) -> Self {
153 Self::with_context(initial, Arc::new(ReactiveContext::new()))
154 }
155 
156 /// Create a new signal with a specific reactive context
157 pub fn with_context(initial: T, context: Arc<ReactiveContext>) -> Self {
158 use slotmap::SlotMap;
159 use std::sync::OnceLock;
160 
161 static SLOT_MAP: OnceLock<Mutex<SlotMap<StateId, ()>>> = OnceLock::new();
162 
163 let slot_map = SLOT_MAP.get_or_init(|| Mutex::new(SlotMap::new()));
164 let id = slot_map.lock().insert(());
165 
166 Self {
167 id,
168 value: Arc::new(RwLock::new(initial)),
169 subscribers: Arc::new(RwLock::new(SmallVec::new())),
170 context,
171 }
172 }
173 
174 /// Get current value and track dependency
175 pub fn get(&self) -> T {
176 self.context.track_dependency(self.id);
177 self.value.read().clone()
178 }
179 
180 /// Get current value without tracking dependency
181 pub fn peek(&self) -> T {
182 self.value.read().clone()
183 }
184 
185 /// Set new value and notify subscribers
186 pub fn set(&self, value: T) {
187 {
188 let mut guard = self.value.write();
189 *guard = value.clone();
190 }
191 #[cfg(feature = "serde")]
192 {
193 // Record inspector snapshot if available.
194 use self::serde_helper::{Fallback, JsonInspector};
195 let detail = JsonInspector(&value).to_json();
196 crate::inspector::inspector().record_state_snapshot(self.id, detail);
197 }
198 #[cfg(not(feature = "serde"))]
199 {
200 let type_name = std::any::type_name::<T>();
201 crate::inspector::inspector()
202 .record_state_snapshot(self.id, format!("Updated {}", type_name));
203 }
204 self.notify(&value);
205 self.context.invalidate_dependents(self.id);
206 }
207 
208 /// Update value with a function
209 pub fn update(&self, f: impl FnOnce(&mut T)) {
210 let value = {
211 let mut guard = self.value.write();
212 f(&mut *guard);
213 guard.clone()
214 };
215 #[cfg(feature = "serde")]
216 {
217 use self::serde_helper::{Fallback, JsonInspector};
218 let detail = JsonInspector(&value).to_json();
219 crate::inspector::inspector().record_state_snapshot(self.id, detail);
220 }
221 #[cfg(not(feature = "serde"))]
222 {
223 let type_name = std::any::type_name::<T>();
224 crate::inspector::inspector()
225 .record_state_snapshot(self.id, format!("Updated {}", type_name));
226 }
227 self.notify(&value);
228 self.context.invalidate_dependents(self.id);
229 }
230 
231 /// Subscribe to value changes
232 pub fn subscribe(&self, callback: StateCallback) -> Disposable {
233 let subscribers = Arc::clone(&self.subscribers);
234 let callback_id = {
235 let mut subs = subscribers.write();
236 let id = subs.len();
237 subs.push(callback);
238 id
239 };
240 
241 Disposable::new(move || {
242 // Remove callback by replacing with no-op
243 if let Some(callback) = subscribers.write().get_mut(callback_id) {
244 *callback = Box::new(|_| {});
245 }
246 })
247 }
248 
249 /// Create a computed signal that derives from this signal
250 pub fn computed<U, F>(&self, f: F) -> Signal<U>
251 where
252 U: Clone + Send + Sync + 'static,
253 F: Fn(&T) -> U + Send + Sync + 'static,
254 {
255 let computation_id = ComputationId::new();
256 let computed = Signal::with_context(
257 self.context
258 .run_with_tracking(computation_id, || f(&self.get())),
259 Arc::clone(&self.context),
260 );
261 
262 let computed_clone = computed.clone();
263 let f = Arc::new(f);
264 
265 self.subscribe(Box::new(move |value: &dyn Any| {
266 if let Some(typed_value) = value.downcast_ref::<T>() {
267 let new_value = f(typed_value);
268 computed_clone.set(new_value);
269 }
270 }));
271 
272 computed
273 }
274 
275 /// Create an effect that runs when the signal changes
276 pub fn effect<F>(&self, f: F) -> Disposable
277 where
278 F: Fn(&T) + Send + Sync + 'static,
279 {
280 // Run effect immediately
281 f(&self.get());
282 
283 // Subscribe to future changes
284 self.subscribe(Box::new(move |value: &dyn Any| {
285 if let Some(typed_value) = value.downcast_ref::<T>() {
286 f(typed_value);
287 }
288 }))
289 }
290 
291 /// Create a derived signal that transforms this signal's value
292 pub fn map<U, F>(&self, f: F) -> Signal<U>
293 where
294 U: Clone + Send + Sync + 'static,
295 F: Fn(&T) -> U + Send + Sync + 'static,
296 {
297 self.computed(f)
298 }
299 
300 /// Filter signal updates based on a predicate
301 pub fn filter<F>(&self, predicate: F) -> Signal<Option<T>>
302 where
303 F: Fn(&T) -> bool + Send + Sync + 'static,
304 {
305 self.computed(move |value| {
306 if predicate(value) {
307 Some(value.clone())
308 } else {
309 None
310 }
311 })
312 }
313 
314 /// Notify all subscribers
315 fn notify(&self, value: &T) {
316 let subscribers = self.subscribers.read();
317 for callback in subscribers.iter() {
318 callback(value as &dyn Any);
319 }
320 }
321}
322 
323impl<T: Clone + Send + Sync + 'static + std::fmt::Debug> std::fmt::Debug for Signal<T> {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 f.debug_struct("Signal")
326 .field("id", &self.id)
327 .field("value", &self.peek())
328 .finish()
329 }
330}
331 
332impl<T: Clone + Send + Sync + 'static> Clone for Signal<T> {
333 fn clone(&self) -> Self {
334 Self {
335 id: self.id,
336 value: Arc::clone(&self.value),
337 subscribers: Arc::clone(&self.subscribers),
338 context: Arc::clone(&self.context),
339 }
340 }
341}
342 
343/// Store for managing multiple related state values
344pub struct Store {
345 states: DashMap<String, Box<dyn Any + Send + Sync>>,
346 context: Arc<ReactiveContext>,
347}
348 
349impl Store {
350 /// Create a new store
351 pub fn new() -> Self {
352 Self {
353 states: DashMap::new(),
354 context: Arc::new(ReactiveContext::new()),
355 }
356 }
357 
358 /// Add a signal to the store
359 pub fn add_signal<T: Clone + Send + Sync + 'static>(&self, key: &str, initial: T) -> Signal<T> {
360 let signal = Signal::with_context(initial, Arc::clone(&self.context));
361 self.states
362 .insert(key.to_string(), Box::new(signal.clone()));
363 signal
364 }
365 
366 /// Get a signal from the store
367 pub fn get_signal<T: Clone + Send + Sync + 'static>(&self, key: &str) -> Option<Signal<T>> {
368 self.states
369 .get(key)
370 .and_then(|entry| entry.value().downcast_ref::<Signal<T>>().cloned())
371 }
372 
373 /// Create a computed value that depends on multiple signals in the store
374 pub fn computed<T, F>(&self, f: F) -> Signal<T>
375 where
376 T: Clone + Send + Sync + 'static,
377 F: Fn(&Store) -> T + Send + Sync + 'static,
378 {
379 let computation_id = ComputationId::new();
380 let initial_value = self.context.run_with_tracking(computation_id, || f(self));
381 
382 Signal::with_context(initial_value, Arc::clone(&self.context))
383 }
384 
385 /// Remove a signal from the store
386 pub fn remove(&self, key: &str) -> bool {
387 self.states.remove(key).is_some()
388 }
389 
390 /// Clear all signals from the store
391 pub fn clear(&self) {
392 self.states.clear();
393 }
394 
395 /// Get the number of signals in the store
396 pub fn len(&self) -> usize {
397 self.states.len()
398 }
399 
400 /// Check if the store is empty
401 pub fn is_empty(&self) -> bool {
402 self.states.is_empty()
403 }
404}
405 
406impl Default for Store {
407 fn default() -> Self {
408 Self::new()
409 }
410}
411 
412/// Batch multiple state updates to minimize notifications
413pub struct Batch {
414 updates: Vec<Box<dyn FnOnce() + Send>>,
415}
416 
417impl Batch {
418 /// Create a new batch
419 pub fn new() -> Self {
420 Self {
421 updates: Vec::new(),
422 }
423 }
424 
425 /// Add an update to the batch
426 pub fn add<F>(&mut self, update: F)
427 where
428 F: FnOnce() + Send + 'static,
429 {
430 self.updates.push(Box::new(update));
431 }
432 
433 /// Execute all updates in the batch
434 pub fn execute(self) {
435 for update in self.updates {
436 update();
437 }
438 }
439}
440 
441impl Default for Batch {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446 
447/// Global reactive context for the application
448static GLOBAL_CONTEXT: std::sync::OnceLock<Arc<ReactiveContext>> = std::sync::OnceLock::new();
449 
450/// Get the global reactive context
451pub fn global_context() -> Arc<ReactiveContext> {
452 GLOBAL_CONTEXT
453 .get_or_init(|| Arc::new(ReactiveContext::new()))
454 .clone()
455}
456 
457/// Create a signal with the global context
458pub fn signal<T: Clone + Send + Sync + 'static>(initial: T) -> Signal<T> {
459 Signal::with_context(initial, global_context())
460}
461 
462/// Create a computed signal with the global context
463pub fn computed<T, F>(f: F) -> Signal<T>
464where
465 T: Clone + Send + Sync + 'static,
466 F: Fn() -> T + Send + Sync + 'static,
467{
468 let computation_id = ComputationId::new();
469 let context = global_context();
470 let initial_value = context.run_with_tracking(computation_id, f);
471 
472 Signal::with_context(initial_value, context)
473}
474 
475/// Create an effect with the global context
476pub fn effect<F>(f: F) -> Disposable
477where
478 F: Fn() + Send + Sync + 'static,
479{
480 // Run effect immediately
481 f();
482 
483 // Return a disposable that does nothing for now
484 
485 Disposable::new(|| {})
486}
487 
488/// State trait for type-erased state management
489pub trait State: Send + Sync {
490 fn as_any(&self) -> &dyn Any;
491 fn clone_state(&self) -> Box<dyn State>;
492}
493 
494impl<T: Clone + Send + Sync + 'static> State for T {
495 fn as_any(&self) -> &dyn Any {
496 self
497 }
498 
499 fn clone_state(&self) -> Box<dyn State> {
500 Box::new(self.clone())
501 }
502}
503 
504#[cfg(test)]
505mod tests {
506 use super::*;
507 use std::sync::atomic::{AtomicI32, Ordering};
508 
509 #[test]
510 fn test_signal_basic() {
511 let signal = Signal::new(42);
512 assert_eq!(signal.get(), 42);
513 
514 signal.set(100);
515 assert_eq!(signal.get(), 100);
516 }
517 
518 #[test]
519 fn test_signal_subscribe() {
520 let signal = Signal::new(0);
521 let counter = Arc::new(AtomicI32::new(0));
522 let counter_clone = Arc::clone(&counter);
523 
524 let _disposable = signal.subscribe(Box::new(move |value: &dyn Any| {
525 if let Some(&val) = value.downcast_ref::<i32>() {
526 counter_clone.store(val, Ordering::Relaxed);
527 }
528 }));
529 
530 signal.set(42);
531 assert_eq!(counter.load(Ordering::Relaxed), 42);
532 }
533 
534 #[test]
535 fn test_computed_signal() {
536 let base = Signal::new(10);
537 let doubled = base.computed(|&x| x * 2);
538 
539 assert_eq!(doubled.get(), 20);
540 
541 base.set(15);
542 assert_eq!(doubled.get(), 30);
543 }
544 
545 #[test]
546 fn test_store() {
547 let store = Store::new();
548 let counter = store.add_signal("counter", 0);
549 let name = store.add_signal("name", "test".to_string());
550 
551 assert_eq!(counter.get(), 0);
552 assert_eq!(name.get(), "test");
553 
554 counter.set(42);
555 assert_eq!(store.get_signal::<i32>("counter").unwrap().get(), 42);
556 }
557 
558 #[test]
559 fn test_batch_updates() {
560 let signal1 = Signal::new(0);
561 let signal2 = Signal::new(0);
562 
563 let mut batch = Batch::new();
564 let s1 = signal1.clone();
565 let s2 = signal2.clone();
566 
567 batch.add(move || s1.set(10));
568 batch.add(move || s2.set(20));
569 
570 batch.execute();
571 
572 assert_eq!(signal1.get(), 10);
573 assert_eq!(signal2.get(), 20);
574 }
575 
576 #[test]
577 fn test_signal_map() {
578 let signal = Signal::new(5);
579 let mapped = signal.map(|&x| x.to_string());
580 
581 assert_eq!(mapped.get(), "5");
582 
583 signal.set(10);
584 assert_eq!(mapped.get(), "10");
585 }
586 
587 #[test]
588 fn test_signal_filter() {
589 let signal = Signal::new(5);
590 let filtered = signal.filter(|&x| x > 10);
591 
592 assert_eq!(filtered.get(), None);
593 
594 signal.set(15);
595 assert_eq!(filtered.get(), Some(15));
596 }
597}
598