Seregon/ShadPKG

A tool for deriving PKG packet encryption keys for ps4 written in c++

C++/47.3 KB/No license
core/decompiler/codegen/CppEmitter.cpp
ShadPKG / core / decompiler / codegen / CppEmitter.cpp
1#include "CppEmitter.h"
2#include "../DecompilerContext.h"
3#include "../analysis/SymbolDatabase.h"
4#include "../analysis/TypeSystem.h"
5#include <iomanip>
6#include <functional>
7 
8namespace ShadPKG::Decompiler::Codegen {
9 
10// Helper to collect all labels from statements recursively
11static void collectLabels(const std::shared_ptr<AST::Statement> &stmt, std::set<std::string> &labels) {
12 if (!stmt) return;
13 if (auto label = std::dynamic_pointer_cast<AST::LabelStatement>(stmt)) {
14 labels.insert(label->name);
15 } else if (auto compound = std::dynamic_pointer_cast<AST::CompoundStatement>(stmt)) {
16 for (const auto &s : compound->statements) collectLabels(s, labels);
17 } else if (auto ifStmt = std::dynamic_pointer_cast<AST::IfStatement>(stmt)) {
18 collectLabels(ifStmt->thenBranch, labels);
19 collectLabels(ifStmt->elseBranch, labels);
20 } else if (auto whileStmt = std::dynamic_pointer_cast<AST::WhileStatement>(stmt)) {
21 collectLabels(whileStmt->body, labels);
22 } else if (auto doWhile = std::dynamic_pointer_cast<AST::DoWhileStatement>(stmt)) {
23 collectLabels(doWhile->body, labels);
24 } else if (auto forStmt = std::dynamic_pointer_cast<AST::ForStatement>(stmt)) {
25 collectLabels(forStmt->body, labels);
26 } else if (auto switchStmt = std::dynamic_pointer_cast<AST::SwitchStmt>(stmt)) {
27 for (const auto &c : switchStmt->cases) collectLabels(c->body, labels);
28 }
29}
30 
31// Helper to filter registers - only keep those actually used in expressions
32static void filterUsedRegs(const std::shared_ptr<AST::Statement> &stmt, std::set<std::string> &regs) {
33 if (!stmt) return;
34
35 std::set<std::string> mentioned;
36
37 // Forward declarations for mutual recursion
38 std::function<void(const std::shared_ptr<AST::Expression>&)> collectRegsExpr;
39 std::function<void(const std::shared_ptr<AST::Statement>&)> collectRegsStmt;
40
41 collectRegsExpr = [&](const std::shared_ptr<AST::Expression> &expr) {
42 if (!expr) return;
43 if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(expr)) {
44 if (var->name.find("reg_") == 0 || var->name.find("xmm") == 0) {
45 mentioned.insert(var->name);
46 }
47 } else if (auto bin = std::dynamic_pointer_cast<AST::BinaryExpr>(expr)) {
48 collectRegsExpr(bin->left);
49 collectRegsExpr(bin->right);
50 } else if (auto unary = std::dynamic_pointer_cast<AST::UnaryExpr>(expr)) {
51 collectRegsExpr(unary->operand);
52 } else if (auto call = std::dynamic_pointer_cast<AST::CallExpr>(expr)) {
53 for (const auto &arg : call->arguments) collectRegsExpr(arg);
54 } else if (auto cast = std::dynamic_pointer_cast<AST::CastExpr>(expr)) {
55 collectRegsExpr(cast->expr);
56 }
57 };
58
59 collectRegsStmt = [&](const std::shared_ptr<AST::Statement> &s) {
60 if (!s) return;
61 if (auto compound = std::dynamic_pointer_cast<AST::CompoundStatement>(s)) {
62 for (const auto &st : compound->statements) collectRegsStmt(st);
63 } else if (auto exprStmt = std::dynamic_pointer_cast<AST::ExpressionStatement>(s)) {
64 collectRegsExpr(exprStmt->expression);
65 } else if (auto ifStmt = std::dynamic_pointer_cast<AST::IfStatement>(s)) {
66 collectRegsExpr(ifStmt->condition);
67 collectRegsStmt(ifStmt->thenBranch);
68 collectRegsStmt(ifStmt->elseBranch);
69 } else if (auto whileStmt = std::dynamic_pointer_cast<AST::WhileStatement>(s)) {
70 collectRegsExpr(whileStmt->condition);
71 collectRegsStmt(whileStmt->body);
72 } else if (auto doWhile = std::dynamic_pointer_cast<AST::DoWhileStatement>(s)) {
73 collectRegsExpr(doWhile->condition);
74 collectRegsStmt(doWhile->body);
75 } else if (auto forStmt = std::dynamic_pointer_cast<AST::ForStatement>(s)) {
76 collectRegsExpr(forStmt->init);
77 collectRegsExpr(forStmt->condition);
78 collectRegsExpr(forStmt->step);
79 collectRegsStmt(forStmt->body);
80 } else if (auto ret = std::dynamic_pointer_cast<AST::ReturnStatement>(s)) {
81 collectRegsExpr(ret->value);
82 }
83 };
84
85 collectRegsStmt(stmt);
86
87 // Keep only registers that are actually mentioned
88 std::set<std::string> filtered;
89 for (const auto &reg : regs) {
90 if (mentioned.count(reg)) {
91 filtered.insert(reg);
92 }
93 }
94 regs = filtered;
95}
96 
97std::string
98CppEmitter::generate(const std::shared_ptr<AST::FunctionAST> &func) {
99 ss_.str("");
100 tokens_.clear();
101 usedRegs_.clear();
102 emittedLabels_.clear();
103 indentLevel_ = 0;
104 currentFunc_ = func;
105 
106 // ┌─────────────────────────────────────────────────────────────────────────┐
107 // │ Pre-pass: collect all labels that exist in this function │
108 // └─────────────────────────────────────────────────────────────────────────┘
109 if (func->body) {
110 for (const auto &stmt : func->body->statements) {
111 collectLabels(stmt, emittedLabels_);
112 }
113 }
114 
115 // ┌─────────────────────────────────────────────────────────────────────────┐
116 // │ Collect used reg_ variables from the body first │
117 // └─────────────────────────────────────────────────────────────────────────┘
118 if (func->body) {
119 for (const auto &stmt : func->body->statements) {
120 collectUsedRegs(stmt);
121 }
122 // Filter to keep only registers actually used in the function body
123 if (func->body && !func->body->statements.empty()) {
124 filterUsedRegs(func->body, usedRegs_);
125 }
126 }
127 
128 // Signature - always use int64_t return type to match declarations
129 std::string returnType = (func->returnType == "void") ? "int64_t" : func->returnType;
130 emit(returnType, TokenType::Type);
131 emit(" ");
132 emit(func->name, TokenType::Function, func->address);
133 emit("(");
134 
135 // Always use 6 parameters to match declarations
136 emit("int64_t a1, int64_t a2, int64_t a3, int64_t a4, int64_t a5, int64_t a6");
137 
138 emit(") {\n");
139 
140 indentLevel_++;
141 
142 // ┌─────────────────────────────────────────────────────────────────────────┐
143 // │ Declare collected reg_ variables │
144 // └─────────────────────────────────────────────────────────────────────────┘
145 for (const auto &regName : usedRegs_) {
146 indent();
147 // Choose type based on register prefix
148 std::string typeName = "int64_t";
149 if (regName.find("xmm") != std::string::npos) {
150 typeName = "double";
151 } else if (regName.find("reg_e") != std::string::npos ||
152 (regName.size() == 3 && regName[0] == 'e') ||
153 (regName.size() > 1 && regName.back() == 'd')) {
154 typeName = "int32_t"; // e* registers and r*d (r15d) are 32-bit
155 }
156 emit(typeName, TokenType::Type);
157 emit(" ");
158 emit(regName, TokenType::Identifier);
159 // Initialize registers to 0 to prevent crashes from uninitialized access
160 if (typeName == "double") {
161 emit(" = 0.0; // register temp\n", TokenType::Comment);
162 } else {
163 emit(" = 0; // register temp\n", TokenType::Comment);
164 }
165 }
166 
167 if (!usedRegs_.empty()) {
168 emit("\n");
169 // Initialize argument registers from function parameters (System V AMD64 ABI)
170 indent();
171 emit("// Initialize registers from arguments (System V AMD64 ABI)\n");
172 if (usedRegs_.count("reg_rdi")) { indent(); emit("reg_rdi = a1;\n"); }
173 if (usedRegs_.count("reg_rsi")) { indent(); emit("reg_rsi = a2;\n"); }
174 if (usedRegs_.count("reg_rdx")) { indent(); emit("reg_rdx = a3;\n"); }
175 if (usedRegs_.count("reg_rcx")) { indent(); emit("reg_rcx = a4;\n"); }
176 if (usedRegs_.count("reg_r8")) { indent(); emit("reg_r8 = a5;\n"); }
177 if (usedRegs_.count("reg_r9")) { indent(); emit("reg_r9 = a6;\n"); }
178
179 // Initialize other commonly used registers to safe values
180 if (usedRegs_.count("reg_rax")) { indent(); emit("reg_rax = (int64_t)&g_ps4_memory[0]; // Safe base pointer\n"); }
181 if (usedRegs_.count("reg_rbx")) { indent(); emit("reg_rbx = (int64_t)&g_ps4_memory[0]; // Safe base pointer\n"); }
182 if (usedRegs_.count("reg_rbp")) { indent(); emit("reg_rbp = (int64_t)&g_ps4_memory[0]; // Safe base pointer\n"); }
183 if (usedRegs_.count("reg_rsp")) { indent(); emit("reg_rsp = (int64_t)&g_ps4_memory[0]; // Safe base pointer\n"); }
184
185 emit("\n");
186 }
187 
188 // Local stack variable declarations
189 for (const auto &var : func->locals) {
190 indent();
191 std::string typeName = "int";
192 if (var.complexType)
193 typeName = var.complexType->toString();
194 emit(typeName, TokenType::Type);
195 emit(" ");
196 emit(var.name, TokenType::Identifier);
197 emit(" = 0; ", TokenType::Text);
198 emit("// [rbp", TokenType::Comment);
199 emit(std::string(var.stackOffset < 0 ? " - " : " + ") + "0x",
200 TokenType::Comment);
201 
202 std::stringstream hexOff;
203 hexOff << std::hex << std::abs(var.stackOffset);
204 emit(hexOff.str(), TokenType::Comment);
205 emit("]\n", TokenType::Comment);
206 }
207 
208 if (!func->locals.empty())
209 emit("\n");
210 
211 // Safety counters (declared before body so gotos can't jump over them)
212 indent();
213 emit("int64_t _loop_limit = 0; // per-loop iteration limiter\n", TokenType::Text);
214 indent();
215 emit("int64_t _goto_count = 0; // total goto steps limiter\n", TokenType::Text);
216 emit("\n");
217 
218 // Body
219 if (func->body) {
220 for (const auto &stmt : func->body->statements) {
221 stmt->accept(this);
222 }
223 }
224 
225 // Escape label for goto-based loop bailout
226 emit("_func_exit:\n", TokenType::Text);
227 // Always emit a fallback return so clang doesn't insert a ud2 trap.
228 // Real return paths are emitted by ReturnStatement visitors above.
229 indent();
230 emit("return 0; // fallback\n", TokenType::Text);
231 indentLevel_--;
232 emit("}\n");
233 
234 return ss_.str();
235}
236 
237void CppEmitter::indent() {
238 for (int i = 0; i < indentLevel_; ++i)
239 emit(" ");
240}
241 
242void CppEmitter::emit(const std::string &str, TokenType type, uint64_t addr) {
243 ss_ << str;
244 tokens_.push_back({str, type, addr});
245}
246 
247void CppEmitter::emitLine(const std::string &str, TokenType type) {
248 indent();
249 emit(str, type);
250 emit("\n");
251}
252 
253std::string CppEmitter::inferTypeFromInstruction(uint64_t addr) {
254 auto irFunc = DecompilerContext::Get().GetFunctionAt(currentFunc_->address);
255 if (!irFunc)
256 return "int64_t";
257 
258 for (const auto &bb : irFunc->basicBlocks) {
259 for (const auto &instr : bb->instructions) {
260 if (instr.address == addr) {
261 std::string disasm = instr.disassembly;
262 // SSE/AVX scalar single
263 if (disasm.find("ss") != std::string::npos)
264 return "float";
265 // SSE/AVX scalar double
266 if (disasm.find("sd") != std::string::npos)
267 return "double";
268 // SSE/AVX packed single
269 if (disasm.find("ps") != std::string::npos)
270 return "float";
271 // SSE/AVX packed double
272 if (disasm.find("pd") != std::string::npos)
273 return "double";
274 
275 // Integer sizes
276 if (disasm.find("byte ptr") != std::string::npos)
277 return "int8_t";
278 if (disasm.find("word ptr") != std::string::npos)
279 return "int16_t";
280 if (disasm.find("dword ptr") != std::string::npos)
281 return "int32_t";
282 if (disasm.find("qword ptr") != std::string::npos)
283 return "int64_t";
284 
285 // Fallback to register sizes
286 for (const auto &op : instr.operands) {
287 if (op.type == IR::Operand::Type::Register) {
288 std::string r = op.regName;
289 if (r[0] == 'e')
290 return "int32_t";
291 if (r[0] == 'r')
292 return "int64_t";
293 if (r.find("xmm") != std::string::npos)
294 return "float";
295 }
296 }
297 }
298 }
299 }
300 return "int64_t";
301}
302 
303 
304// ═══════════════════════════════════════════════════════════════════════════
305// Expressions
306// ═══════════════════════════════════════════════════════════════════════════
307 
308void CppEmitter::visit(AST::ConstantExpr *node) {
309 TokenType type = TokenType::Number;
310 if (node->kind == AST::ConstantExpr::Kind::String)
311 type = TokenType::String;
312 emit(node->toString(), type);
313}
314 
315void CppEmitter::visit(AST::VariableExpr *node) {
316 emit(node->name, TokenType::Identifier);
317}
318 
319// Helper function for constant folding
320static std::shared_ptr<AST::Expression> foldConstants(
321 AST::BinaryExpr::Op op,
322 std::shared_ptr<AST::Expression> left,
323 std::shared_ptr<AST::Expression> right) {
324
325 // Only fold if both operands are constants
326 auto leftConst = std::dynamic_pointer_cast<AST::ConstantExpr>(left);
327 auto rightConst = std::dynamic_pointer_cast<AST::ConstantExpr>(right);
328
329 if (!leftConst || !rightConst) {
330 return nullptr; // Can't fold
331 }
332
333 // Only fold integer constants for now
334 if (leftConst->kind != AST::ConstantExpr::Kind::Integer ||
335 rightConst->kind != AST::ConstantExpr::Kind::Integer) {
336 return nullptr;
337 }
338
339 int64_t lval = leftConst->intValue;
340 int64_t rval = rightConst->intValue;
341 int64_t result = 0;
342 bool canFold = true;
343
344 switch (op) {
345 case AST::BinaryExpr::Op::Add:
346 result = lval + rval;
347 break;
348 case AST::BinaryExpr::Op::Sub:
349 result = lval - rval;
350 break;
351 case AST::BinaryExpr::Op::Mul:
352 result = lval * rval;
353 break;
354 case AST::BinaryExpr::Op::Div:
355 if (rval != 0) result = lval / rval;
356 else canFold = false;
357 break;
358 case AST::BinaryExpr::Op::Mod:
359 if (rval != 0) result = lval % rval;
360 else canFold = false;
361 break;
362 case AST::BinaryExpr::Op::And:
363 result = lval & rval;
364 break;
365 case AST::BinaryExpr::Op::Or:
366 result = lval | rval;
367 break;
368 case AST::BinaryExpr::Op::Xor:
369 result = lval ^ rval;
370 break;
371 case AST::BinaryExpr::Op::Shl:
372 result = lval << rval;
373 break;
374 case AST::BinaryExpr::Op::Shr:
375 result = lval >> rval;
376 break;
377 case AST::BinaryExpr::Op::Eq:
378 result = (lval == rval) ? 1 : 0;
379 break;
380 case AST::BinaryExpr::Op::Ne:
381 result = (lval != rval) ? 1 : 0;
382 break;
383 case AST::BinaryExpr::Op::Lt:
384 result = (lval < rval) ? 1 : 0;
385 break;
386 case AST::BinaryExpr::Op::Le:
387 result = (lval <= rval) ? 1 : 0;
388 break;
389 case AST::BinaryExpr::Op::Gt:
390 result = (lval > rval) ? 1 : 0;
391 break;
392 case AST::BinaryExpr::Op::Ge:
393 result = (lval >= rval) ? 1 : 0;
394 break;
395 default:
396 canFold = false;
397 }
398
399 if (canFold) {
400 return std::make_shared<AST::ConstantExpr>(result, leftConst->isHex || rightConst->isHex);
401 }
402
403 return nullptr;
404}
405 
406void CppEmitter::visit(AST::BinaryExpr *node) {
407 // Attempt constant folding first
408 if (auto folded = foldConstants(node->op, node->left, node->right)) {
409 folded->accept(this);
410 return;
411 }
412
413 // Check for Bitwise operations on doubles/floats (xmm registers)
414 if (node->op == AST::BinaryExpr::Op::Xor ||
415 node->op == AST::BinaryExpr::Op::And ||
416 node->op == AST::BinaryExpr::Op::Or) {
417 // Simple heuristic: check if operands are xmm registers
418 auto checkXmm = [](std::shared_ptr<AST::Expression> e) {
419 if (auto v = std::dynamic_pointer_cast<AST::VariableExpr>(e)) {
420 return v->name.find("xmm") != std::string::npos;
421 }
422 return false;
423 };
424 if (checkXmm(node->left) || checkXmm(node->right)) {
425 // If self-xor (clearing), emit 0.0
426 if (node->op == AST::BinaryExpr::Op::Xor &&
427 node->left->toString() == node->right->toString()) {
428 emit("0.0");
429 return;
430 }
431 // Otherwise emit bitwise cast wrapper using helper functions
432 emit("as_double((as_uint64(", TokenType::Text);
433 node->left->accept(this);
434 emit(") " + AST::BinaryExpr::opToString(node->op) + " as_uint64(",
435 TokenType::Text);
436 node->right->accept(this);
437 emit(")))", TokenType::Text);
438 return;
439 }
440 }
441 
442 // For assignments, don't wrap in parentheses
443 if (node->op == AST::BinaryExpr::Op::Assign) {
444 node->left->accept(this);
445 emit(" = ");
446 node->right->accept(this);
447 return;
448 }
449 
450 // For simple expressions (variable op constant), skip outer parentheses
451 bool leftSimple = std::dynamic_pointer_cast<AST::VariableExpr>(node->left) ||
452 std::dynamic_pointer_cast<AST::ConstantExpr>(node->left);
453 bool rightSimple = std::dynamic_pointer_cast<AST::VariableExpr>(node->right) ||
454 std::dynamic_pointer_cast<AST::ConstantExpr>(node->right);
455
456 if (leftSimple && rightSimple) {
457 node->left->accept(this);
458 emit(" " + AST::BinaryExpr::opToString(node->op) + " ");
459 node->right->accept(this);
460 } else {
461 emit("(");
462 node->left->accept(this);
463 emit(" " + AST::BinaryExpr::opToString(node->op) + " ");
464 node->right->accept(this);
465 emit(")");
466 }
467}
468 
469void CppEmitter::visit(AST::UnaryExpr *node) {
470 if (node->op == AST::UnaryExpr::Op::Deref) {
471 emit("*");
472 // If unwrapping a register variable, we must cast it to a pointer first
473 if (auto var =
474 std::dynamic_pointer_cast<AST::VariableExpr>(node->operand)) {
475 if (var->name.find("reg_") == 0 || var->name.find("r") == 0 ||
476 var->name.find("e") == 0) {
477 emit("(int64_t*)");
478 }
479 }
480 emit("(");
481 node->operand->accept(this);
482 emit(")");
483 } else if (node->op == AST::UnaryExpr::Op::Negate) {
484 emit("-");
485 emit("(");
486 node->operand->accept(this);
487 emit(")");
488 } else if (node->op == AST::UnaryExpr::Op::BitwiseNot) {
489 // Check for bitwise not on xmm
490 if (auto v = std::dynamic_pointer_cast<AST::VariableExpr>(node->operand)) {
491 if (v->name.find("xmm") != std::string::npos) {
492 emit("as_double(~as_uint64(", TokenType::Text);
493 node->operand->accept(this);
494 emit("))", TokenType::Text);
495 return;
496 }
497 }
498 emit("~");
499 node->operand->accept(this);
500 } else {
501 emit(AST::UnaryExpr::opToString(node->op));
502 node->operand->accept(this);
503 }
504}
505 
506void CppEmitter::visit(AST::CallExpr *node) {
507 // Cleanup Assembly instructions if they are wrapped in __asm()
508 if (node->functionName == "__asm" && !node->arguments.empty()) {
509 if (auto constExpr =
510 std::dynamic_pointer_cast<AST::ConstantExpr>(node->arguments[0])) {
511 std::string asmCode = constExpr->strValue;
512 
513 // 1. Filter LEAVE (stack frame management)
514 if (asmCode.find("leave") != std::string::npos) {
515 return; // Do not emit anything for leave
516 }
517 
518 // 2. Translate BSWAP to intrinsic
519 if (asmCode.find("bswap") != std::string::npos) {
520 emit("_byteswap_ulong", TokenType::Function);
521 emit("(");
522 // Fallback to whatever register was in the asm if we can't extract it
523 // perfectly Here we just use a comment placeholder as before but we are
524 // inside the call
525 emit("/* " + asmCode + " */", TokenType::Comment);
526 emit(")");
527 return;
528 }
529 
530 // 3. Translate FISTTP (FPU to Integer with truncation)
531 if (asmCode.find("fisttp") != std::string::npos) {
532 emit("(int)", TokenType::Type);
533 emit("/* " + asmCode + " */", TokenType::Comment);
534 return;
535 }
536 
537 // 4. Translate vcvttss2si (SSE Float to Int)
538 if (asmCode.find("vcvttss2si") != std::string::npos) {
539 // Try to identify destination and source from arguments if possible
540 // Otherwise emit cast
541 emit("(int32_t)", TokenType::Type);
542 emit("/* " + asmCode + " */", TokenType::Comment);
543 return;
544 }
545 }
546 }
547 
548 emit(node->functionName, TokenType::Function, node->targetAddress);
549 emit("(");
550 for (size_t i = 0; i < node->arguments.size(); ++i) {
551 if (i > 0)
552 emit(", ");
553 node->arguments[i]->accept(this);
554 }
555 
556 // Pad with zeros if we missed arguments in analysis
557 // This ensures compilation against the generated header
558 if (node->targetAddress != 0) {
559 int expectedParams =
560 DecompilerContext::Get().GetFunctionParamCount(node->targetAddress);
561 for (size_t i = node->arguments.size(); i < expectedParams; ++i) {
562 if (i > 0)
563 emit(", ");
564 emit("0");
565 }
566 }
567 
568 emit(")");
569}
570 
571void CppEmitter::visit(AST::MemoryExpr *node) {
572 // Check if base is RIP
573 bool isRip = false;
574 if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(node->base)) {
575 if (var->name == "rip" || var->name == "RIP") {
576 isRip = true;
577 }
578 }
579 
580 if (isRip && node->offset) {
581 if (auto cnst =
582 std::dynamic_pointer_cast<AST::ConstantExpr>(node->offset)) {
583 uint64_t ripNext = node->sourceAddress + 7;
584 auto irFunc =
585 DecompilerContext::Get().GetFunctionAt(currentFunc_->address);
586 if (irFunc) {
587 for (const auto &bb : irFunc->basicBlocks) {
588 for (size_t i = 0; i < bb->instructions.size(); ++i) {
589 if (bb->instructions[i].address == node->sourceAddress) {
590 if (i + 1 < bb->instructions.size()) {
591 ripNext = bb->instructions[i + 1].address;
592 } else {
593 ripNext = bb->endAddress;
594 }
595 goto found_rip_mem;
596 }
597 }
598 }
599 }
600 found_rip_mem:
601 uint64_t targetAddr = ripNext + cnst->intValue;
602 auto symDb = DecompilerContext::Get().GetSymbolDatabase();
603 std::string symName = symDb ? symDb->getSymbolName(targetAddr) : "";
604 
605 if (!symName.empty() && symName.find("var_") == std::string::npos &&
606 symName.find("sub_") == std::string::npos &&
607 symName.find("loc_") == std::string::npos) {
608 emit(symName, TokenType::Global, targetAddr);
609 } else {
610 // ┌─────────────────────────────────────────────────────────────────────┐
611 // │ AUTO-GENERATE global name: g_<type>_<addr> │
612 // └─────────────────────────────────────────────────────────────────────┘
613 std::string typeName = inferTypeFromInstruction(node->sourceAddress);
614 std::stringstream globalName;
615 globalName << "g_" << typeName << "_" << std::hex << targetAddr;
616 std::string name = globalName.str();
617 
618 if (symDb) {
619 symDb->addSymbol(targetAddr, name,
620 Analysis::SymbolType::GlobalVariable);
621 }
622 
623 emit(name, TokenType::Global, targetAddr);
624 }
625 return;
626 }
627 }
628 
629 emit(node->toString()); // Fallback to basic string repr, could be improved
630}
631 
632void CppEmitter::visit(AST::MemberAccessExpr *node) {
633 // Check if base is RIP
634 bool isRip = false;
635 if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(node->base)) {
636 if (var->name == "rip" || var->name == "RIP") {
637 isRip = true;
638 }
639 }
640 
641 if (isRip) {
642 // Address = RIP_Next + Displacement (node->offset)
643 // Heuristic for RIP_Next: sourceAddress + instruction_size.
644 // We try to find the actual instruction size from IR if available.
645 uint64_t ripNext = node->sourceAddress + 7; // Default fallback for x64
646 
647 auto irFunc = DecompilerContext::Get().GetFunctionAt(currentFunc_->address);
648 if (irFunc) {
649 for (const auto &bb : irFunc->basicBlocks) {
650 for (size_t i = 0; i < bb->instructions.size(); ++i) {
651 if (bb->instructions[i].address == node->sourceAddress) {
652 if (i + 1 < bb->instructions.size()) {
653 ripNext = bb->instructions[i + 1].address;
654 } else {
655 // Last instruction in block, RIP_Next is start of next block or
656 // end of current
657 ripNext = bb->endAddress;
658 }
659 goto found_rip;
660 }
661 }
662 }
663 }
664 found_rip:
665 uint64_t targetAddr = ripNext + node->offset;
666 auto symDb = DecompilerContext::Get().GetSymbolDatabase();
667 std::string symName = symDb ? symDb->getSymbolName(targetAddr) : "";
668 
669 if (!symName.empty() && symName.find("var_") == std::string::npos &&
670 symName.find("sub_") == std::string::npos) {
671 emit(symName, TokenType::Global, targetAddr);
672 } else {
673 // ┌─────────────────────────────────────────────────────────────────────┐
674 // │ AUTO-GENERATE global name: g_<type>_<addr> │
675 // │ Register in SymbolDatabase for consistent naming │
676 // └─────────────────────────────────────────────────────────────────────┘
677 std::string typeName = inferTypeFromInstruction(node->sourceAddress);
678 std::stringstream globalName;
679 globalName << "g_" << typeName << "_" << std::hex << targetAddr;
680 std::string name = globalName.str();
681 
682 // Register for future references
683 if (symDb) {
684 symDb->addSymbol(targetAddr, name,
685 Analysis::SymbolType::GlobalVariable);
686 }
687 
688 emit(name, TokenType::Global, targetAddr);
689 }
690 return;
691 }
692 
693 // Check if base is a register variable that needs casting
694 if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(node->base)) {
695 std::string name = var->name;
696 // Normalize name (remove reg_ prefix if present, though collecting adds it,
697 // usage might not?) Actually usage in AST is usually "rbx" or "reg_rbx".
698 if (name.find("reg_") == 0)
699 name = name.substr(4);
700 
701 // Check if it's a standard access register
702 if (name == "rax" || name == "rbx" || name == "rcx" || name == "rdx" ||
703 name == "rsi" || name == "rdi" || name == "rbp" || name == "rsp" ||
704 (name.size() >= 2 && name[0] == 'r' && isdigit(name[1]))) {
705 
706 emit("((Struct_", TokenType::Type);
707 emit(name, TokenType::Type);
708 emit("*)", TokenType::Type);
709 node->base->accept(this);
710 emit(")", TokenType::Text);
711 emit("->", TokenType::Text);
712 emit(node->memberName, TokenType::Identifier);
713 return;
714 }
715 }
716 
717 node->base->accept(this);
718 emit("->", TokenType::Text);
719 emit(node->memberName, TokenType::Identifier);
720 // Maybe add a comment relative to offset if debug mode is on?
721}
722 
723void CppEmitter::visit(AST::CastExpr *node) {
724 emit("(", TokenType::Text);
725 emit(node->targetType, TokenType::Type);
726 emit(")", TokenType::Text);
727 // Wrap complex expressions in parentheses to ensure correct precedence
728 bool needsParens = std::dynamic_pointer_cast<AST::BinaryExpr>(node->expr) ||
729 std::dynamic_pointer_cast<AST::UnaryExpr>(node->expr);
730 if (needsParens) emit("(");
731 node->expr->accept(this);
732 if (needsParens) emit(")");
733}
734 
735// ═══════════════════════════════════════════════════════════════════════════
736// Statements
737// ═══════════════════════════════════════════════════════════════════════════
738 
739void CppEmitter::visit(AST::CompoundStatement *node) {
740 emitLine("{ ", TokenType::Text);
741 indentLevel_++;
742 for (const auto &stmt : node->statements) {
743 stmt->accept(this);
744 }
745 indentLevel_--;
746 emitLine("}", TokenType::Text);
747}
748 
749void CppEmitter::visit(AST::ExpressionStatement *node) {
750 indent();
751 
752 // ┌─────────────────────────────────────────────────────────────────────────┐
753 // │ VALIDATE: Only emit expressions that are valid C++ statements │
754 // │ - Assignments (x = y) │
755 // │ - Function calls (foo()) │
756 // │ - Other expressions become comments to preserve debug info │
757 // └─────────────────────────────────────────────────────────────────────────┘
758 
759 bool isValidStatement = false;
760 
761 // Check for function call
762 if (std::dynamic_pointer_cast<AST::CallExpr>(node->expression)) {
763 isValidStatement = true;
764 }
765 // Check for assignment expression
766 else if (auto binExpr =
767 std::dynamic_pointer_cast<AST::BinaryExpr>(node->expression)) {
768 if (binExpr->op == AST::BinaryExpr::Op::Assign) {
769 isValidStatement = true;
770 }
771 }
772 // Check for cast expression (valid in C++)
773 else if (std::dynamic_pointer_cast<AST::CastExpr>(node->expression)) {
774 isValidStatement = true;
775 }
776 
777 if (isValidStatement) {
778 node->expression->accept(this);
779 emit(";\n", TokenType::Text);
780 } else {
781 // Emit as comment to preserve debug info without syntax errors
782 emit("// ", TokenType::Comment);
783 node->expression->accept(this);
784 emit(";\n", TokenType::Comment);
785 }
786}
787 
788void CppEmitter::visit(AST::IfStatement *node) {
789 indent();
790 emit("if", TokenType::Keyword);
791 emit(" (", TokenType::Text);
792 node->condition->accept(this);
793 emit(") {\n", TokenType::Text);
794 
795 indentLevel_++;
796 if (auto compound =
797 std::dynamic_pointer_cast<AST::CompoundStatement>(node->thenBranch)) {
798 for (const auto &s : compound->statements)
799 s->accept(this);
800 } else if (node->thenBranch) {
801 node->thenBranch->accept(this);
802 }
803 indentLevel_--;
804 
805 indent();
806 emit("}", TokenType::Text);
807 
808 if (node->elseBranch) {
809 emit(" ", TokenType::Text);
810 emit("else", TokenType::Keyword);
811 emit(" {\n", TokenType::Text);
812 indentLevel_++;
813 if (auto compound = std::dynamic_pointer_cast<AST::CompoundStatement>(
814 node->elseBranch)) {
815 for (const auto &s : compound->statements)
816 s->accept(this);
817 } else {
818 node->elseBranch->accept(this);
819 }
820 indentLevel_--;
821 indent();
822 emit("}", TokenType::Text);
823 }
824 emit("\n", TokenType::Text);
825}
826 
827void CppEmitter::visit(AST::WhileStatement *node) {
828 // Reset counter (assignment, not declaration - safe for goto)
829 indent(); emit("_loop_limit = 0;\n", TokenType::Text);
830 indent();
831 emit("while", TokenType::Keyword);
832 emit(" (", TokenType::Text);
833 node->condition->accept(this);
834 emit(") {\n", TokenType::Text);
835 indentLevel_++;
836 indent(); emit("if (++_loop_limit > 1000000LL) break; // loop safety\n", TokenType::Comment);
837 if (node->body) {
838 if (auto compound =
839 std::dynamic_pointer_cast<AST::CompoundStatement>(node->body)) {
840 for (const auto &s : compound->statements)
841 s->accept(this);
842 } else {
843 node->body->accept(this);
844 }
845 }
846 indentLevel_--;
847 emitLine("}", TokenType::Text);
848}
849 
850void CppEmitter::visit(AST::DoWhileStatement *node) {
851 // Reset counter (assignment, not declaration - safe for goto)
852 indent(); emit("_loop_limit = 0;\n", TokenType::Text);
853 indent();
854 emit("do", TokenType::Keyword);
855 emit(" {\n", TokenType::Text);
856 indentLevel_++;
857 indent(); emit("if (++_loop_limit > 1000000LL) break; // loop safety\n", TokenType::Comment);
858 if (node->body) {
859 if (auto compound =
860 std::dynamic_pointer_cast<AST::CompoundStatement>(node->body)) {
861 for (const auto &s : compound->statements)
862 s->accept(this);
863 } else {
864 node->body->accept(this);
865 }
866 }
867 indentLevel_--;
868 indent();
869 emit("} ", TokenType::Text);
870 emit("while", TokenType::Keyword);
871 emit(" (", TokenType::Text);
872 node->condition->accept(this);
873 emit(");\n", TokenType::Text);
874}
875 
876void CppEmitter::visit(AST::ForStatement *node) {
877 indent();
878 emit("for (", TokenType::Keyword);
879 if (node->init)
880 node->init->accept(this);
881 emit("; ", TokenType::Text);
882 if (node->condition)
883 node->condition->accept(this);
884 emit("; ", TokenType::Text);
885 if (node->step)
886 node->step->accept(this);
887 emit(") {\n", TokenType::Text);
888 indentLevel_++;
889 if (node->body) {
890 if (auto compound =
891 std::dynamic_pointer_cast<AST::CompoundStatement>(node->body)) {
892 for (const auto &s : compound->statements)
893 s->accept(this);
894 } else {
895 node->body->accept(this);
896 }
897 }
898 indentLevel_--;
899 emitLine("}", TokenType::Text);
900}
901 
902void CppEmitter::visit(AST::ReturnStatement *node) {
903 indent();
904 emit("return", TokenType::Keyword);
905 if (node->value) {
906 emit(" ", TokenType::Text);
907 node->value->accept(this);
908 }
909 emit(";\n", TokenType::Text);
910}
911 
912void CppEmitter::visit(AST::BreakStatement *node) {
913 emitLine("break;", TokenType::Keyword);
914}
915 
916void CppEmitter::visit(AST::ContinueStatement *node) {
917 emitLine("continue;", TokenType::Keyword);
918}
919 
920void CppEmitter::visit(AST::GotoStatement *node) {
921 indent();
922 // Only emit goto if the target label exists in this function
923 if (emittedLabels_.count(node->label)) {
924 // Safety: bail out if this function has looped too many times via gotos
925 emit("if (++_goto_count > 100000LL) goto _func_exit; ", TokenType::Comment);
926 emit("goto", TokenType::Keyword);
927 emit(" ", TokenType::Text);
928 emit(node->label, TokenType::Identifier, node->targetAddress);
929 emit(";\n", TokenType::Text);
930 } else {
931 // Label doesn't exist - emit as comment to avoid compilation error
932 emit("// goto ", TokenType::Comment);
933 emit(node->label, TokenType::Comment);
934 emit("; // label not in scope\n", TokenType::Comment);
935 }
936}
937 
938void CppEmitter::visit(AST::LabelStatement *node) {
939 // Labels are dedented slightly usually
940 // indent();
941 emit(node->name, TokenType::Identifier, node->address);
942 emit(":\n", TokenType::Text);
943}
944 
945void CppEmitter::visit(AST::CaseStmt *node) {
946 // case val:
947 for (auto val : node->values) {
948 if (indentLevel_ > 0)
949 indentLevel_--; // Case labels are dedented
950 indent();
951 if (indentLevel_ > 0)
952 indentLevel_++; // Restore
953 
954 emit("case ", TokenType::Keyword);
955 emit(std::to_string(val), TokenType::Number);
956 emit(":\n", TokenType::Text);
957 }
958 if (node->isDefault) {
959 if (indentLevel_ > 0)
960 indentLevel_--;
961 indent();
962 if (indentLevel_ > 0)
963 indentLevel_++;
964 emit("default:\n", TokenType::Keyword);
965 }
966 
967 // Body (CompoundStatement handles braces, but often cases don't enforce
968 // braces in C++) However, our CaseStmt has a CompoundStatement body. Standard
969 // C++: case 1: { ... }
970 if (node->body) {
971 node->body->accept(this);
972 }
973}
974 
975void CppEmitter::visit(AST::SwitchStmt *node) {
976 indent();
977 emit("switch", TokenType::Keyword);
978 emit(" (", TokenType::Text);
979 node->condition->accept(this);
980 emit(") {\n", TokenType::Text);
981 
982 indentLevel_++;
983 for (const auto &cse : node->cases) {
984 cse->accept(this);
985 }
986 indentLevel_--;
987 
988 indent();
989 emit("}\n", TokenType::Text);
990}
991 
992// ═══════════════════════════════════════════════════════════════════════════
993// Collect Used Register Variables (reg_*, xmm*, unknown_op, g_*)
994// ═══════════════════════════════════════════════════════════════════════════
995void CppEmitter::collectUsedRegsExpr(
996 const std::shared_ptr<AST::Expression> &expr) {
997 if (!expr)
998 return;
999 
1000 if (auto var = std::dynamic_pointer_cast<AST::VariableExpr>(expr)) {
1001 const std::string &n = var->name;
1002 // Collect: reg_*, xmm*, unknown_op, and standard registers (rax, rbx, etc.)
1003 if (n.find("reg_") == 0 || n.find("xmm") == 0 || n == "unknown_op" ||
1004 n == "rax" || n == "rbx" || n == "rcx" || n == "rdx" || n == "rsi" ||
1005 n == "rdi" || n == "rbp" || n == "rsp" ||
1006 (n.size() >= 2 && n[0] == 'r' && isdigit(n[1])) || // r8, r9, r10...
1007 (n.size() == 3 && n[0] == 'e') // eax, ebx...
1008 ) {
1009 usedRegs_.insert(n);
1010 }
1011 } else if (auto bin = std::dynamic_pointer_cast<AST::BinaryExpr>(expr)) {
1012 collectUsedRegsExpr(bin->left);
1013 collectUsedRegsExpr(bin->right);
1014 } else if (auto unary = std::dynamic_pointer_cast<AST::UnaryExpr>(expr)) {
1015 collectUsedRegsExpr(unary->operand);
1016 } else if (auto call = std::dynamic_pointer_cast<AST::CallExpr>(expr)) {
1017 for (const auto &arg : call->arguments)
1018 collectUsedRegsExpr(arg);
1019 } else if (auto mem =
1020 std::dynamic_pointer_cast<AST::MemberAccessExpr>(expr)) {
1021 collectUsedRegsExpr(mem->base);
1022 } else if (auto memex = std::dynamic_pointer_cast<AST::MemoryExpr>(expr)) {
1023 collectUsedRegsExpr(memex->base);
1024 collectUsedRegsExpr(memex->offset);
1025 } else if (auto cast = std::dynamic_pointer_cast<AST::CastExpr>(expr)) {
1026 collectUsedRegsExpr(cast->expr);
1027 }
1028}
1029 
1030void CppEmitter::collectUsedRegs(const std::shared_ptr<AST::Statement> &stmt) {
1031 if (!stmt)
1032 return;
1033 
1034 if (auto compound = std::dynamic_pointer_cast<AST::CompoundStatement>(stmt)) {
1035 for (const auto &s : compound->statements)
1036 collectUsedRegs(s);
1037 } else if (auto exprStmt =
1038 std::dynamic_pointer_cast<AST::ExpressionStatement>(stmt)) {
1039 collectUsedRegsExpr(exprStmt->expression);
1040 } else if (auto ifStmt = std::dynamic_pointer_cast<AST::IfStatement>(stmt)) {
1041 collectUsedRegsExpr(ifStmt->condition);
1042 collectUsedRegs(ifStmt->thenBranch);
1043 collectUsedRegs(ifStmt->elseBranch);
1044 } else if (auto whileStmt =
1045 std::dynamic_pointer_cast<AST::WhileStatement>(stmt)) {
1046 collectUsedRegsExpr(whileStmt->condition);
1047 collectUsedRegs(whileStmt->body);
1048 } else if (auto doWhile =
1049 std::dynamic_pointer_cast<AST::DoWhileStatement>(stmt)) {
1050 collectUsedRegsExpr(doWhile->condition);
1051 collectUsedRegs(doWhile->body);
1052 } else if (auto forStmt =
1053 std::dynamic_pointer_cast<AST::ForStatement>(stmt)) {
1054 collectUsedRegsExpr(forStmt->init);
1055 collectUsedRegsExpr(forStmt->condition);
1056 collectUsedRegsExpr(forStmt->step);
1057 collectUsedRegs(forStmt->body);
1058 } else if (auto ret = std::dynamic_pointer_cast<AST::ReturnStatement>(stmt)) {
1059 collectUsedRegsExpr(ret->value);
1060 } else if (auto sw = std::dynamic_pointer_cast<AST::SwitchStmt>(stmt)) {
1061 collectUsedRegsExpr(sw->condition);
1062 for (const auto &cs : sw->cases)
1063 collectUsedRegs(cs);
1064 } else if (auto cs = std::dynamic_pointer_cast<AST::CaseStmt>(stmt)) {
1065 // CaseStmt::values is vector<int64_t>, no expressions to check
1066 if (cs->body) {
1067 for (const auto &s : cs->body->statements)
1068 collectUsedRegs(s);
1069 }
1070 }
1071}
1072 
1073} // namespace ShadPKG::Decompiler::Codegen