Make cast() checked by default and add try_cast() for cases where a cast

is not required to be successful.
This commit is contained in:
Michael Hansen
2022-12-01 16:13:31 -08:00
parent 305494c4b2
commit ffeabc3d3f
6 changed files with 36 additions and 38 deletions

View File

@@ -841,7 +841,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
if (isUninitAsyncFor) {
auto tryBlock = container->nodes().front().cast<ASTBlock>();
if (!tryBlock->nodes().empty() && tryBlock->blktype() == ASTBlock::BLK_TRY) {
auto store = tryBlock->nodes().front().cast<ASTStore>();
auto store = tryBlock->nodes().front().try_cast<ASTStore>();
if (store) {
asyncForBlock.cast<ASTIterBlock>()->setIndex(store->dest());
}
@@ -1798,7 +1798,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
{
PycRef<ASTPrint> printNode;
if (curblock->size() > 0 && curblock->nodes().back().type() == ASTNode::NODE_PRINT)
printNode = curblock->nodes().back().cast<ASTPrint>();
printNode = curblock->nodes().back().try_cast<ASTPrint>();
if (printNode && printNode->stream() == nullptr && !printNode->eol())
printNode->add(stack.top());
else
@@ -1813,7 +1813,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
PycRef<ASTPrint> printNode;
if (curblock->size() > 0 && curblock->nodes().back().type() == ASTNode::NODE_PRINT)
printNode = curblock->nodes().back().cast<ASTPrint>();
printNode = curblock->nodes().back().try_cast<ASTPrint>();
if (printNode && printNode->stream() == stream && !printNode->eol())
printNode->add(stack.top());
else
@@ -1826,7 +1826,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
{
PycRef<ASTPrint> printNode;
if (curblock->size() > 0 && curblock->nodes().back().type() == ASTNode::NODE_PRINT)
printNode = curblock->nodes().back().cast<ASTPrint>();
printNode = curblock->nodes().back().try_cast<ASTPrint>();
if (printNode && printNode->stream() == nullptr && !printNode->eol())
printNode->setEol(true);
else
@@ -1841,7 +1841,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
PycRef<ASTPrint> printNode;
if (curblock->size() > 0 && curblock->nodes().back().type() == ASTNode::NODE_PRINT)
printNode = curblock->nodes().back().cast<ASTPrint>();
printNode = curblock->nodes().back().try_cast<ASTPrint>();
if (printNode && printNode->stream() == stream && !printNode->eol())
printNode->setEol(true);
else
@@ -2152,7 +2152,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
if (curblock->blktype() == ASTBlock::BLK_FOR
&& !curblock->inited()) {
PycRef<ASTTuple> tuple = tup.cast<ASTTuple>();
PycRef<ASTTuple> tuple = tup.try_cast<ASTTuple>();
if (tuple != NULL)
tuple->setRequireParens(false);
curblock.cast<ASTIterBlock>()->setIndex(tup);
@@ -2211,7 +2211,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
if (curblock->blktype() == ASTBlock::BLK_FOR
&& !curblock->inited()) {
PycRef<ASTTuple> tuple = tup.cast<ASTTuple>();
PycRef<ASTTuple> tuple = tup.try_cast<ASTTuple>();
if (tuple != NULL)
tuple->setRequireParens(false);
curblock.cast<ASTIterBlock>()->setIndex(tup);
@@ -2253,7 +2253,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
if (curblock->blktype() == ASTBlock::BLK_FOR
&& !curblock->inited()) {
PycRef<ASTTuple> tuple = tup.cast<ASTTuple>();
PycRef<ASTTuple> tuple = tup.try_cast<ASTTuple>();
if (tuple != NULL)
tuple->setRequireParens(false);
curblock.cast<ASTIterBlock>()->setIndex(tup);
@@ -2735,7 +2735,7 @@ void print_src(PycRef<ASTNode> node, PycModule* mod)
if (param.first.type() == ASTNode::NODE_NAME) {
fprintf(pyc_output, "%s = ", param.first.cast<ASTName>()->name()->value());
} else {
PycRef<PycString> str_name = param.first.cast<ASTObject>()->object().require_cast<PycString>();
PycRef<PycString> str_name = param.first.cast<ASTObject>()->object().cast<PycString>();
fprintf(pyc_output, "%s = ", str_name->value());
}
print_src(param.second, mod);
@@ -2902,20 +2902,18 @@ void print_src(PycRef<ASTNode> node, PycModule* mod)
break;
case ASTNode::NODE_BLOCK:
{
if (node.cast<ASTBlock>()->blktype() == ASTBlock::BLK_ELSE
&& node.cast<ASTBlock>()->size() == 0)
PycRef<ASTBlock> blk = node.cast<ASTBlock>();
if (blk->blktype() == ASTBlock::BLK_ELSE && blk->size() == 0)
break;
if (node.cast<ASTBlock>()->blktype() == ASTBlock::BLK_CONTAINER) {
if (blk->blktype() == ASTBlock::BLK_CONTAINER) {
end_line();
PycRef<ASTBlock> blk = node.cast<ASTBlock>();
print_block(blk, mod);
end_line();
break;
}
fprintf(pyc_output, "%s", node.cast<ASTBlock>()->type_str());
PycRef<ASTBlock> blk = node.cast<ASTBlock>();
fprintf(pyc_output, "%s", blk->type_str());
if (blk->blktype() == ASTBlock::BLK_IF
|| blk->blktype() == ASTBlock::BLK_ELIF
|| blk->blktype() == ASTBlock::BLK_WHILE) {
@@ -2937,7 +2935,7 @@ void print_src(PycRef<ASTNode> node, PycModule* mod)
} else if (blk->blktype() == ASTBlock::BLK_WITH) {
fputs(" ", pyc_output);
print_src(blk.cast<ASTWithBlock>()->expr(), mod);
PycRef<ASTNode> var = blk.cast<ASTWithBlock>()->var();
PycRef<ASTNode> var = blk.try_cast<ASTWithBlock>()->var();
if (var != NULL) {
fputs(" as ", pyc_output);
print_src(var, mod);
@@ -3248,8 +3246,8 @@ void print_src(PycRef<ASTNode> node, PycModule* mod)
print_src(dest, mod);
}
}
} else if (src.type() == ASTNode::NODE_BINARY &&
src.cast<ASTBinary>()->is_inplace() == true) {
} else if (src.type() == ASTNode::NODE_BINARY
&& src.cast<ASTBinary>()->is_inplace()) {
print_src(src, mod);
} else {
print_src(dest, mod);
@@ -3325,7 +3323,7 @@ void print_src(PycRef<ASTNode> node, PycModule* mod)
PycRef<ASTTernary> ternary = node.cast<ASTTernary>();
//fputs("(", pyc_output);
print_src(ternary->if_expr(), mod);
const auto if_block = ternary->if_block().require_cast<ASTCondBlock>();
const auto if_block = ternary->if_block().cast<ASTCondBlock>();
fputs(" if ", pyc_output);
if (if_block->negative())
fputs("not ", pyc_output);
@@ -3405,7 +3403,7 @@ void decompyle(PycRef<PycCode> code, PycModule* mod)
if (store->src().type() == ASTNode::NODE_OBJECT
&& store->dest().type() == ASTNode::NODE_NAME) {
PycRef<ASTObject> src = store->src().cast<ASTObject>();
PycRef<PycString> srcString = src->object().cast<PycString>();
PycRef<PycString> srcString = src->object().try_cast<PycString>();
PycRef<ASTName> dest = store->dest().cast<ASTName>();
if (srcString != nullptr && srcString->isEqual(code->name().cast<PycObject>())
&& dest->name()->isEqual("__qualname__")) {

View File

@@ -40,27 +40,27 @@ void PycCode::load(PycData* stream, PycModule* mod)
else
m_flags = 0;
m_code = LoadObject(stream, mod).require_cast<PycString>();
m_consts = LoadObject(stream, mod).require_cast<PycSequence>();
m_names = LoadObject(stream, mod).require_cast<PycSequence>();
m_code = LoadObject(stream, mod).cast<PycString>();
m_consts = LoadObject(stream, mod).cast<PycSequence>();
m_names = LoadObject(stream, mod).cast<PycSequence>();
if (mod->verCompare(1, 3) >= 0)
m_varNames = LoadObject(stream, mod).require_cast<PycSequence>();
m_varNames = LoadObject(stream, mod).cast<PycSequence>();
else
m_varNames = new PycTuple;
if (mod->verCompare(2, 1) >= 0)
m_freeVars = LoadObject(stream, mod).require_cast<PycSequence>();
m_freeVars = LoadObject(stream, mod).cast<PycSequence>();
else
m_freeVars = new PycTuple;
if (mod->verCompare(2, 1) >= 0)
m_cellVars = LoadObject(stream, mod).require_cast<PycSequence>();
m_cellVars = LoadObject(stream, mod).cast<PycSequence>();
else
m_cellVars = new PycTuple;
m_fileName = LoadObject(stream, mod).require_cast<PycString>();
m_name = LoadObject(stream, mod).require_cast<PycString>();
m_fileName = LoadObject(stream, mod).cast<PycString>();
m_name = LoadObject(stream, mod).cast<PycString>();
if (mod->verCompare(1, 5) >= 0 && mod->verCompare(2, 3) < 0)
m_firstLine = stream->get16();
@@ -68,7 +68,7 @@ void PycCode::load(PycData* stream, PycModule* mod)
m_firstLine = stream->get32();
if (mod->verCompare(1, 5) >= 0)
m_lnTable = LoadObject(stream, mod).require_cast<PycString>();
m_lnTable = LoadObject(stream, mod).cast<PycString>();
else
m_lnTable = new PycString;
}

View File

@@ -58,19 +58,19 @@ public:
PycRef<PycString> getName(int idx) const
{
return m_names->get(idx).require_cast<PycString>();
return m_names->get(idx).cast<PycString>();
}
PycRef<PycString> getVarName(int idx) const
{
return m_varNames->get(idx).require_cast<PycString>();
return m_varNames->get(idx).cast<PycString>();
}
PycRef<PycString> getCellVar(int idx) const
{
return (idx >= m_cellVars->size())
? m_freeVars->get(idx - m_cellVars->size()).require_cast<PycString>()
: m_cellVars->get(idx).require_cast<PycString>();
? m_freeVars->get(idx - m_cellVars->size()).cast<PycString>()
: m_cellVars->get(idx).cast<PycString>();
}
const globals_t& getGlobals() const { return m_globalsUsed; }

View File

@@ -213,7 +213,7 @@ void PycModule::loadFromFile(const char* filename)
in.get32(); // Size parameter added in Python 3.3
}
m_code = LoadObject(&in, this).require_cast<PycCode>();
m_code = LoadObject(&in, this).cast<PycCode>();
}
void PycModule::loadFromMarshalledFile(const char* filename, int major, int minor)
@@ -230,7 +230,7 @@ void PycModule::loadFromMarshalledFile(const char* filename, int major, int mino
m_maj = major;
m_min = minor;
m_unicode = (major >= 3);
m_code = LoadObject(&in, this).require_cast<PycCode>();
m_code = LoadObject(&in, this).cast<PycCode>();
}
PycRef<PycString> PycModule::getIntern(int ref) const

View File

@@ -70,10 +70,10 @@ public:
inline int type() const;
template <class _Cast>
PycRef<_Cast> cast() const { return dynamic_cast<_Cast*>(m_obj); }
PycRef<_Cast> try_cast() const { return dynamic_cast<_Cast*>(m_obj); }
template <class _Cast>
PycRef<_Cast> require_cast() const
PycRef<_Cast> cast() const
{
_Cast* result = dynamic_cast<_Cast*>(m_obj);
if (!result)

View File

@@ -304,7 +304,7 @@ int main(int argc, char* argv[])
fprintf(pyc_output, "%s (Python %d.%d%s)\n", dispname, mod.majorVer(), mod.minorVer(),
(mod.majorVer() < 3 && mod.isUnicode()) ? " -U" : "");
try {
output_object(mod.code().cast<PycObject>(), &mod, 0);
output_object(mod.code().try_cast<PycObject>(), &mod, 0);
} catch (std::exception& ex) {
fprintf(stderr, "Error disassembling %s: %s\n", infile, ex.what());
return 1;