Seregon/ShadPKG

A tool for deriving PKG packet encryption keys for ps4 written in c++

C++/47.3 KB/No license
core/decompiler/ir/AST.h
ShadPKG / core / decompiler / ir / AST.h
1#pragma once
2 
3#include <cstdint>
4#include <memory>
5#include <sstream>
6#include <string>
7#include <vector>
8 
9namespace ShadPKG::Decompiler::Analysis {
10class Type;
11}
12 
13namespace ShadPKG::Decompiler::AST {
14 
15// ╔═══════════════════════════════════════════════════════════════════════════╗
16// ║ FORWARD DECLARATIONS ║
17// ╚═══════════════════════════════════════════════════════════════════════════╝
18 
19class ASTVisitor;
20 
21// ╔═══════════════════════════════════════════════════════════════════════════╗
22// ║ BASE AST NODE ║
23// ╚═══════════════════════════════════════════════════════════════════════════╝
24 
25class ASTNode {
26public:
27 virtual ~ASTNode() = default;
28 virtual void accept(ASTVisitor *visitor) = 0;
29 
30 uint64_t sourceAddress = 0; // Original instruction address for debugging
31};
32 
33// ╔═══════════════════════════════════════════════════════════════════════════╗
34// ║ EXPRESSION HIERARCHY ║
35// ║ ║
36// ║ Expression ║
37// ║ ├── ConstantExpr (42, 0xFF, "string") ║
38// ║ ├── VariableExpr (var_4, rax) ║
39// ║ ├── BinaryExpr (a + b, a > b, a && b) ║
40// ║ ├── UnaryExpr (!a, -a, ~a) ║
41// ║ ├── CallExpr (func(args...)) ║
42// ║ └── MemoryExpr (*ptr, arr[i]) ║
43// ╚═══════════════════════════════════════════════════════════════════════════╝
44 
45class Expression : public ASTNode {
46public:
47 enum class Type {
48 Unknown,
49 Int8,
50 Int16,
51 Int32,
52 Int64,
53 UInt8,
54 UInt16,
55 UInt32,
56 UInt64,
57 Float,
58 Double,
59 Pointer,
60 Boolean
61 };
62 
63 Type resultType = Type::Unknown;
64 
65 virtual std::string toString() const = 0;
66};
67 
68// ─────────────────────────────────────────────────────────────────────────────
69// ConstantExpr: Literal values (42, 0xFF, 3.14)
70// ─────────────────────────────────────────────────────────────────────────────
71 
72class ConstantExpr : public Expression {
73public:
74 enum class Kind { Integer, Float, String };
75 
76 Kind kind = Kind::Integer;
77 int64_t intValue = 0;
78 double floatValue = 0.0;
79 std::string strValue;
80 bool isHex = false;
81 
82 ConstantExpr(int64_t val, bool hex = false)
83 : kind(Kind::Integer), intValue(val), isHex(hex) {
84 resultType = Type::Int64;
85 }
86 
87 ConstantExpr(double val) : kind(Kind::Float), floatValue(val) {
88 resultType = Type::Double;
89 }
90 
91 ConstantExpr(const std::string &val) : kind(Kind::String), strValue(val) {
92 resultType = Type::Unknown; // String literal type? Pointer usually.
93 }
94 
95 void accept(ASTVisitor *visitor) override;
96 
97 std::string toString() const override {
98 std::stringstream ss;
99 if (kind == Kind::Integer) {
100 if (isHex)
101 ss << "0x" << std::hex << intValue;
102 else
103 ss << std::dec << intValue;
104 } else if (kind == Kind::Float) {
105 ss << floatValue;
106 } else {
107 ss << "\"" << strValue << "\"";
108 }
109 return ss.str();
110 }
111};
112 
113// ─────────────────────────────────────────────────────────────────────────────
114// VariableExpr: Named variables (var_4, param_0, rax)
115// ─────────────────────────────────────────────────────────────────────────────
116 
117class VariableExpr : public Expression {
118public:
119 std::string name;
120 bool isParameter = false;
121 int stackOffset = 0; // Original [rbp - offset] for reference
122 
123 explicit VariableExpr(const std::string &n) : name(n) {}
124 
125 void accept(ASTVisitor *visitor) override;
126 
127 std::string toString() const override { return name; }
128};
129 
130// ─────────────────────────────────────────────────────────────────────────────
131// BinaryExpr: Two-operand expressions (a + b, a > b)
132// ─────────────────────────────────────────────────────────────────────────────
133 
134class BinaryExpr : public Expression {
135public:
136 enum class Op {
137 // Arithmetic
138 Add,
139 Sub,
140 Mul,
141 Div,
142 Mod,
143 // Bitwise
144 And,
145 Or,
146 Xor,
147 Shl,
148 Shr,
149 // Comparison
150 Eq,
151 Ne,
152 Lt,
153 Le,
154 Gt,
155 Ge,
156 // Logical
157 LogicalAnd,
158 LogicalOr,
159 // Assignment
160 Assign
161 };
162 
163 Op op;
164 std::shared_ptr<Expression> left;
165 std::shared_ptr<Expression> right;
166 
167 BinaryExpr(Op o, std::shared_ptr<Expression> l, std::shared_ptr<Expression> r)
168 : op(o), left(std::move(l)), right(std::move(r)) {}
169 
170 void accept(ASTVisitor *visitor) override;
171 
172 static std::string opToString(Op o) {
173 switch (o) {
174 case Op::Add:
175 return "+";
176 case Op::Sub:
177 return "-";
178 case Op::Mul:
179 return "*";
180 case Op::Div:
181 return "/";
182 case Op::Mod:
183 return "%";
184 case Op::And:
185 return "&";
186 case Op::Or:
187 return "|";
188 case Op::Xor:
189 return "^";
190 case Op::Shl:
191 return "<<";
192 case Op::Shr:
193 return ">>";
194 case Op::Eq:
195 return "==";
196 case Op::Ne:
197 return "!=";
198 case Op::Lt:
199 return "<";
200 case Op::Le:
201 return "<=";
202 case Op::Gt:
203 return ">";
204 case Op::Ge:
205 return ">=";
206 case Op::LogicalAnd:
207 return "&&";
208 case Op::LogicalOr:
209 return "||";
210 case Op::Assign:
211 return "=";
212 }
213 return "?";
214 }
215 
216 std::string toString() const override {
217 return "(" + left->toString() + " " + opToString(op) + " " +
218 right->toString() + ")";
219 }
220};
221 
222// ─────────────────────────────────────────────────────────────────────────────
223// UnaryExpr: Single-operand expressions (!a, -a, ~a)
224// ─────────────────────────────────────────────────────────────────────────────
225 
226class UnaryExpr : public Expression {
227public:
228 enum class Op {
229 Negate, // -a
230 Not, // !a
231 BitwiseNot, // ~a
232 Deref, // *a
233 AddressOf // &a
234 };
235 
236 Op op;
237 std::shared_ptr<Expression> operand;
238 
239 UnaryExpr(Op o, std::shared_ptr<Expression> expr)
240 : op(o), operand(std::move(expr)) {}
241 
242 void accept(ASTVisitor *visitor) override;
243 
244 static std::string opToString(Op o) {
245 switch (o) {
246 case Op::Negate:
247 return "-";
248 case Op::Not:
249 return "!";
250 case Op::BitwiseNot:
251 return "~";
252 case Op::Deref:
253 return "*";
254 case Op::AddressOf:
255 return "&";
256 }
257 return "?";
258 }
259 
260 std::string toString() const override {
261 return opToString(op) + operand->toString();
262 }
263};
264 
265// ─────────────────────────────────────────────────────────────────────────────
266// CallExpr: Function calls (func(a, b, c))
267// ─────────────────────────────────────────────────────────────────────────────
268 
269class CallExpr : public Expression {
270public:
271 std::string functionName;
272 uint64_t targetAddress = 0;
273 std::vector<std::shared_ptr<Expression>> arguments;
274 
275 explicit CallExpr(const std::string &name) : functionName(name) {}
276 explicit CallExpr(uint64_t addr) : targetAddress(addr) {
277 std::stringstream ss;
278 ss << "sub_" << std::hex << addr;
279 functionName = ss.str();
280 }
281 
282 void accept(ASTVisitor *visitor) override;
283 
284 std::string toString() const override {
285 std::string result = functionName + "(";
286 for (size_t i = 0; i < arguments.size(); ++i) {
287 if (i > 0)
288 result += ", ";
289 result += arguments[i]->toString();
290 }
291 return result + ")";
292 }
293};
294 
295// ─────────────────────────────────────────────────────────────────────────────
296// MemoryExpr: Memory access (*ptr, base[offset])
297// ─────────────────────────────────────────────────────────────────────────────
298 
299class MemoryExpr : public Expression {
300public:
301 std::shared_ptr<Expression> base;
302 std::shared_ptr<Expression> offset; // nullptr if just *base
303 int scale = 1; // For array indexing: base + offset * scale
304 
305 explicit MemoryExpr(std::shared_ptr<Expression> b) : base(std::move(b)) {}
306 
307 void accept(ASTVisitor *visitor) override;
308 
309 std::string toString() const override {
310 if (offset) {
311 return base->toString() + "[" + offset->toString() + "]";
312 }
313 return "*" + base->toString();
314 }
315};
316 
317// ─────────────────────────────────────────────────────────────────────────────
318// MemberAccessExpr: Struct member access (base->member)
319// ─────────────────────────────────────────────────────────────────────────────
320 
321class MemberAccessExpr : public Expression {
322public:
323 std::shared_ptr<Expression> base; // The pointer or struct
324 std::string memberName;
325 int offset; // Debug info: the offset being accessed
326 
327 MemberAccessExpr(std::shared_ptr<Expression> b, const std::string &m, int off)
328 : base(std::move(b)), memberName(m), offset(off) {}
329 
330 void accept(ASTVisitor *visitor) override;
331 
332 std::string toString() const override {
333 // Assuming pointer access (->) mostly for this task.
334 // If base is not a pointer, it should be dot (.).
335 // Logic can be refined in Codegen.
336 return base->toString() + "->" + memberName;
337 }
338};
339 
340// ─────────────────────────────────────────────────────────────────────────────
341// CastExpr: Explicit Type Cast ( (int64_t)var )
342// ─────────────────────────────────────────────────────────────────────────────
343 
344class CastExpr : public Expression {
345public:
346 std::shared_ptr<Expression> expr;
347 std::string targetType;
348 
349 CastExpr(std::shared_ptr<Expression> e, const std::string &type)
350 : expr(std::move(e)), targetType(type) {}
351 
352 void accept(ASTVisitor *visitor) override;
353 
354 std::string toString() const override {
355 return "(" + targetType + ")" + expr->toString();
356 }
357};
358 
359// ╔══════════════════════════════════════════════════════════════════════════╗
360// ║ STATEMENT HIERARCHY ║
361// ╠══════════════════════════════════════════════════════════════════════════╣
362// ║ Statement ║
363// ║ ├── CompoundStatement { stmt1; stmt2; ... } ║
364// ║ ├── ExpressionStatement expr; ║
365// ║ ├── IfStatement if (cond) { } else { } ║
366// ║ ├── WhileStatement while (cond) { } ║
367// ║ ├── DoWhileStatement do { } while (cond); ║
368// ║ ├── ForStatement for (init; cond; step) { } ║
369// ║ ├── ReturnStatement return expr; ║
370// ║ ├── BreakStatement break; ║
371// ║ ├── ContinueStatement continue; ║
372// ║ ├── GotoStatement goto label; (fallback for irreducible CFG) ║
373// ║ └── LabelStatement label: ║
374// ╚══════════════════════════════════════════════════════════════════════════╝
375 
376class Statement : public ASTNode {
377public:
378 std::string comment; // Optional inline comment
379};
380 
381// ─────────────────────────────────────────────────────────────────────────────
382// CompoundStatement: Block of statements { ... }
383// ─────────────────────────────────────────────────────────────────────────────
384 
385class CompoundStatement : public Statement {
386public:
387 std::vector<std::shared_ptr<Statement>> statements;
388 
389 void accept(ASTVisitor *visitor) override;
390 
391 void addStatement(std::shared_ptr<Statement> stmt) {
392 statements.push_back(std::move(stmt));
393 }
394};
395 
396// ─────────────────────────────────────────────────────────────────────────────
397// ExpressionStatement: Expression as statement (x = 5;)
398// ─────────────────────────────────────────────────────────────────────────────
399 
400class ExpressionStatement : public Statement {
401public:
402 std::shared_ptr<Expression> expression;
403 
404 explicit ExpressionStatement(std::shared_ptr<Expression> expr)
405 : expression(std::move(expr)) {}
406 
407 void accept(ASTVisitor *visitor) override;
408};
409 
410// ─────────────────────────────────────────────────────────────────────────────
411// IfStatement: if (condition) then-branch [else else-branch]
412// ─────────────────────────────────────────────────────────────────────────────
413 
414class IfStatement : public Statement {
415public:
416 std::shared_ptr<Expression> condition;
417 std::shared_ptr<Statement> thenBranch;
418 std::shared_ptr<Statement> elseBranch; // nullptr if no else
419 
420 IfStatement(std::shared_ptr<Expression> cond,
421 std::shared_ptr<Statement> thenStmt,
422 std::shared_ptr<Statement> elseStmt = nullptr)
423 : condition(std::move(cond)), thenBranch(std::move(thenStmt)),
424 elseBranch(std::move(elseStmt)) {}
425 
426 void accept(ASTVisitor *visitor) override;
427};
428 
429// ─────────────────────────────────────────────────────────────────────────────
430// WhileStatement: while (condition) { body }
431// ─────────────────────────────────────────────────────────────────────────────
432 
433class WhileStatement : public Statement {
434public:
435 std::shared_ptr<Expression> condition;
436 std::shared_ptr<Statement> body;
437 
438 WhileStatement(std::shared_ptr<Expression> cond, std::shared_ptr<Statement> b)
439 : condition(std::move(cond)), body(std::move(b)) {}
440 
441 void accept(ASTVisitor *visitor) override;
442};
443 
444// ─────────────────────────────────────────────────────────────────────────────
445// DoWhileStatement: do { body } while (condition);
446// ─────────────────────────────────────────────────────────────────────────────
447 
448class DoWhileStatement : public Statement {
449public:
450 std::shared_ptr<Statement> body;
451 std::shared_ptr<Expression> condition;
452 
453 DoWhileStatement(std::shared_ptr<Statement> b,
454 std::shared_ptr<Expression> cond)
455 : body(std::move(b)), condition(std::move(cond)) {}
456 
457 void accept(ASTVisitor *visitor) override;
458};
459 
460// ─────────────────────────────────────────────────────────────────────────────
461// ForStatement: for (init; cond; step) { body }
462// ─────────────────────────────────────────────────────────────────────────────
463 
464class ForStatement : public Statement {
465public:
466 std::shared_ptr<Expression> init;
467 std::shared_ptr<Expression> condition;
468 std::shared_ptr<Expression> step;
469 std::shared_ptr<Statement> body;
470 
471 void accept(ASTVisitor *visitor) override;
472};
473 
474// ─────────────────────────────────────────────────────────────────────────────
475// ReturnStatement: return [expr];
476// ─────────────────────────────────────────────────────────────────────────────
477 
478class ReturnStatement : public Statement {
479public:
480 std::shared_ptr<Expression> value; // nullptr for void return
481 
482 ReturnStatement() = default;
483 explicit ReturnStatement(std::shared_ptr<Expression> v)
484 : value(std::move(v)) {}
485 
486 void accept(ASTVisitor *visitor) override;
487};
488 
489// ─────────────────────────────────────────────────────────────────────────────
490// BreakStatement: break;
491// ─────────────────────────────────────────────────────────────────────────────
492 
493class BreakStatement : public Statement {
494public:
495 void accept(ASTVisitor *visitor) override;
496};
497 
498// ─────────────────────────────────────────────────────────────────────────────
499// ContinueStatement: continue;
500// ─────────────────────────────────────────────────────────────────────────────
501 
502class ContinueStatement : public Statement {
503public:
504 void accept(ASTVisitor *visitor) override;
505};
506 
507// ─────────────────────────────────────────────────────────────────────────────
508// GotoStatement: goto label; (fallback for irreducible control flow)
509// ─────────────────────────────────────────────────────────────────────────────
510 
511class GotoStatement : public Statement {
512public:
513 std::string label;
514 uint64_t targetAddress = 0;
515 
516 explicit GotoStatement(const std::string &lbl) : label(lbl) {}
517 explicit GotoStatement(uint64_t addr) : targetAddress(addr) {
518 std::stringstream ss;
519 ss << "loc_" << std::hex << addr;
520 label = ss.str();
521 }
522 
523 void accept(ASTVisitor *visitor) override;
524};
525 
526// ─────────────────────────────────────────────────────────────────────────────
527// LabelStatement: label:
528// ─────────────────────────────────────────────────────────────────────────────
529 
530class LabelStatement : public Statement {
531public:
532 std::string name;
533 uint64_t address = 0;
534 
535 explicit LabelStatement(const std::string &n) : name(n) {}
536 explicit LabelStatement(uint64_t addr) : address(addr) {
537 std::stringstream ss;
538 ss << "loc_" << std::hex << addr;
539 name = ss.str();
540 }
541 
542 void accept(ASTVisitor *visitor) override;
543};
544 
545// ╔═══════════════════════════════════════════════════════════════════════════╗
546// ║ LOCAL VARIABLE ║
547// ╚═══════════════════════════════════════════════════════════════════════════╝
548 
549struct LocalVariable {
550 std::string name;
551 Expression::Type type =
552 Expression::Type::Int32; // Legacy enum for simple cases
553 // NEW: Advanced Type System Integration
554 std::shared_ptr<ShadPKG::Decompiler::Analysis::Type> complexType;
555 
556 int stackOffset = 0; // Original [rbp - offset]
557 int size = 4; // Size in bytes
558 bool isParameter = false;
559 int paramIndex = -1; // If parameter, which one (0, 1, 2...)
560 
561 // If complexType is set, use it. Otherwise fallback to enum.
562 std::string getTypeName() const;
563};
564 
565// ╔═══════════════════════════════════════════════════════════════════════════╗
566// ║ FUNCTION AST ║
567// ╚═══════════════════════════════════════════════════════════════════════════╝
568 
569class FunctionAST {
570public:
571 std::string name;
572 uint64_t address = 0;
573 std::string returnType = "void";
574 
575 std::vector<LocalVariable> parameters;
576 std::vector<LocalVariable> locals;
577 std::shared_ptr<CompoundStatement> body;
578 
579 FunctionAST() : body(std::make_shared<CompoundStatement>()) {}
580};
581 
582// ╔═══════════════════════════════════════════════════════════════════════════╗
583// ║ AST VISITOR ║
584// ║ ║
585// ║ Visitor pattern for traversing AST without modifying node classes. ║
586// ║ Subclass and override visit() methods to implement code generation, ║
587// ║ optimization passes, or analysis. ║
588// ╚═══════════════════════════════════════════════════════════════════════════╝
589 
590// ─────────────────────────────────────────────────────────────────────────────
591// CaseStmt: case val1: case val2: ... body
592// ─────────────────────────────────────────────────────────────────────────────
593 
594class CaseStmt : public Statement {
595public:
596 std::vector<int64_t> values; // Multiple values for stacked cases
597 bool isDefault = false;
598 std::shared_ptr<CompoundStatement> body;
599 
600 CaseStmt() : body(std::make_shared<CompoundStatement>()) {}
601 
602 void accept(ASTVisitor *visitor) override;
603};
604 
605// ─────────────────────────────────────────────────────────────────────────────
606// SwitchStmt: switch (expr) { cases... }
607// ─────────────────────────────────────────────────────────────────────────────
608 
609class SwitchStmt : public Statement {
610public:
611 std::shared_ptr<Expression> condition;
612 std::vector<std::shared_ptr<CaseStmt>> cases;
613 
614 explicit SwitchStmt(std::shared_ptr<Expression> cond)
615 : condition(std::move(cond)) {}
616 
617 void accept(ASTVisitor *visitor) override;
618};
619 
620class ASTVisitor {
621public:
622 virtual ~ASTVisitor() = default;
623 
624 // Expressions
625 virtual void visit(ConstantExpr *node) = 0;
626 virtual void visit(VariableExpr *node) = 0;
627 virtual void visit(BinaryExpr *node) = 0;
628 virtual void visit(UnaryExpr *node) = 0;
629 virtual void visit(CallExpr *node) = 0;
630 virtual void visit(MemoryExpr *node) = 0;
631 virtual void visit(MemberAccessExpr *node) = 0;
632 virtual void visit(CastExpr *node) = 0; // NEW
633 
634 // Statements
635 virtual void visit(CompoundStatement *node) = 0;
636 virtual void visit(ExpressionStatement *node) = 0;
637 virtual void visit(IfStatement *node) = 0;
638 virtual void visit(WhileStatement *node) = 0;
639 virtual void visit(DoWhileStatement *node) = 0;
640 virtual void visit(ForStatement *node) = 0;
641 virtual void visit(ReturnStatement *node) = 0;
642 virtual void visit(BreakStatement *node) = 0;
643 virtual void visit(ContinueStatement *node) = 0;
644 virtual void visit(GotoStatement *node) = 0;
645 virtual void visit(LabelStatement *node) = 0;
646 virtual void visit(CaseStmt *node) = 0; // NEW
647 virtual void visit(SwitchStmt *node) = 0; // NEW
648};
649 
650// ╔═══════════════════════════════════════════════════════════════════════════╗
651// ║ VISITOR ACCEPT IMPLEMENTATIONS ║
652// ╚═══════════════════════════════════════════════════════════════════════════╝
653 
654inline void ConstantExpr::accept(ASTVisitor *v) { v->visit(this); }
655inline void VariableExpr::accept(ASTVisitor *v) { v->visit(this); }
656inline void BinaryExpr::accept(ASTVisitor *v) { v->visit(this); }
657inline void UnaryExpr::accept(ASTVisitor *v) { v->visit(this); }
658inline void CallExpr::accept(ASTVisitor *v) { v->visit(this); }
659inline void MemoryExpr::accept(ASTVisitor *v) { v->visit(this); }
660inline void MemberAccessExpr::accept(ASTVisitor *v) { v->visit(this); } // NEW
661inline void CastExpr::accept(ASTVisitor *v) { v->visit(this); } // NEW
662inline void CompoundStatement::accept(ASTVisitor *v) { v->visit(this); }
663inline void ExpressionStatement::accept(ASTVisitor *v) { v->visit(this); }
664inline void IfStatement::accept(ASTVisitor *v) { v->visit(this); }
665inline void WhileStatement::accept(ASTVisitor *v) { v->visit(this); }
666inline void DoWhileStatement::accept(ASTVisitor *v) { v->visit(this); }
667inline void ForStatement::accept(ASTVisitor *v) { v->visit(this); }
668inline void ReturnStatement::accept(ASTVisitor *v) { v->visit(this); }
669inline void BreakStatement::accept(ASTVisitor *v) { v->visit(this); }
670inline void ContinueStatement::accept(ASTVisitor *v) { v->visit(this); }
671inline void GotoStatement::accept(ASTVisitor *v) { v->visit(this); }
672inline void LabelStatement::accept(ASTVisitor *v) { v->visit(this); }
673inline void CaseStmt::accept(ASTVisitor *v) { v->visit(this); }
674inline void SwitchStmt::accept(ASTVisitor *v) { v->visit(this); }
675 
676} // namespace ShadPKG::Decompiler::AST
677