Seregon/ShadPKG

A tool for deriving PKG packet encryption keys for ps4 written in c++

C++/47.3 KB/No license
common/slab_heap.h
ShadPKG / common / slab_heap.h
1// SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3 
4#pragma once
5 
6#include <atomic>
7#include "common/assert.h"
8#include "common/spin_lock.h"
9 
10namespace Common {
11 
12class SlabHeapImpl {
13public:
14 struct Node {
15 Node* next{};
16 };
17 
18public:
19 constexpr SlabHeapImpl() = default;
20 
21 void Initialize() {
22 ASSERT(m_head == nullptr);
23 }
24 
25 Node* GetHead() const {
26 return m_head;
27 }
28 
29 void* Allocate() {
30 m_lock.lock();
31 
32 Node* ret = m_head;
33 if (ret != nullptr) {
34 m_head = ret->next;
35 }
36 
37 m_lock.unlock();
38 return ret;
39 }
40 
41 void Free(void* obj) {
42 m_lock.lock();
43 
44 Node* node = static_cast<Node*>(obj);
45 node->next = m_head;
46 m_head = node;
47 
48 m_lock.unlock();
49 }
50 
51private:
52 std::atomic<Node*> m_head{};
53 Common::SpinLock m_lock;
54};
55 
56class SlabHeapBase : protected SlabHeapImpl {
57private:
58 size_t m_obj_size{};
59 uintptr_t m_peak{};
60 uintptr_t m_start{};
61 uintptr_t m_end{};
62 
63public:
64 constexpr SlabHeapBase() = default;
65 
66 bool Contains(uintptr_t address) const {
67 return m_start <= address && address < m_end;
68 }
69 
70 void Initialize(size_t obj_size, void* memory, size_t memory_size) {
71 // Ensure we don't initialize a slab using null memory.
72 ASSERT(memory != nullptr);
73 
74 // Set our object size.
75 m_obj_size = obj_size;
76 
77 // Initialize the base allocator.
78 SlabHeapImpl::Initialize();
79 
80 // Set our tracking variables.
81 const size_t num_obj = (memory_size / obj_size);
82 m_start = reinterpret_cast<uintptr_t>(memory);
83 m_end = m_start + num_obj * obj_size;
84 m_peak = m_start;
85 
86 // Free the objects.
87 u8* cur = reinterpret_cast<u8*>(m_end);
88 
89 for (size_t i = 0; i < num_obj; i++) {
90 cur -= obj_size;
91 SlabHeapImpl::Free(cur);
92 }
93 }
94 
95 size_t GetSlabHeapSize() const {
96 return (m_end - m_start) / this->GetObjectSize();
97 }
98 
99 size_t GetObjectSize() const {
100 return m_obj_size;
101 }
102 
103 void* Allocate() {
104 void* obj = SlabHeapImpl::Allocate();
105 return obj;
106 }
107 
108 void Free(void* obj) {
109 // Don't allow freeing an object that wasn't allocated from this heap.
110 const bool contained = this->Contains(reinterpret_cast<uintptr_t>(obj));
111 ASSERT(contained);
112 SlabHeapImpl::Free(obj);
113 }
114 
115 size_t GetObjectIndex(const void* obj) const {
116 return (reinterpret_cast<uintptr_t>(obj) - m_start) / this->GetObjectSize();
117 }
118 
119 size_t GetPeakIndex() const {
120 return this->GetObjectIndex(reinterpret_cast<const void*>(m_peak));
121 }
122 
123 uintptr_t GetSlabHeapAddress() const {
124 return m_start;
125 }
126 
127 size_t GetNumRemaining() const {
128 // Only calculate the number of remaining objects under debug configuration.
129 return 0;
130 }
131};
132 
133template <typename T>
134class SlabHeap final : public SlabHeapBase {
135private:
136 using BaseHeap = SlabHeapBase;
137 
138public:
139 constexpr SlabHeap() = default;
140 
141 void Initialize(void* memory, size_t memory_size) {
142 BaseHeap::Initialize(sizeof(T), memory, memory_size);
143 }
144 
145 T* Allocate() {
146 T* obj = static_cast<T*>(BaseHeap::Allocate());
147 
148 if (obj != nullptr) [[likely]] {
149 std::construct_at(obj);
150 }
151 return obj;
152 }
153 
154 void Free(T* obj) {
155 BaseHeap::Free(obj);
156 }
157 
158 size_t GetObjectIndex(const T* obj) const {
159 return BaseHeap::GetObjectIndex(obj);
160 }
161};
162 
163} // namespace Common
164