A tool for deriving PKG packet encryption keys for ps4 written in c++
| 1 | #include "DataFlowAnalysis.h" |
| 2 | #include <iostream> |
| 3 | #include <functional> |
| 4 | |
| 5 | namespace ShadPKG::Decompiler::Analysis { |
| 6 | |
| 7 | DataFlowAnalysis::DataFlowAnalysis(std::shared_ptr<AST::FunctionAST> func) |
| 8 | : func_(func) {} |
| 9 | |
| 10 | void DataFlowAnalysis::analyze() { |
| 11 | if (!func_ || !func_->body) |
| 12 | return; |
| 13 | |
| 14 | // Pass 1: Collect constraints/usages |
| 15 | func_->body->accept(this); |
| 16 | |
| 17 | // Pass 2: Apply inferred types to locals |
| 18 | applyTypes(); |
| 19 | |
| 20 | // Pass 3: Track variable usage for dead code elimination |
| 21 | trackVariableUsage(); |
| 22 | |
| 23 | // Pass 4: Eliminate dead code (assignments to unused variables) |
| 24 | eliminateDeadCode(); |
| 25 | } |
| 26 | |
| 27 | void DataFlowAnalysis::inferType(const std::string &varName, |
| 28 | AST::Expression::Type type) { |
| 29 | // Type priority hierarchy: |
| 30 | // Pointer > Int64 > Int32 > Int16 > Int8 > Unknown |
| 31 | // Once a type is inferred, only upgrade to higher priority types |
| 32 | |
| 33 | if (inferredTypes_.find(varName) == inferredTypes_.end()) { |
| 34 | inferredTypes_[varName] = type; |
| 35 | return; |
| 36 | } |
| 37 | |
| 38 | auto current = inferredTypes_[varName]; |
| 39 | |
| 40 | // If current is already Pointer, don't downgrade |
| 41 | if (current == AST::Expression::Type::Pointer) { |
| 42 | return; |
| 43 | } |
| 44 | |
| 45 | // Upgrade to Pointer if needed |
| 46 | if (type == AST::Expression::Type::Pointer) { |
| 47 | inferredTypes_[varName] = type; |
| 48 | return; |
| 49 | } |
| 50 | |
| 51 | // For integer types, upgrade to larger sizes |
| 52 | if (current == AST::Expression::Type::Unknown) { |
| 53 | inferredTypes_[varName] = type; |
| 54 | } else if (type == AST::Expression::Type::Int64 && |
| 55 | (current == AST::Expression::Type::Int32 || |
| 56 | current == AST::Expression::Type::Int16 || |
| 57 | current == AST::Expression::Type::Int8)) { |
| 58 | inferredTypes_[varName] = type; |
| 59 | } else if (type == AST::Expression::Type::Int32 && |
| 60 | (current == AST::Expression::Type::Int16 || |
| 61 | current == AST::Expression::Type::Int8)) { |
| 62 | inferredTypes_[varName] = type; |
| 63 | } else if (type == AST::Expression::Type::Int16 && |
| 64 | current == AST::Expression::Type::Int8) { |
| 65 | inferredTypes_[varName] = type; |
| 66 | } |
| 67 | } |
| 68 | |
| 69 | void DataFlowAnalysis::applyTypes() { |
| 70 | for (auto &local : func_->locals) { |
| 71 | if (inferredTypes_.count(local.name)) { |
| 72 | local.type = inferredTypes_[local.name]; |
| 73 | } |
| 74 | } |
| 75 | } |
| 76 | |
| 77 | // ═══════════════════════════════════════════════════════════════════════════ |
| 78 | // Visitor Implementation |
| 79 | // ═══════════════════════════════════════════════════════════════════════════ |
| 80 | |
| 81 | void DataFlowAnalysis::visit(AST::VariableExpr *node) { |
| 82 | // Base usage - nothing specific unless context provided |
| 83 | } |
| 84 | |
| 85 | void DataFlowAnalysis::visit(AST::BinaryExpr *node) { |
| 86 | node->left->accept(this); |
| 87 | node->right->accept(this); |
| 88 | |
| 89 | // Inference rules |
| 90 | if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(node->left)) { |
| 91 | if (node->op == AST::BinaryExpr::Op::Assign) { |
| 92 | // Assigning pointer? |
| 93 | // If right is a call to malloc? (Ideally Semantic passed here) |
| 94 | } |
| 95 | |
| 96 | // CMP / Logic -> Int/Bool context |
| 97 | if (node->op == AST::BinaryExpr::Op::Eq || |
| 98 | node->op == AST::BinaryExpr::Op::Gt) { |
| 99 | // Usually integers |
| 100 | } |
| 101 | } |
| 102 | } |
| 103 | |
| 104 | void DataFlowAnalysis::visit(AST::UnaryExpr *node) { |
| 105 | node->operand->accept(this); |
| 106 | |
| 107 | if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(node->operand)) { |
| 108 | if (node->op == AST::UnaryExpr::Op::Deref) { |
| 109 | inferType(var->name, AST::Expression::Type::Pointer); |
| 110 | } |
| 111 | } |
| 112 | } |
| 113 | |
| 114 | void DataFlowAnalysis::visit(AST::CallExpr *node) { |
| 115 | for (auto &arg : node->arguments) { |
| 116 | arg->accept(this); |
| 117 | } |
| 118 | // known function signatures would help here |
| 119 | } |
| 120 | |
| 121 | void DataFlowAnalysis::visit(AST::MemoryExpr *node) { |
| 122 | // If we have MemoryExpr( Base + Offset ) |
| 123 | // Base is likely a pointer |
| 124 | } |
| 125 | |
| 126 | void DataFlowAnalysis::visit(AST::CastExpr *node) { |
| 127 | if (node->expr) |
| 128 | node->expr->accept(this); |
| 129 | } |
| 130 | |
| 131 | // Control flow traversal |
| 132 | void DataFlowAnalysis::visit(AST::CompoundStatement *node) { |
| 133 | for (const auto &s : node->statements) |
| 134 | s->accept(this); |
| 135 | } |
| 136 | void DataFlowAnalysis::visit(AST::ExpressionStatement *node) { |
| 137 | node->expression->accept(this); |
| 138 | } |
| 139 | void DataFlowAnalysis::visit(AST::IfStatement *node) { |
| 140 | node->condition->accept(this); |
| 141 | if (node->thenBranch) |
| 142 | node->thenBranch->accept(this); |
| 143 | if (node->elseBranch) |
| 144 | node->elseBranch->accept(this); |
| 145 | } |
| 146 | void DataFlowAnalysis::visit(AST::WhileStatement *node) { |
| 147 | node->condition->accept(this); |
| 148 | if (node->body) |
| 149 | node->body->accept(this); |
| 150 | } |
| 151 | void DataFlowAnalysis::visit(AST::DoWhileStatement *node) { |
| 152 | node->condition->accept(this); |
| 153 | if (node->body) |
| 154 | node->body->accept(this); |
| 155 | } |
| 156 | void DataFlowAnalysis::visit(AST::ForStatement *node) {} |
| 157 | void DataFlowAnalysis::visit(AST::ReturnStatement *node) { |
| 158 | if (node->value) |
| 159 | node->value->accept(this); |
| 160 | } |
| 161 | void DataFlowAnalysis::visit(AST::BreakStatement *node) {} |
| 162 | void DataFlowAnalysis::visit(AST::ContinueStatement *node) {} |
| 163 | void DataFlowAnalysis::visit(AST::GotoStatement *node) {} |
| 164 | void DataFlowAnalysis::visit(AST::LabelStatement *node) {} |
| 165 | |
| 166 | void DataFlowAnalysis::visit(AST::CaseStmt *node) { |
| 167 | if (node->body) |
| 168 | node->body->accept(this); |
| 169 | } |
| 170 | |
| 171 | void DataFlowAnalysis::visit(AST::SwitchStmt *node) { |
| 172 | if (node->condition) |
| 173 | node->condition->accept(this); |
| 174 | for (auto &cse : node->cases) { |
| 175 | cse->accept(this); |
| 176 | } |
| 177 | } |
| 178 | |
| 179 | void DataFlowAnalysis::trackVariableUsage() { |
| 180 | if (!func_ || !func_->body) |
| 181 | return; |
| 182 | |
| 183 | // Forward declaration for mutual recursion |
| 184 | std::function<void(const std::shared_ptr<AST::Expression>&)> trackExpr; |
| 185 | std::function<void(const std::shared_ptr<AST::Statement>&)> trackStmt; |
| 186 | |
| 187 | trackExpr = [&](const std::shared_ptr<AST::Expression> &expr) { |
| 188 | if (!expr) return; |
| 189 | |
| 190 | if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(expr)) { |
| 191 | usedVariables_.insert(var->name); |
| 192 | } else if (auto bin = std::dynamic_pointer_cast<AST::BinaryExpr>(expr)) { |
| 193 | // For assignments, only track RHS as usage |
| 194 | if (bin->op != AST::BinaryExpr::Op::Assign) { |
| 195 | trackExpr(bin->left); |
| 196 | } |
| 197 | trackExpr(bin->right); |
| 198 | } else if (auto unary = std::dynamic_pointer_cast<AST::UnaryExpr>(expr)) { |
| 199 | trackExpr(unary->operand); |
| 200 | } else if (auto call = std::dynamic_pointer_cast<AST::CallExpr>(expr)) { |
| 201 | for (const auto &arg : call->arguments) { |
| 202 | trackExpr(arg); |
| 203 | } |
| 204 | } else if (auto cast = std::dynamic_pointer_cast<AST::CastExpr>(expr)) { |
| 205 | trackExpr(cast->expr); |
| 206 | } |
| 207 | }; |
| 208 | |
| 209 | trackStmt = [&](const std::shared_ptr<AST::Statement> &stmt) { |
| 210 | if (!stmt) return; |
| 211 | |
| 212 | if (auto compound = std::dynamic_pointer_cast<AST::CompoundStatement>(stmt)) { |
| 213 | for (const auto &s : compound->statements) { |
| 214 | trackStmt(s); |
| 215 | } |
| 216 | } else if (auto exprStmt = std::dynamic_pointer_cast<AST::ExpressionStatement>(stmt)) { |
| 217 | // Track assignments |
| 218 | if (auto bin = std::dynamic_pointer_cast<AST::BinaryExpr>(exprStmt->expression)) { |
| 219 | if (bin->op == AST::BinaryExpr::Op::Assign) { |
| 220 | if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(bin->left)) { |
| 221 | assignedVariables_.insert(var->name); |
| 222 | } |
| 223 | } |
| 224 | } |
| 225 | trackExpr(exprStmt->expression); |
| 226 | } else if (auto ifStmt = std::dynamic_pointer_cast<AST::IfStatement>(stmt)) { |
| 227 | trackExpr(ifStmt->condition); |
| 228 | trackStmt(ifStmt->thenBranch); |
| 229 | trackStmt(ifStmt->elseBranch); |
| 230 | } else if (auto whileStmt = std::dynamic_pointer_cast<AST::WhileStatement>(stmt)) { |
| 231 | trackExpr(whileStmt->condition); |
| 232 | trackStmt(whileStmt->body); |
| 233 | } else if (auto doWhile = std::dynamic_pointer_cast<AST::DoWhileStatement>(stmt)) { |
| 234 | trackExpr(doWhile->condition); |
| 235 | trackStmt(doWhile->body); |
| 236 | } else if (auto forStmt = std::dynamic_pointer_cast<AST::ForStatement>(stmt)) { |
| 237 | trackExpr(forStmt->init); |
| 238 | trackExpr(forStmt->condition); |
| 239 | trackExpr(forStmt->step); |
| 240 | trackStmt(forStmt->body); |
| 241 | } else if (auto ret = std::dynamic_pointer_cast<AST::ReturnStatement>(stmt)) { |
| 242 | trackExpr(ret->value); |
| 243 | } |
| 244 | }; |
| 245 | |
| 246 | trackStmt(func_->body); |
| 247 | } |
| 248 | |
| 249 | void DataFlowAnalysis::eliminateDeadCode() { |
| 250 | if (!func_ || !func_->body) |
| 251 | return; |
| 252 | |
| 253 | // Recursively eliminate dead assignments from compound statements |
| 254 | std::function<void(std::shared_ptr<AST::CompoundStatement>&)> eliminateFromCompound; |
| 255 | |
| 256 | eliminateFromCompound = [&](std::shared_ptr<AST::CompoundStatement> &compound) { |
| 257 | if (!compound) return; |
| 258 | |
| 259 | std::vector<std::shared_ptr<AST::Statement>> filtered; |
| 260 | |
| 261 | for (const auto &stmt : compound->statements) { |
| 262 | bool shouldKeep = true; |
| 263 | |
| 264 | // Check if this is a dead assignment |
| 265 | if (auto exprStmt = std::dynamic_pointer_cast<AST::ExpressionStatement>(stmt)) { |
| 266 | if (auto bin = std::dynamic_pointer_cast<AST::BinaryExpr>(exprStmt->expression)) { |
| 267 | if (bin->op == AST::BinaryExpr::Op::Assign) { |
| 268 | if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(bin->left)) { |
| 269 | // If variable is assigned but never used, it's dead code |
| 270 | if (usedVariables_.find(var->name) == usedVariables_.end()) { |
| 271 | shouldKeep = false; |
| 272 | } |
| 273 | } |
| 274 | } |
| 275 | } |
| 276 | } |
| 277 | |
| 278 | // Recursively process nested compound statements |
| 279 | if (auto nestedCompound = std::dynamic_pointer_cast<AST::CompoundStatement>(stmt)) { |
| 280 | eliminateFromCompound(nestedCompound); |
| 281 | } else if (auto ifStmt = std::dynamic_pointer_cast<AST::IfStatement>(stmt)) { |
| 282 | if (auto thenCompound = std::dynamic_pointer_cast<AST::CompoundStatement>(ifStmt->thenBranch)) { |
| 283 | eliminateFromCompound(thenCompound); |
| 284 | } |
| 285 | if (auto elseCompound = std::dynamic_pointer_cast<AST::CompoundStatement>(ifStmt->elseBranch)) { |
| 286 | eliminateFromCompound(elseCompound); |
| 287 | } |
| 288 | } else if (auto whileStmt = std::dynamic_pointer_cast<AST::WhileStatement>(stmt)) { |
| 289 | if (auto bodyCompound = std::dynamic_pointer_cast<AST::CompoundStatement>(whileStmt->body)) { |
| 290 | eliminateFromCompound(bodyCompound); |
| 291 | } |
| 292 | } else if (auto doWhile = std::dynamic_pointer_cast<AST::DoWhileStatement>(stmt)) { |
| 293 | if (auto bodyCompound = std::dynamic_pointer_cast<AST::CompoundStatement>(doWhile->body)) { |
| 294 | eliminateFromCompound(bodyCompound); |
| 295 | } |
| 296 | } else if (auto forStmt = std::dynamic_pointer_cast<AST::ForStatement>(stmt)) { |
| 297 | if (auto bodyCompound = std::dynamic_pointer_cast<AST::CompoundStatement>(forStmt->body)) { |
| 298 | eliminateFromCompound(bodyCompound); |
| 299 | } |
| 300 | } |
| 301 | |
| 302 | if (shouldKeep) { |
| 303 | filtered.push_back(stmt); |
| 304 | } |
| 305 | } |
| 306 | |
| 307 | compound->statements = filtered; |
| 308 | }; |
| 309 | |
| 310 | eliminateFromCompound(func_->body); |
| 311 | } |
| 312 | |
| 313 | } // namespace ShadPKG::Decompiler::Analysis |