diff --git a/ASTNode.h b/ASTNode.h index 26df413..feef017 100644 --- a/ASTNode.h +++ b/ASTNode.h @@ -16,7 +16,7 @@ public: NODE_CONVERT, NODE_KEYWORD, NODE_RAISE, NODE_EXEC, NODE_BLOCK, NODE_COMPREHENSION, NODE_LOADBUILDCLASS, NODE_AWAITABLE, NODE_FORMATTEDVALUE, NODE_JOINEDSTR, NODE_CONST_MAP, - NODE_ANNOTATED_VAR, NODE_CHAINSTORE, + NODE_ANNOTATED_VAR, NODE_CHAINSTORE, NODE_TERNARY, // Empty node types NODE_LOCALS, @@ -699,4 +699,33 @@ private: PycRef m_type; }; +class ASTTernary : public ASTNode +{ +public: + ASTTernary(PycRef if_block, PycRef if_expr, PycRef else_expr) + : ASTNode(NODE_TERNARY), + m_if_block(std::move(if_block)), + m_if_expr(std::move(if_expr)), + m_else_expr(std::move(else_expr)) + { + } + const PycRef& if_block() const noexcept + { + return m_if_block; + } + const PycRef& if_expr() const noexcept + { + return m_if_expr; + } + const PycRef& else_expr() const noexcept + { + return m_else_expr; + } + +private: + PycRef m_if_block; // contains "condition" and "negative" + PycRef m_if_expr; + PycRef m_else_expr; +}; + #endif diff --git a/ASTree.cpp b/ASTree.cpp index 247d22a..4e597a1 100644 --- a/ASTree.cpp +++ b/ASTree.cpp @@ -29,6 +29,43 @@ static bool printDocstringAndGlobals = false; /* Use this to keep track of whether we need to print a class or module docstring */ static bool printClassDocstring = true; +// shortcut for all top/pop calls +PycRef StackPopTop(FastStack& stack) +{ + const auto node{ stack.top() }; + stack.pop(); + return node; +} + +/* compiler generates very, VERY similar byte code for if/else statement block and if-expression + * statement + * if a: b = 1 + * else: b = 2 + * expression: + * b = 1 if a else 2 + * (see for instance https://stackoverflow.com/a/52202007) + * here, try to guess if just finished else statement is part of if-expression (ternary operator) + * if it is, remove statements from the block and put a ternary node on top of stack + */ +void CheckIfExpr(FastStack& stack, PycRef curblock) +{ + if (stack.empty()) + return; + if (curblock->nodes().size() < 2) + return; + auto rit{ curblock->nodes().crbegin() }; + ++rit; // the last is "else" block, the one before should be "if" (could be "for", ...) + if ((*rit)->type() != ASTNode::NODE_BLOCK || + (*rit).cast()->blktype() != ASTBlock::BLK_IF) + return; + auto else_expr{ StackPopTop(stack) }; + curblock->removeLast(); + auto if_block{ curblock->nodes().back() }; + auto if_expr{ StackPopTop(stack) }; + curblock->removeLast(); + stack.push(new ASTTernary(std::move(if_block), std::move(if_expr), std::move(else_expr))); +} + PycRef BuildFromCode(PycRef code, PycModule* mod) { PycBuffer source(code->code()->value(), code->code()->length()); @@ -109,6 +146,8 @@ PycRef BuildFromCode(PycRef code, PycModule* mod) curblock->append(prev.cast()); prev = curblock; + + CheckIfExpr(stack, curblock); } } @@ -1363,19 +1402,9 @@ PycRef BuildFromCode(PycRef code, PycModule* mod) break; } - if ((curblock->blktype() == ASTBlock::BLK_WHILE - && !curblock->inited()) - || (curblock->blktype() == ASTBlock::BLK_IF - && curblock->size() == 0)) { - PycRef fakeint = new PycInt(1); - PycRef truthy = new ASTObject(fakeint); - - stack.push(truthy); - break; - } - if (!stack_hist.empty()) { - stack = stack_hist.top(); + if (stack.empty()) // if it's part of if-expression, TOS at the moment is the result of "if" part + stack = stack_hist.top(); stack_hist.pop(); } @@ -3256,6 +3285,26 @@ void print_src(PycRef node, PycModule* mod) print_src(type, mod); } break; + case ASTNode::NODE_TERNARY: + { + /* parenthesis might not be needed, + * but when if-expr is part of numerical expression, ternary has the LOWEST precedence + * print(a + b if False else c) + * output is c, not a+c (a+b is calculated first) + */ + PycRef ternary = node.cast(); + fputs("( ", pyc_output); + print_src(ternary->if_expr(), mod); + const auto if_block = ternary->if_block().require_cast(); + fputs(" if ", pyc_output); + if (if_block->negative()) + fputs("not ", pyc_output); + print_src(if_block->cond(), mod); + fputs(" else ", pyc_output); + print_src(ternary->else_expr(), mod); + fputs(" )", pyc_output); + } + break; default: fprintf(pyc_output, "", node->type()); fprintf(stderr, "Unsupported Node type: %d\n", node->type()); diff --git a/FastStack.h b/FastStack.h index dd39464..45f8ed5 100644 --- a/FastStack.h +++ b/FastStack.h @@ -40,6 +40,11 @@ public: return nullptr; } + bool empty() const + { + return m_ptr == -1; + } + private: std::vector> m_stack; int m_ptr;