A tool for deriving PKG packet encryption keys for ps4 written in c++
| 1 | #include "LoopCorrector.h" |
| 2 | #include "common/logging/log.h" |
| 3 | #include <algorithm> |
| 4 | #include <sstream> |
| 5 | |
| 6 | namespace ShadPKG::Decompiler::Analysis { |
| 7 | |
| 8 | bool LoopCorrector::isTautologyCondition(const std::shared_ptr<AST::Expression> &cond) { |
| 9 | if (!cond) |
| 10 | return false; |
| 11 | |
| 12 | // Check for self-comparison: (a == a), (a != b where a==b) |
| 13 | if (isSelfComparison(cond)) { |
| 14 | LOG_WARN(Decompiler, "Detected tautology condition in loop"); |
| 15 | return true; |
| 16 | } |
| 17 | |
| 18 | // Check for constant true/false conditions |
| 19 | if (auto constExpr = std::dynamic_pointer_cast<AST::ConstantExpr>(cond)) { |
| 20 | // Constants like 1, -1, true are infinite loop conditions |
| 21 | if (constExpr->value != 0) { |
| 22 | LOG_WARN(Decompiler, "Detected constant true condition in loop"); |
| 23 | return true; |
| 24 | } |
| 25 | } |
| 26 | |
| 27 | return false; |
| 28 | } |
| 29 | |
| 30 | bool LoopCorrector::isSelfComparison(const std::shared_ptr<AST::Expression> &expr) { |
| 31 | if (!expr) |
| 32 | return false; |
| 33 | |
| 34 | // Check for BinaryExpr like (reg == reg) |
| 35 | if (auto binExpr = std::dynamic_pointer_cast<AST::BinaryExpr>(expr)) { |
| 36 | // Check equality/inequality operators |
| 37 | if (binExpr->op == AST::BinaryExpr::Op::Equal || |
| 38 | binExpr->op == AST::BinaryExpr::Op::NotEqual) { |
| 39 | |
| 40 | // Extract variable names from both sides |
| 41 | auto leftVars = extractVariables(binExpr->left); |
| 42 | auto rightVars = extractVariables(binExpr->right); |
| 43 | |
| 44 | // If both sides use the same variable, it's a self-comparison |
| 45 | for (const auto &var : leftVars) { |
| 46 | if (rightVars.count(var) > 0) { |
| 47 | return true; |
| 48 | } |
| 49 | } |
| 50 | } |
| 51 | } |
| 52 | |
| 53 | return false; |
| 54 | } |
| 55 | |
| 56 | std::set<std::string> LoopCorrector::extractVariables(const std::shared_ptr<AST::Expression> &expr) { |
| 57 | std::set<std::string> vars; |
| 58 | |
| 59 | if (!expr) |
| 60 | return vars; |
| 61 | |
| 62 | if (auto varExpr = std::dynamic_pointer_cast<AST::VariableExpr>(expr)) { |
| 63 | vars.insert(varExpr->name); |
| 64 | } else if (auto binExpr = std::dynamic_pointer_cast<AST::BinaryExpr>(expr)) { |
| 65 | auto leftVars = extractVariables(binExpr->left); |
| 66 | auto rightVars = extractVariables(binExpr->right); |
| 67 | vars.insert(leftVars.begin(), leftVars.end()); |
| 68 | vars.insert(rightVars.begin(), rightVars.end()); |
| 69 | } else if (auto unaryExpr = std::dynamic_pointer_cast<AST::UnaryExpr>(expr)) { |
| 70 | auto innerVars = extractVariables(unaryExpr->expr); |
| 71 | vars.insert(innerVars.begin(), innerVars.end()); |
| 72 | } else if (auto derefExpr = std::dynamic_pointer_cast<AST::DerefExpr>(expr)) { |
| 73 | auto innerVars = extractVariables(derefExpr->expr); |
| 74 | vars.insert(innerVars.begin(), innerVars.end()); |
| 75 | } else if (auto memberExpr = std::dynamic_pointer_cast<AST::MemberExpr>(expr)) { |
| 76 | auto objVars = extractVariables(memberExpr->object); |
| 77 | vars.insert(objVars.begin(), objVars.end()); |
| 78 | } |
| 79 | |
| 80 | return vars; |
| 81 | } |
| 82 | |
| 83 | bool LoopCorrector::isSentinelValue(int64_t value) { |
| 84 | // Common sentinel values: NULL, -1, max uint64 |
| 85 | return value == 0 || value == -1 || value == 0xffffffffffffffff; |
| 86 | } |
| 87 | |
| 88 | bool LoopCorrector::isPointerWalkPattern(const std::shared_ptr<AST::Expression> &cond, |
| 89 | std::string &baseRegister) { |
| 90 | if (!cond) |
| 91 | return false; |
| 92 | |
| 93 | // Look for patterns like: ptr != nullptr, ptr != 0, ptr != 0xffffffff |
| 94 | if (auto binExpr = std::dynamic_pointer_cast<AST::BinaryExpr>(cond)) { |
| 95 | if (binExpr->op == AST::BinaryExpr::Op::NotEqual || |
| 96 | binExpr->op == AST::BinaryExpr::Op::Equal) { |
| 97 | |
| 98 | // Check if comparing against constant |
| 99 | std::shared_ptr<AST::Expression> ptrSide = nullptr; |
| 100 | std::shared_ptr<AST::ConstantExpr> constSide = nullptr; |
| 101 | |
| 102 | if (auto constRight = std::dynamic_pointer_cast<AST::ConstantExpr>(binExpr->right)) { |
| 103 | if (isSentinelValue(constRight->value)) { |
| 104 | ptrSide = binExpr->left; |
| 105 | constSide = constRight; |
| 106 | } |
| 107 | } else if (auto constLeft = std::dynamic_pointer_cast<AST::ConstantExpr>(binExpr->left)) { |
| 108 | if (isSentinelValue(constLeft->value)) { |
| 109 | ptrSide = binExpr->right; |
| 110 | constSide = constLeft; |
| 111 | } |
| 112 | } |
| 113 | |
| 114 | // Extract pointer register name |
| 115 | if (ptrSide) { |
| 116 | auto vars = extractVariables(ptrSide); |
| 117 | if (vars.size() == 1) { |
| 118 | baseRegister = *vars.begin(); |
| 119 | return true; |
| 120 | } |
| 121 | } |
| 122 | } |
| 123 | } |
| 124 | |
| 125 | return false; |
| 126 | } |
| 127 | |
| 128 | bool LoopCorrector::detectsPointerArithmetic(const std::shared_ptr<AST::Statement> &body, |
| 129 | std::string &ptrReg, int64_t &offset) { |
| 130 | if (!body) |
| 131 | return false; |
| 132 | |
| 133 | // Check if body contains assignments like: ptr += 8, ptr -= 8, ptr = ptr + constant |
| 134 | if (auto exprStmt = std::dynamic_pointer_cast<AST::ExpressionStatement>(body)) { |
| 135 | if (auto assignExpr = std::dynamic_pointer_cast<AST::AssignmentExpr>(exprStmt->expr)) { |
| 136 | // lhs should be a variable |
| 137 | if (auto lhsVar = std::dynamic_pointer_cast<AST::VariableExpr>(assignExpr->lhs)) { |
| 138 | ptrReg = lhsVar->name; |
| 139 | |
| 140 | // Check if rhs is arithmetic: ptr + constant |
| 141 | if (auto rhsBin = std::dynamic_pointer_cast<AST::BinaryExpr>(assignExpr->rhs)) { |
| 142 | if ((rhsBin->op == AST::BinaryExpr::Op::Add || |
| 143 | rhsBin->op == AST::BinaryExpr::Op::Sub) && |
| 144 | std::dynamic_pointer_cast<AST::ConstantExpr>(rhsBin->right)) { |
| 145 | offset = std::dynamic_pointer_cast<AST::ConstantExpr>(rhsBin->right)->value; |
| 146 | if (rhsBin->op == AST::BinaryExpr::Op::Sub) |
| 147 | offset = -offset; |
| 148 | return true; |
| 149 | } |
| 150 | } |
| 151 | } |
| 152 | } |
| 153 | } else if (auto compoundStmt = std::dynamic_pointer_cast<AST::CompoundStatement>(body)) { |
| 154 | // Check the statements in compound body |
| 155 | for (const auto &stmt : compoundStmt->statements) { |
| 156 | if (detectsPointerArithmetic(stmt, ptrReg, offset)) { |
| 157 | return true; |
| 158 | } |
| 159 | } |
| 160 | } |
| 161 | |
| 162 | return false; |
| 163 | } |
| 164 | |
| 165 | std::shared_ptr<AST::VariableDecl> |
| 166 | LoopCorrector::generateSafetyCounter(const std::string &counterName) { |
| 167 | auto decl = std::make_shared<AST::VariableDecl>(); |
| 168 | decl->name = counterName; |
| 169 | decl->type = "int"; |
| 170 | decl->initialValue = std::make_shared<AST::ConstantExpr>(0LL); |
| 171 | return decl; |
| 172 | } |
| 173 | |
| 174 | std::shared_ptr<AST::Expression> |
| 175 | LoopCorrector::generateBoundedCondition(const std::string &ptrReg, int maxIterations) { |
| 176 | // Generate: (ptr != nullptr && ptr != -1 && ++safety < maxIterations) |
| 177 | |
| 178 | // ptr != nullptr |
| 179 | auto ptrNotNull = std::make_shared<AST::BinaryExpr>( |
| 180 | AST::BinaryExpr::Op::NotEqual, |
| 181 | std::make_shared<AST::VariableExpr>(ptrReg), |
| 182 | std::make_shared<AST::ConstantExpr>(0LL)); |
| 183 | |
| 184 | // ptr != -1 |
| 185 | auto ptrNotSentinel = std::make_shared<AST::BinaryExpr>( |
| 186 | AST::BinaryExpr::Op::NotEqual, |
| 187 | std::make_shared<AST::VariableExpr>(ptrReg), |
| 188 | std::make_shared<AST::ConstantExpr>(-1LL)); |
| 189 | |
| 190 | // ++safety < maxIterations |
| 191 | auto safetyCheck = std::make_shared<AST::BinaryExpr>( |
| 192 | AST::BinaryExpr::Op::LessThan, |
| 193 | std::make_shared<AST::UnaryExpr>(AST::UnaryExpr::Op::PreIncrement, |
| 194 | std::make_shared<AST::VariableExpr>("safety")), |
| 195 | std::make_shared<AST::ConstantExpr>((int64_t)maxIterations)); |
| 196 | |
| 197 | // Combine: (ptrNotNull && ptrNotSentinel && safetyCheck) |
| 198 | auto combined1 = std::make_shared<AST::BinaryExpr>( |
| 199 | AST::BinaryExpr::Op::LogicalAnd, ptrNotNull, ptrNotSentinel); |
| 200 | |
| 201 | auto combined2 = std::make_shared<AST::BinaryExpr>( |
| 202 | AST::BinaryExpr::Op::LogicalAnd, combined1, safetyCheck); |
| 203 | |
| 204 | return combined2; |
| 205 | } |
| 206 | |
| 207 | std::shared_ptr<AST::Statement> |
| 208 | LoopCorrector::correctWhileLoop(const std::shared_ptr<AST::WhileStatement> &whileStmt, |
| 209 | const std::shared_ptr<IR::BasicBlock> &headerBB) { |
| 210 | if (!whileStmt || !whileStmt->condition) |
| 211 | return whileStmt; |
| 212 | |
| 213 | // Check for tautology |
| 214 | if (isTautologyCondition(whileStmt->condition)) { |
| 215 | LOG_INFO(Decompiler, "Correcting infinite while loop with tautology condition"); |
| 216 | |
| 217 | std::string ptrReg; |
| 218 | // Try to detect pointer walk pattern and generate proper condition |
| 219 | if (isPointerWalkPattern(whileStmt->condition, ptrReg)) { |
| 220 | LOG_INFO(Decompiler, "Detected pointer walk pattern on register: {}", ptrReg); |
| 221 | |
| 222 | auto boundedCond = generateBoundedCondition(ptrReg, 1000); |
| 223 | auto correctedStmt = std::make_shared<AST::WhileStatement>(boundedCond, whileStmt->body); |
| 224 | return correctedStmt; |
| 225 | } else { |
| 226 | // Generic tautology fix: replace with bounded loop on internal counter |
| 227 | auto boundedCond = generateBoundedCondition("reg_rax", 1000); |
| 228 | auto correctedStmt = std::make_shared<AST::WhileStatement>(boundedCond, whileStmt->body); |
| 229 | return correctedStmt; |
| 230 | } |
| 231 | } |
| 232 | |
| 233 | // Check if loop body has pointer arithmetic |
| 234 | std::string ptrReg; |
| 235 | int64_t offset; |
| 236 | if (detectsPointerArithmetic(whileStmt->body, ptrReg, offset)) { |
| 237 | LOG_INFO(Decompiler, "Detected pointer arithmetic in loop body: {} += {}", ptrReg, offset); |
| 238 | |
| 239 | // Check if condition already validates the pointer |
| 240 | if (!isPointerWalkPattern(whileStmt->condition, ptrReg)) { |
| 241 | // Add bounds checking to the existing condition |
| 242 | auto boundedCond = generateBoundedCondition(ptrReg, 1000); |
| 243 | auto correctedStmt = std::make_shared<AST::WhileStatement>(boundedCond, whileStmt->body); |
| 244 | return correctedStmt; |
| 245 | } |
| 246 | } |
| 247 | |
| 248 | // No correction needed |
| 249 | return whileStmt; |
| 250 | } |
| 251 | |
| 252 | std::shared_ptr<AST::Statement> |
| 253 | LoopCorrector::correctDoWhileLoop(const std::shared_ptr<AST::DoWhileStatement> &doWhileStmt, |
| 254 | const std::shared_ptr<IR::BasicBlock> &latchBB) { |
| 255 | if (!doWhileStmt || !doWhileStmt->condition) |
| 256 | return doWhileStmt; |
| 257 | |
| 258 | // Check for tautology |
| 259 | if (isTautologyCondition(doWhileStmt->condition)) { |
| 260 | LOG_INFO(Decompiler, "Correcting infinite do-while loop with tautology condition"); |
| 261 | |
| 262 | std::string ptrReg; |
| 263 | if (isPointerWalkPattern(doWhileStmt->condition, ptrReg)) { |
| 264 | LOG_INFO(Decompiler, "Detected pointer walk pattern on register: {}", ptrReg); |
| 265 | |
| 266 | auto boundedCond = generateBoundedCondition(ptrReg, 1000); |
| 267 | auto correctedStmt = std::make_shared<AST::DoWhileStatement>(doWhileStmt->body, boundedCond); |
| 268 | return correctedStmt; |
| 269 | } else { |
| 270 | auto boundedCond = generateBoundedCondition("reg_rax", 1000); |
| 271 | auto correctedStmt = std::make_shared<AST::DoWhileStatement>(doWhileStmt->body, boundedCond); |
| 272 | return correctedStmt; |
| 273 | } |
| 274 | } |
| 275 | |
| 276 | // Check if loop body has pointer arithmetic |
| 277 | std::string ptrReg; |
| 278 | int64_t offset; |
| 279 | if (detectsPointerArithmetic(doWhileStmt->body, ptrReg, offset)) { |
| 280 | LOG_INFO(Decompiler, "Detected pointer arithmetic in do-while body: {} += {}", ptrReg, offset); |
| 281 | |
| 282 | if (!isPointerWalkPattern(doWhileStmt->condition, ptrReg)) { |
| 283 | auto boundedCond = generateBoundedCondition(ptrReg, 1000); |
| 284 | auto correctedStmt = std::make_shared<AST::DoWhileStatement>(doWhileStmt->body, boundedCond); |
| 285 | return correctedStmt; |
| 286 | } |
| 287 | } |
| 288 | |
| 289 | return doWhileStmt; |
| 290 | } |
| 291 | |
| 292 | } // namespace ShadPKG::Decompiler::Analysis |
| 293 |