Seregon/ShadPKG

A tool for deriving PKG packet encryption keys for ps4 written in c++

C++/47.3 KB/No license
core/decompiler/analysis/DataFlowAnalysis.cpp
ShadPKG / core / decompiler / analysis / DataFlowAnalysis.cpp
1#include "DataFlowAnalysis.h"
2#include <iostream>
3#include <functional>
4 
5namespace ShadPKG::Decompiler::Analysis {
6 
7DataFlowAnalysis::DataFlowAnalysis(std::shared_ptr<AST::FunctionAST> func)
8 : func_(func) {}
9 
10void DataFlowAnalysis::analyze() {
11 if (!func_ || !func_->body)
12 return;
13 
14 // Pass 1: Collect constraints/usages
15 func_->body->accept(this);
16 
17 // Pass 2: Apply inferred types to locals
18 applyTypes();
19
20 // Pass 3: Track variable usage for dead code elimination
21 trackVariableUsage();
22
23 // Pass 4: Eliminate dead code (assignments to unused variables)
24 eliminateDeadCode();
25}
26 
27void DataFlowAnalysis::inferType(const std::string &varName,
28 AST::Expression::Type type) {
29 // Type priority hierarchy:
30 // Pointer > Int64 > Int32 > Int16 > Int8 > Unknown
31 // Once a type is inferred, only upgrade to higher priority types
32
33 if (inferredTypes_.find(varName) == inferredTypes_.end()) {
34 inferredTypes_[varName] = type;
35 return;
36 }
37
38 auto current = inferredTypes_[varName];
39
40 // If current is already Pointer, don't downgrade
41 if (current == AST::Expression::Type::Pointer) {
42 return;
43 }
44
45 // Upgrade to Pointer if needed
46 if (type == AST::Expression::Type::Pointer) {
47 inferredTypes_[varName] = type;
48 return;
49 }
50
51 // For integer types, upgrade to larger sizes
52 if (current == AST::Expression::Type::Unknown) {
53 inferredTypes_[varName] = type;
54 } else if (type == AST::Expression::Type::Int64 &&
55 (current == AST::Expression::Type::Int32 ||
56 current == AST::Expression::Type::Int16 ||
57 current == AST::Expression::Type::Int8)) {
58 inferredTypes_[varName] = type;
59 } else if (type == AST::Expression::Type::Int32 &&
60 (current == AST::Expression::Type::Int16 ||
61 current == AST::Expression::Type::Int8)) {
62 inferredTypes_[varName] = type;
63 } else if (type == AST::Expression::Type::Int16 &&
64 current == AST::Expression::Type::Int8) {
65 inferredTypes_[varName] = type;
66 }
67}
68 
69void DataFlowAnalysis::applyTypes() {
70 for (auto &local : func_->locals) {
71 if (inferredTypes_.count(local.name)) {
72 local.type = inferredTypes_[local.name];
73 }
74 }
75}
76 
77// ═══════════════════════════════════════════════════════════════════════════
78// Visitor Implementation
79// ═══════════════════════════════════════════════════════════════════════════
80 
81void DataFlowAnalysis::visit(AST::VariableExpr *node) {
82 // Base usage - nothing specific unless context provided
83}
84 
85void DataFlowAnalysis::visit(AST::BinaryExpr *node) {
86 node->left->accept(this);
87 node->right->accept(this);
88 
89 // Inference rules
90 if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(node->left)) {
91 if (node->op == AST::BinaryExpr::Op::Assign) {
92 // Assigning pointer?
93 // If right is a call to malloc? (Ideally Semantic passed here)
94 }
95 
96 // CMP / Logic -> Int/Bool context
97 if (node->op == AST::BinaryExpr::Op::Eq ||
98 node->op == AST::BinaryExpr::Op::Gt) {
99 // Usually integers
100 }
101 }
102}
103 
104void DataFlowAnalysis::visit(AST::UnaryExpr *node) {
105 node->operand->accept(this);
106 
107 if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(node->operand)) {
108 if (node->op == AST::UnaryExpr::Op::Deref) {
109 inferType(var->name, AST::Expression::Type::Pointer);
110 }
111 }
112}
113 
114void DataFlowAnalysis::visit(AST::CallExpr *node) {
115 for (auto &arg : node->arguments) {
116 arg->accept(this);
117 }
118 // known function signatures would help here
119}
120 
121void DataFlowAnalysis::visit(AST::MemoryExpr *node) {
122 // If we have MemoryExpr( Base + Offset )
123 // Base is likely a pointer
124}
125 
126void DataFlowAnalysis::visit(AST::CastExpr *node) {
127 if (node->expr)
128 node->expr->accept(this);
129}
130 
131// Control flow traversal
132void DataFlowAnalysis::visit(AST::CompoundStatement *node) {
133 for (const auto &s : node->statements)
134 s->accept(this);
135}
136void DataFlowAnalysis::visit(AST::ExpressionStatement *node) {
137 node->expression->accept(this);
138}
139void DataFlowAnalysis::visit(AST::IfStatement *node) {
140 node->condition->accept(this);
141 if (node->thenBranch)
142 node->thenBranch->accept(this);
143 if (node->elseBranch)
144 node->elseBranch->accept(this);
145}
146void DataFlowAnalysis::visit(AST::WhileStatement *node) {
147 node->condition->accept(this);
148 if (node->body)
149 node->body->accept(this);
150}
151void DataFlowAnalysis::visit(AST::DoWhileStatement *node) {
152 node->condition->accept(this);
153 if (node->body)
154 node->body->accept(this);
155}
156void DataFlowAnalysis::visit(AST::ForStatement *node) {}
157void DataFlowAnalysis::visit(AST::ReturnStatement *node) {
158 if (node->value)
159 node->value->accept(this);
160}
161void DataFlowAnalysis::visit(AST::BreakStatement *node) {}
162void DataFlowAnalysis::visit(AST::ContinueStatement *node) {}
163void DataFlowAnalysis::visit(AST::GotoStatement *node) {}
164void DataFlowAnalysis::visit(AST::LabelStatement *node) {}
165 
166void DataFlowAnalysis::visit(AST::CaseStmt *node) {
167 if (node->body)
168 node->body->accept(this);
169}
170 
171void DataFlowAnalysis::visit(AST::SwitchStmt *node) {
172 if (node->condition)
173 node->condition->accept(this);
174 for (auto &cse : node->cases) {
175 cse->accept(this);
176 }
177}
178 
179void DataFlowAnalysis::trackVariableUsage() {
180 if (!func_ || !func_->body)
181 return;
182 
183 // Forward declaration for mutual recursion
184 std::function<void(const std::shared_ptr<AST::Expression>&)> trackExpr;
185 std::function<void(const std::shared_ptr<AST::Statement>&)> trackStmt;
186 
187 trackExpr = [&](const std::shared_ptr<AST::Expression> &expr) {
188 if (!expr) return;
189
190 if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(expr)) {
191 usedVariables_.insert(var->name);
192 } else if (auto bin = std::dynamic_pointer_cast<AST::BinaryExpr>(expr)) {
193 // For assignments, only track RHS as usage
194 if (bin->op != AST::BinaryExpr::Op::Assign) {
195 trackExpr(bin->left);
196 }
197 trackExpr(bin->right);
198 } else if (auto unary = std::dynamic_pointer_cast<AST::UnaryExpr>(expr)) {
199 trackExpr(unary->operand);
200 } else if (auto call = std::dynamic_pointer_cast<AST::CallExpr>(expr)) {
201 for (const auto &arg : call->arguments) {
202 trackExpr(arg);
203 }
204 } else if (auto cast = std::dynamic_pointer_cast<AST::CastExpr>(expr)) {
205 trackExpr(cast->expr);
206 }
207 };
208 
209 trackStmt = [&](const std::shared_ptr<AST::Statement> &stmt) {
210 if (!stmt) return;
211
212 if (auto compound = std::dynamic_pointer_cast<AST::CompoundStatement>(stmt)) {
213 for (const auto &s : compound->statements) {
214 trackStmt(s);
215 }
216 } else if (auto exprStmt = std::dynamic_pointer_cast<AST::ExpressionStatement>(stmt)) {
217 // Track assignments
218 if (auto bin = std::dynamic_pointer_cast<AST::BinaryExpr>(exprStmt->expression)) {
219 if (bin->op == AST::BinaryExpr::Op::Assign) {
220 if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(bin->left)) {
221 assignedVariables_.insert(var->name);
222 }
223 }
224 }
225 trackExpr(exprStmt->expression);
226 } else if (auto ifStmt = std::dynamic_pointer_cast<AST::IfStatement>(stmt)) {
227 trackExpr(ifStmt->condition);
228 trackStmt(ifStmt->thenBranch);
229 trackStmt(ifStmt->elseBranch);
230 } else if (auto whileStmt = std::dynamic_pointer_cast<AST::WhileStatement>(stmt)) {
231 trackExpr(whileStmt->condition);
232 trackStmt(whileStmt->body);
233 } else if (auto doWhile = std::dynamic_pointer_cast<AST::DoWhileStatement>(stmt)) {
234 trackExpr(doWhile->condition);
235 trackStmt(doWhile->body);
236 } else if (auto forStmt = std::dynamic_pointer_cast<AST::ForStatement>(stmt)) {
237 trackExpr(forStmt->init);
238 trackExpr(forStmt->condition);
239 trackExpr(forStmt->step);
240 trackStmt(forStmt->body);
241 } else if (auto ret = std::dynamic_pointer_cast<AST::ReturnStatement>(stmt)) {
242 trackExpr(ret->value);
243 }
244 };
245 
246 trackStmt(func_->body);
247}
248 
249void DataFlowAnalysis::eliminateDeadCode() {
250 if (!func_ || !func_->body)
251 return;
252 
253 // Recursively eliminate dead assignments from compound statements
254 std::function<void(std::shared_ptr<AST::CompoundStatement>&)> eliminateFromCompound;
255
256 eliminateFromCompound = [&](std::shared_ptr<AST::CompoundStatement> &compound) {
257 if (!compound) return;
258
259 std::vector<std::shared_ptr<AST::Statement>> filtered;
260
261 for (const auto &stmt : compound->statements) {
262 bool shouldKeep = true;
263
264 // Check if this is a dead assignment
265 if (auto exprStmt = std::dynamic_pointer_cast<AST::ExpressionStatement>(stmt)) {
266 if (auto bin = std::dynamic_pointer_cast<AST::BinaryExpr>(exprStmt->expression)) {
267 if (bin->op == AST::BinaryExpr::Op::Assign) {
268 if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(bin->left)) {
269 // If variable is assigned but never used, it's dead code
270 if (usedVariables_.find(var->name) == usedVariables_.end()) {
271 shouldKeep = false;
272 }
273 }
274 }
275 }
276 }
277
278 // Recursively process nested compound statements
279 if (auto nestedCompound = std::dynamic_pointer_cast<AST::CompoundStatement>(stmt)) {
280 eliminateFromCompound(nestedCompound);
281 } else if (auto ifStmt = std::dynamic_pointer_cast<AST::IfStatement>(stmt)) {
282 if (auto thenCompound = std::dynamic_pointer_cast<AST::CompoundStatement>(ifStmt->thenBranch)) {
283 eliminateFromCompound(thenCompound);
284 }
285 if (auto elseCompound = std::dynamic_pointer_cast<AST::CompoundStatement>(ifStmt->elseBranch)) {
286 eliminateFromCompound(elseCompound);
287 }
288 } else if (auto whileStmt = std::dynamic_pointer_cast<AST::WhileStatement>(stmt)) {
289 if (auto bodyCompound = std::dynamic_pointer_cast<AST::CompoundStatement>(whileStmt->body)) {
290 eliminateFromCompound(bodyCompound);
291 }
292 } else if (auto doWhile = std::dynamic_pointer_cast<AST::DoWhileStatement>(stmt)) {
293 if (auto bodyCompound = std::dynamic_pointer_cast<AST::CompoundStatement>(doWhile->body)) {
294 eliminateFromCompound(bodyCompound);
295 }
296 } else if (auto forStmt = std::dynamic_pointer_cast<AST::ForStatement>(stmt)) {
297 if (auto bodyCompound = std::dynamic_pointer_cast<AST::CompoundStatement>(forStmt->body)) {
298 eliminateFromCompound(bodyCompound);
299 }
300 }
301
302 if (shouldKeep) {
303 filtered.push_back(stmt);
304 }
305 }
306
307 compound->statements = filtered;
308 };
309
310 eliminateFromCompound(func_->body);
311}
312 
313} // namespace ShadPKG::Decompiler::Analysis