Add basic protection aginst circular references in pycdas and pycdc.

This fixes the last case of fuzzer errors detected by #572.
This commit is contained in:
Michael Hansen
2025-08-28 16:42:03 -07:00
parent 38799f5cfb
commit 577720302e
2 changed files with 32 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
#include <cstring> #include <cstring>
#include <cstdint> #include <cstdint>
#include <stdexcept> #include <stdexcept>
#include <unordered_set>
#include "ASTree.h" #include "ASTree.h"
#include "FastStack.h" #include "FastStack.h"
#include "pyc_numeric.h" #include "pyc_numeric.h"
@@ -2779,6 +2780,8 @@ void print_formatted_value(PycRef<ASTFormattedValue> formatted_value, PycModule*
pyc_output << "}"; pyc_output << "}";
} }
static std::unordered_set<ASTNode *> node_seen;
void print_src(PycRef<ASTNode> node, PycModule* mod, std::ostream& pyc_output) void print_src(PycRef<ASTNode> node, PycModule* mod, std::ostream& pyc_output)
{ {
if (node == NULL) { if (node == NULL) {
@@ -2787,6 +2790,12 @@ void print_src(PycRef<ASTNode> node, PycModule* mod, std::ostream& pyc_output)
return; return;
} }
if (node_seen.find((ASTNode *)node) != node_seen.end()) {
fputs("WARNING: Circular reference detected\n", stderr);
return;
}
node_seen.insert((ASTNode *)node);
switch (node->type()) { switch (node->type()) {
case ASTNode::NODE_BINARY: case ASTNode::NODE_BINARY:
case ASTNode::NODE_COMPARE: case ASTNode::NODE_COMPARE:
@@ -3442,10 +3451,12 @@ void print_src(PycRef<ASTNode> node, PycModule* mod, std::ostream& pyc_output)
pyc_output << "<NODE:" << node->type() << ">"; pyc_output << "<NODE:" << node->type() << ">";
fprintf(stderr, "Unsupported Node type: %d\n", node->type()); fprintf(stderr, "Unsupported Node type: %d\n", node->type());
cleanBuild = false; cleanBuild = false;
node_seen.erase((ASTNode *)node);
return; return;
} }
cleanBuild = true; cleanBuild = true;
node_seen.erase((ASTNode *)node);
} }
bool print_docstring(PycRef<PycObject> obj, int indent, PycModule* mod, bool print_docstring(PycRef<PycObject> obj, int indent, PycModule* mod,
@@ -3462,8 +3473,16 @@ bool print_docstring(PycRef<PycObject> obj, int indent, PycModule* mod,
return false; return false;
} }
static std::unordered_set<PycCode *> code_seen;
void decompyle(PycRef<PycCode> code, PycModule* mod, std::ostream& pyc_output) void decompyle(PycRef<PycCode> code, PycModule* mod, std::ostream& pyc_output)
{ {
if (code_seen.find((PycCode *)code) != code_seen.end()) {
fputs("WARNING: Circular reference detected\n", stderr);
return;
}
code_seen.insert((PycCode *)code);
PycRef<ASTNode> source = BuildFromCode(code, mod); PycRef<ASTNode> source = BuildFromCode(code, mod);
PycRef<ASTNodeList> clean = source.cast<ASTNodeList>(); PycRef<ASTNodeList> clean = source.cast<ASTNodeList>();
@@ -3557,4 +3576,6 @@ void decompyle(PycRef<PycCode> code, PycModule* mod, std::ostream& pyc_output)
start_line(cur_indent, pyc_output); start_line(cur_indent, pyc_output);
pyc_output << "# WARNING: Decompyle incomplete\n"; pyc_output << "# WARNING: Decompyle incomplete\n";
} }
code_seen.erase((PycCode *)code);
} }

View File

@@ -4,6 +4,7 @@
#include <string> #include <string>
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <unordered_set>
#include "pyc_module.h" #include "pyc_module.h"
#include "pyc_numeric.h" #include "pyc_numeric.h"
#include "bytecode.h" #include "bytecode.h"
@@ -73,6 +74,8 @@ static void iprintf(std::ostream& pyc_output, int indent, const char* fmt, ...)
va_end(varargs); va_end(varargs);
} }
static std::unordered_set<PycObject *> out_seen;
void output_object(PycRef<PycObject> obj, PycModule* mod, int indent, void output_object(PycRef<PycObject> obj, PycModule* mod, int indent,
unsigned flags, std::ostream& pyc_output) unsigned flags, std::ostream& pyc_output)
{ {
@@ -81,6 +84,12 @@ void output_object(PycRef<PycObject> obj, PycModule* mod, int indent,
return; return;
} }
if (out_seen.find((PycObject *)obj) != out_seen.end()) {
fputs("WARNING: Circular reference detected\n", stderr);
return;
}
out_seen.insert((PycObject *)obj);
switch (obj->type()) { switch (obj->type()) {
case PycObject::TYPE_CODE: case PycObject::TYPE_CODE:
case PycObject::TYPE_CODE2: case PycObject::TYPE_CODE2:
@@ -246,6 +255,8 @@ void output_object(PycRef<PycObject> obj, PycModule* mod, int indent,
default: default:
iprintf(pyc_output, indent, "<TYPE: %d>\n", obj->type()); iprintf(pyc_output, indent, "<TYPE: %d>\n", obj->type());
} }
out_seen.erase((PycObject *)obj);
} }
int main(int argc, char* argv[]) int main(int argc, char* argv[])