Seregon/ShadPKG

A tool for deriving PKG packet encryption keys for ps4 written in c++

C++/47.3 KB/No license
core/decompiler/analysis/LoopCorrector.cpp
ShadPKG / core / decompiler / analysis / LoopCorrector.cpp
1#include "LoopCorrector.h"
2#include "common/logging/log.h"
3#include <algorithm>
4#include <sstream>
5 
6namespace ShadPKG::Decompiler::Analysis {
7 
8bool LoopCorrector::isTautologyCondition(const std::shared_ptr<AST::Expression> &cond) {
9 if (!cond)
10 return false;
11 
12 // Check for self-comparison: (a == a), (a != b where a==b)
13 if (isSelfComparison(cond)) {
14 LOG_WARN(Decompiler, "Detected tautology condition in loop");
15 return true;
16 }
17 
18 // Check for constant true/false conditions
19 if (auto constExpr = std::dynamic_pointer_cast<AST::ConstantExpr>(cond)) {
20 // Constants like 1, -1, true are infinite loop conditions
21 if (constExpr->value != 0) {
22 LOG_WARN(Decompiler, "Detected constant true condition in loop");
23 return true;
24 }
25 }
26 
27 return false;
28}
29 
30bool LoopCorrector::isSelfComparison(const std::shared_ptr<AST::Expression> &expr) {
31 if (!expr)
32 return false;
33 
34 // Check for BinaryExpr like (reg == reg)
35 if (auto binExpr = std::dynamic_pointer_cast<AST::BinaryExpr>(expr)) {
36 // Check equality/inequality operators
37 if (binExpr->op == AST::BinaryExpr::Op::Equal ||
38 binExpr->op == AST::BinaryExpr::Op::NotEqual) {
39
40 // Extract variable names from both sides
41 auto leftVars = extractVariables(binExpr->left);
42 auto rightVars = extractVariables(binExpr->right);
43 
44 // If both sides use the same variable, it's a self-comparison
45 for (const auto &var : leftVars) {
46 if (rightVars.count(var) > 0) {
47 return true;
48 }
49 }
50 }
51 }
52 
53 return false;
54}
55 
56std::set<std::string> LoopCorrector::extractVariables(const std::shared_ptr<AST::Expression> &expr) {
57 std::set<std::string> vars;
58 
59 if (!expr)
60 return vars;
61 
62 if (auto varExpr = std::dynamic_pointer_cast<AST::VariableExpr>(expr)) {
63 vars.insert(varExpr->name);
64 } else if (auto binExpr = std::dynamic_pointer_cast<AST::BinaryExpr>(expr)) {
65 auto leftVars = extractVariables(binExpr->left);
66 auto rightVars = extractVariables(binExpr->right);
67 vars.insert(leftVars.begin(), leftVars.end());
68 vars.insert(rightVars.begin(), rightVars.end());
69 } else if (auto unaryExpr = std::dynamic_pointer_cast<AST::UnaryExpr>(expr)) {
70 auto innerVars = extractVariables(unaryExpr->expr);
71 vars.insert(innerVars.begin(), innerVars.end());
72 } else if (auto derefExpr = std::dynamic_pointer_cast<AST::DerefExpr>(expr)) {
73 auto innerVars = extractVariables(derefExpr->expr);
74 vars.insert(innerVars.begin(), innerVars.end());
75 } else if (auto memberExpr = std::dynamic_pointer_cast<AST::MemberExpr>(expr)) {
76 auto objVars = extractVariables(memberExpr->object);
77 vars.insert(objVars.begin(), objVars.end());
78 }
79 
80 return vars;
81}
82 
83bool LoopCorrector::isSentinelValue(int64_t value) {
84 // Common sentinel values: NULL, -1, max uint64
85 return value == 0 || value == -1 || value == 0xffffffffffffffff;
86}
87 
88bool LoopCorrector::isPointerWalkPattern(const std::shared_ptr<AST::Expression> &cond,
89 std::string &baseRegister) {
90 if (!cond)
91 return false;
92 
93 // Look for patterns like: ptr != nullptr, ptr != 0, ptr != 0xffffffff
94 if (auto binExpr = std::dynamic_pointer_cast<AST::BinaryExpr>(cond)) {
95 if (binExpr->op == AST::BinaryExpr::Op::NotEqual ||
96 binExpr->op == AST::BinaryExpr::Op::Equal) {
97
98 // Check if comparing against constant
99 std::shared_ptr<AST::Expression> ptrSide = nullptr;
100 std::shared_ptr<AST::ConstantExpr> constSide = nullptr;
101 
102 if (auto constRight = std::dynamic_pointer_cast<AST::ConstantExpr>(binExpr->right)) {
103 if (isSentinelValue(constRight->value)) {
104 ptrSide = binExpr->left;
105 constSide = constRight;
106 }
107 } else if (auto constLeft = std::dynamic_pointer_cast<AST::ConstantExpr>(binExpr->left)) {
108 if (isSentinelValue(constLeft->value)) {
109 ptrSide = binExpr->right;
110 constSide = constLeft;
111 }
112 }
113 
114 // Extract pointer register name
115 if (ptrSide) {
116 auto vars = extractVariables(ptrSide);
117 if (vars.size() == 1) {
118 baseRegister = *vars.begin();
119 return true;
120 }
121 }
122 }
123 }
124 
125 return false;
126}
127 
128bool LoopCorrector::detectsPointerArithmetic(const std::shared_ptr<AST::Statement> &body,
129 std::string &ptrReg, int64_t &offset) {
130 if (!body)
131 return false;
132 
133 // Check if body contains assignments like: ptr += 8, ptr -= 8, ptr = ptr + constant
134 if (auto exprStmt = std::dynamic_pointer_cast<AST::ExpressionStatement>(body)) {
135 if (auto assignExpr = std::dynamic_pointer_cast<AST::AssignmentExpr>(exprStmt->expr)) {
136 // lhs should be a variable
137 if (auto lhsVar = std::dynamic_pointer_cast<AST::VariableExpr>(assignExpr->lhs)) {
138 ptrReg = lhsVar->name;
139
140 // Check if rhs is arithmetic: ptr + constant
141 if (auto rhsBin = std::dynamic_pointer_cast<AST::BinaryExpr>(assignExpr->rhs)) {
142 if ((rhsBin->op == AST::BinaryExpr::Op::Add ||
143 rhsBin->op == AST::BinaryExpr::Op::Sub) &&
144 std::dynamic_pointer_cast<AST::ConstantExpr>(rhsBin->right)) {
145 offset = std::dynamic_pointer_cast<AST::ConstantExpr>(rhsBin->right)->value;
146 if (rhsBin->op == AST::BinaryExpr::Op::Sub)
147 offset = -offset;
148 return true;
149 }
150 }
151 }
152 }
153 } else if (auto compoundStmt = std::dynamic_pointer_cast<AST::CompoundStatement>(body)) {
154 // Check the statements in compound body
155 for (const auto &stmt : compoundStmt->statements) {
156 if (detectsPointerArithmetic(stmt, ptrReg, offset)) {
157 return true;
158 }
159 }
160 }
161 
162 return false;
163}
164 
165std::shared_ptr<AST::VariableDecl>
166LoopCorrector::generateSafetyCounter(const std::string &counterName) {
167 auto decl = std::make_shared<AST::VariableDecl>();
168 decl->name = counterName;
169 decl->type = "int";
170 decl->initialValue = std::make_shared<AST::ConstantExpr>(0LL);
171 return decl;
172}
173 
174std::shared_ptr<AST::Expression>
175LoopCorrector::generateBoundedCondition(const std::string &ptrReg, int maxIterations) {
176 // Generate: (ptr != nullptr && ptr != -1 && ++safety < maxIterations)
177
178 // ptr != nullptr
179 auto ptrNotNull = std::make_shared<AST::BinaryExpr>(
180 AST::BinaryExpr::Op::NotEqual,
181 std::make_shared<AST::VariableExpr>(ptrReg),
182 std::make_shared<AST::ConstantExpr>(0LL));
183 
184 // ptr != -1
185 auto ptrNotSentinel = std::make_shared<AST::BinaryExpr>(
186 AST::BinaryExpr::Op::NotEqual,
187 std::make_shared<AST::VariableExpr>(ptrReg),
188 std::make_shared<AST::ConstantExpr>(-1LL));
189 
190 // ++safety < maxIterations
191 auto safetyCheck = std::make_shared<AST::BinaryExpr>(
192 AST::BinaryExpr::Op::LessThan,
193 std::make_shared<AST::UnaryExpr>(AST::UnaryExpr::Op::PreIncrement,
194 std::make_shared<AST::VariableExpr>("safety")),
195 std::make_shared<AST::ConstantExpr>((int64_t)maxIterations));
196 
197 // Combine: (ptrNotNull && ptrNotSentinel && safetyCheck)
198 auto combined1 = std::make_shared<AST::BinaryExpr>(
199 AST::BinaryExpr::Op::LogicalAnd, ptrNotNull, ptrNotSentinel);
200 
201 auto combined2 = std::make_shared<AST::BinaryExpr>(
202 AST::BinaryExpr::Op::LogicalAnd, combined1, safetyCheck);
203 
204 return combined2;
205}
206 
207std::shared_ptr<AST::Statement>
208LoopCorrector::correctWhileLoop(const std::shared_ptr<AST::WhileStatement> &whileStmt,
209 const std::shared_ptr<IR::BasicBlock> &headerBB) {
210 if (!whileStmt || !whileStmt->condition)
211 return whileStmt;
212 
213 // Check for tautology
214 if (isTautologyCondition(whileStmt->condition)) {
215 LOG_INFO(Decompiler, "Correcting infinite while loop with tautology condition");
216 
217 std::string ptrReg;
218 // Try to detect pointer walk pattern and generate proper condition
219 if (isPointerWalkPattern(whileStmt->condition, ptrReg)) {
220 LOG_INFO(Decompiler, "Detected pointer walk pattern on register: {}", ptrReg);
221
222 auto boundedCond = generateBoundedCondition(ptrReg, 1000);
223 auto correctedStmt = std::make_shared<AST::WhileStatement>(boundedCond, whileStmt->body);
224 return correctedStmt;
225 } else {
226 // Generic tautology fix: replace with bounded loop on internal counter
227 auto boundedCond = generateBoundedCondition("reg_rax", 1000);
228 auto correctedStmt = std::make_shared<AST::WhileStatement>(boundedCond, whileStmt->body);
229 return correctedStmt;
230 }
231 }
232 
233 // Check if loop body has pointer arithmetic
234 std::string ptrReg;
235 int64_t offset;
236 if (detectsPointerArithmetic(whileStmt->body, ptrReg, offset)) {
237 LOG_INFO(Decompiler, "Detected pointer arithmetic in loop body: {} += {}", ptrReg, offset);
238
239 // Check if condition already validates the pointer
240 if (!isPointerWalkPattern(whileStmt->condition, ptrReg)) {
241 // Add bounds checking to the existing condition
242 auto boundedCond = generateBoundedCondition(ptrReg, 1000);
243 auto correctedStmt = std::make_shared<AST::WhileStatement>(boundedCond, whileStmt->body);
244 return correctedStmt;
245 }
246 }
247 
248 // No correction needed
249 return whileStmt;
250}
251 
252std::shared_ptr<AST::Statement>
253LoopCorrector::correctDoWhileLoop(const std::shared_ptr<AST::DoWhileStatement> &doWhileStmt,
254 const std::shared_ptr<IR::BasicBlock> &latchBB) {
255 if (!doWhileStmt || !doWhileStmt->condition)
256 return doWhileStmt;
257 
258 // Check for tautology
259 if (isTautologyCondition(doWhileStmt->condition)) {
260 LOG_INFO(Decompiler, "Correcting infinite do-while loop with tautology condition");
261 
262 std::string ptrReg;
263 if (isPointerWalkPattern(doWhileStmt->condition, ptrReg)) {
264 LOG_INFO(Decompiler, "Detected pointer walk pattern on register: {}", ptrReg);
265
266 auto boundedCond = generateBoundedCondition(ptrReg, 1000);
267 auto correctedStmt = std::make_shared<AST::DoWhileStatement>(doWhileStmt->body, boundedCond);
268 return correctedStmt;
269 } else {
270 auto boundedCond = generateBoundedCondition("reg_rax", 1000);
271 auto correctedStmt = std::make_shared<AST::DoWhileStatement>(doWhileStmt->body, boundedCond);
272 return correctedStmt;
273 }
274 }
275 
276 // Check if loop body has pointer arithmetic
277 std::string ptrReg;
278 int64_t offset;
279 if (detectsPointerArithmetic(doWhileStmt->body, ptrReg, offset)) {
280 LOG_INFO(Decompiler, "Detected pointer arithmetic in do-while body: {} += {}", ptrReg, offset);
281
282 if (!isPointerWalkPattern(doWhileStmt->condition, ptrReg)) {
283 auto boundedCond = generateBoundedCondition(ptrReg, 1000);
284 auto correctedStmt = std::make_shared<AST::DoWhileStatement>(doWhileStmt->body, boundedCond);
285 return correctedStmt;
286 }
287 }
288 
289 return doWhileStmt;
290}
291 
292} // namespace ShadPKG::Decompiler::Analysis
293