Fix several undefined behavior issues identified by @nrathaus.

Fixes #147.
This commit is contained in:
Michael Hansen
2018-01-28 14:33:26 -08:00
parent a9a362254e
commit bf60a5831b
12 changed files with 152 additions and 55 deletions

View File

@@ -65,7 +65,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
/* Store the current stack for the except/finally statement(s) */
stack_hist.push(stack);
PycRef<ASTBlock> tryblock = new ASTBlock(ASTBlock::BLK_TRY, curblock->end(), true);
blocks.push(tryblock.cast<ASTBlock>());
blocks.push(tryblock);
curblock = blocks.top();
} else if (else_pop
&& opcode != Pyc::JUMP_FORWARD_A
@@ -555,7 +555,8 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
{
PycRef<PycString> varname = code->getName(operand);
if (varname->value()[0] == '_' && varname->value()[1] == '[') {
if (varname->length() >= 2 && varname->value()[0] == '_'
&& varname->value()[1] == '[') {
/* Don't show deletes that are a result of list comps. */
break;
}
@@ -1294,7 +1295,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
}
break;
case Pyc::LOAD_DEREF_A:
stack.push(new ASTName(code->getCellVar(operand).cast<PycString>()));
stack.push(new ASTName(code->getCellVar(operand)));
break;
case Pyc::LOAD_FAST_A:
if (mod->verCompare(1, 3) < 0)
@@ -1441,7 +1442,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
if (curblock->blktype() == ASTBlock::BLK_WITH) {
curblock.cast<ASTWithBlock>()->setExpr(value);
} else {
curblock.cast<ASTCondBlock>()->init();
curblock->init();
}
break;
} else if (value.type() == ASTNode::NODE_INVALID
@@ -1728,7 +1729,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
case Pyc::STORE_DEREF_A:
{
if (unpack) {
PycRef<ASTNode> name = new ASTName(code->getCellVar(operand).cast<PycString>());
PycRef<ASTNode> name = new ASTName(code->getCellVar(operand));
PycRef<ASTNode> tup = stack.top();
if (tup.type() == ASTNode::NODE_TUPLE) {
@@ -1753,7 +1754,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
} else {
PycRef<ASTNode> value = stack.top();
stack.pop();
PycRef<ASTNode> name = new ASTName(code->getCellVar(operand).cast<PycString>());
PycRef<ASTNode> name = new ASTName(code->getCellVar(operand));
curblock->append(new ASTStore(value, name));
}
}
@@ -1897,7 +1898,8 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
stack.pop();
PycRef<PycString> varname = code->getName(operand);
if (varname->value()[0] == '_' && varname->value()[1] == '[') {
if (varname->length() >= 2 && varname->value()[0] == '_'
&& varname->value()[1] == '[') {
/* Don't show stores of list comp append objects. */
break;
}
@@ -2270,7 +2272,12 @@ void print_src(PycRef<ASTNode> node, PycModule* mod)
for (ASTCall::kwparam_t::const_iterator p = call->kwparams().begin(); p != call->kwparams().end(); ++p) {
if (!first)
fputs(", ", pyc_output);
fprintf(pyc_output, "%s = ", p->first.cast<ASTName>()->name()->value());
if (p->first.type() == ASTNode::NODE_NAME) {
fprintf(pyc_output, "%s = ", p->first.cast<ASTName>()->name()->value());
} else {
PycRef<PycString> str_name = p->first.cast<ASTObject>()->object().require_cast<PycString>();
fprintf(pyc_output, "%s = ", str_name->value());
}
print_src(p->second, mod);
first = false;
}

View File

@@ -78,6 +78,7 @@ int PycBuffer::getBuffer(int bytes, void* buffer)
{
if (m_pos + bytes > m_size)
bytes = m_size - m_pos;
memcpy(buffer, (m_buffer + m_pos), bytes);
if (bytes != 0)
memcpy(buffer, (m_buffer + m_pos), bytes);
return bytes;
}

View File

@@ -11,37 +11,51 @@ void PycCode::load(PycData* stream, PycModule* mod)
if (mod->majorVer() >= 3)
m_kwOnlyArgCount = stream->get32();
else
m_kwOnlyArgCount = 0;
if (mod->verCompare(1, 3) >= 0 && mod->verCompare(2, 3) < 0)
m_numLocals = stream->get16();
else if (mod->verCompare(2, 3) >= 0)
m_numLocals = stream->get32();
else
m_numLocals = 0;
if (mod->verCompare(1, 5) >= 0 && mod->verCompare(2, 3) < 0)
m_stackSize = stream->get16();
else if (mod->verCompare(2, 3) >= 0)
m_stackSize = stream->get32();
else
m_stackSize = 0;
if (mod->verCompare(1, 3) >= 0 && mod->verCompare(2, 3) < 0)
m_flags = stream->get16();
else if (mod->verCompare(2, 3) >= 0)
m_flags = stream->get32();
else
m_flags = 0;
m_code = LoadObject(stream, mod).cast<PycString>();
m_consts = LoadObject(stream, mod).cast<PycTuple>();
m_names = LoadObject(stream, mod).cast<PycTuple>();
m_code = LoadObject(stream, mod).require_cast<PycString>();
m_consts = LoadObject(stream, mod).require_cast<PycTuple>();
m_names = LoadObject(stream, mod).require_cast<PycTuple>();
if (mod->verCompare(1, 3) >= 0)
m_varNames = LoadObject(stream, mod).cast<PycTuple>();
m_varNames = LoadObject(stream, mod).require_cast<PycTuple>();
else
m_varNames = new PycTuple;
if (mod->verCompare(2, 1) >= 0)
m_freeVars = LoadObject(stream, mod).cast<PycTuple>();
m_freeVars = LoadObject(stream, mod).require_cast<PycTuple>();
else
m_freeVars = new PycTuple;
if (mod->verCompare(2, 1) >= 0)
m_cellVars = LoadObject(stream, mod).cast<PycTuple>();
m_cellVars = LoadObject(stream, mod).require_cast<PycTuple>();
else
m_cellVars = new PycTuple;
m_fileName = LoadObject(stream, mod).cast<PycString>();
m_name = LoadObject(stream, mod).cast<PycString>();
m_fileName = LoadObject(stream, mod).require_cast<PycString>();
m_name = LoadObject(stream, mod).require_cast<PycString>();
if (mod->verCompare(1, 5) >= 0 && mod->verCompare(2, 3) < 0)
m_firstLine = stream->get16();
@@ -49,5 +63,7 @@ void PycCode::load(PycData* stream, PycModule* mod)
m_firstLine = stream->get32();
if (mod->verCompare(1, 5) >= 0)
m_lnTable = LoadObject(stream, mod).cast<PycString>();
m_lnTable = LoadObject(stream, mod).require_cast<PycString>();
else
m_lnTable = new PycString;
}

View File

@@ -53,16 +53,16 @@ public:
{ return m_consts->get(idx); }
PycRef<PycString> getName(int idx) const
{ return m_names->get(idx).cast<PycString>(); }
{ return m_names->get(idx).require_cast<PycString>(); }
PycRef<PycString> getVarName(int idx) const
{ return m_varNames->get(idx).cast<PycString>(); }
{ return m_varNames->get(idx).require_cast<PycString>(); }
PycRef<PycString> getCellVar(int idx) const
{
return (idx >= m_cellVars->size())
? m_freeVars->get(idx - m_cellVars->size()).cast<PycString>()
: m_cellVars->get(idx).cast<PycString>();
? m_freeVars->get(idx - m_cellVars->size()).require_cast<PycString>()
: m_cellVars->get(idx).require_cast<PycString>();
}
const globals_t& getGlobals() const { return m_globalsUsed; }

View File

@@ -1,5 +1,6 @@
#include "pyc_module.h"
#include "data.h"
#include <stdexcept>
void PycModule::setVersion(unsigned int magic)
{
@@ -125,7 +126,7 @@ void PycModule::setVersion(unsigned int magic)
break;
case MAGIC_3_5:
/* fall through */
/* fall through */
case MAGIC_3_5_3:
m_maj = 3;
@@ -169,21 +170,31 @@ void PycModule::loadFromFile(const char* filename)
if (verCompare(3, 3) >= 0)
in.get32(); // Size parameter added in Python 3.3
m_code = LoadObject(&in, this).cast<PycCode>();
m_code = LoadObject(&in, this).require_cast<PycCode>();
}
PycRef<PycString> PycModule::getIntern(int ref) const
{
if (ref < 0)
throw std::out_of_range("Intern index out of range");
std::list<PycRef<PycString> >::const_iterator it = m_interns.begin();
while (ref--)
while (ref-- && it != m_interns.end())
++it;
if (it == m_interns.end())
throw std::out_of_range("Intern index out of range");
return *it;
}
PycRef<PycObject> PycModule::getRef(int ref) const
{
if (ref < 0)
throw std::out_of_range("Ref index out of range");
std::list<PycRef<PycObject> >::const_iterator it = m_refs.begin();
while (ref--)
while (ref-- && it != m_refs.end())
++it;
if (it == m_refs.end())
throw std::out_of_range("Ref index out of range");
return *it;
}

View File

@@ -64,12 +64,12 @@ std::string PycLong::repr() const
std::list<int>::const_iterator bit;
int shift = 0, temp = 0;
for (bit = m_value.begin(); bit != m_value.end(); ++bit) {
temp |= *bit << shift;
temp |= unsigned(*bit & 0xFFFF) << shift;
shift += 15;
if (shift >= 32) {
bits.push_back(temp);
shift -= 32;
temp = *bit >> (15 - shift);
temp = unsigned(*bit & 0xFFFF) >> (15 - shift);
}
}
if (temp)
@@ -89,7 +89,7 @@ std::string PycLong::repr() const
while (iter != bits.rend())
aptr += snprintf(aptr, 9, "%08X", *iter++);
*aptr++ = 'L';
*aptr++ = 0;
*aptr = 0;
return accum;
}

View File

@@ -1,6 +1,8 @@
#ifndef _PYC_OBJECT_H
#define _PYC_OBJECT_H
#include <typeinfo>
template <class _Obj>
class PycRef {
public:
@@ -55,11 +57,19 @@ public:
inline int type() const;
/* This is just for coding convenience -- no type checking is done! */
template <class _Cast>
PycRef<_Cast> cast() const { return static_cast<_Cast*>(m_obj); }
PycRef<_Cast> cast() const { return dynamic_cast<_Cast*>(m_obj); }
bool isIdent(const _Obj* obj) { return m_obj == obj; }
template <class _Cast>
PycRef<_Cast> require_cast() const
{
_Cast* result = dynamic_cast<_Cast*>(m_obj);
if (!result)
throw std::bad_cast();
return result;
}
bool isIdent(const _Obj* obj) const { return m_obj == obj; }
private:
_Obj* m_obj;

View File

@@ -1,6 +1,7 @@
#include "pyc_sequence.h"
#include "pyc_module.h"
#include "data.h"
#include <stdexcept>
/* PycTuple */
void PycTuple::load(PycData* stream, PycModule* mod)
@@ -60,6 +61,20 @@ bool PycList::isEqual(PycRef<PycObject> obj) const
return true;
}
PycRef<PycObject> PycList::get(int idx) const
{
if (idx < 0)
throw std::out_of_range("List index out of range");
value_t::const_iterator it = m_values.begin();
while (idx-- && it != m_values.end())
++it;
if (it == m_values.end())
throw std::out_of_range("List index out of range");
return *it;
}
/* PycDict */
void PycDict::load(PycData* stream, PycModule* mod)
@@ -114,6 +129,19 @@ PycRef<PycObject> PycDict::get(PycRef<PycObject> key) const
return NULL; // Disassembly shouldn't get non-existant keys
}
PycRef<PycObject> PycDict::get(int idx) const
{
if (idx < 0)
throw std::out_of_range("Dict index out of range");
value_t::const_iterator it = m_values.begin();
while (idx-- && it != m_values.end())
++it;
if (it == m_values.end())
throw std::out_of_range("Dict index out of range");
return *it;
}
/* PycSet */
void PycSet::load(PycData* stream, PycModule* mod)
@@ -140,3 +168,16 @@ bool PycSet::isEqual(PycRef<PycObject> obj) const
}
return true;
}
PycRef<PycObject> PycSet::get(int idx) const
{
if (idx < 0)
throw std::out_of_range("Set index out of range");
value_t::const_iterator it = m_values.begin();
while (idx-- && it != m_values.end())
++it;
if (it == m_values.end())
throw std::out_of_range("Set index out of range");
return *it;
}

View File

@@ -28,7 +28,7 @@ public:
void load(class PycData* stream, class PycModule* mod);
const value_t& values() const { return m_values; }
PycRef<PycObject> get(int idx) const { return m_values[idx]; }
PycRef<PycObject> get(int idx) const { return m_values.at(idx); }
private:
value_t m_values;
@@ -45,12 +45,7 @@ public:
void load(class PycData* stream, class PycModule* mod);
const value_t& values() const { return m_values; }
PycRef<PycObject> get(int idx) const
{
value_t::const_iterator it = m_values.begin();
for (int i=0; i<idx; i++) ++it;
return *it;
}
PycRef<PycObject> get(int idx) const;
private:
value_t m_values;
@@ -71,12 +66,7 @@ public:
const key_t& keys() const { return m_keys; }
const value_t& values() const { return m_values; }
PycRef<PycObject> get(int idx) const
{
value_t::const_iterator it = m_values.begin();
for (int i=0; i<idx; i++) ++it;
return *it;
}
PycRef<PycObject> get(int idx) const;
private:
key_t m_keys;
@@ -94,12 +84,7 @@ public:
void load(class PycData* stream, class PycModule* mod);
const value_t& values() const { return m_values; }
PycRef<PycObject> get(int idx) const
{
value_t::const_iterator it = m_values.begin();
for (int i=0; i<idx; i++) ++it;
return *it;
}
PycRef<PycObject> get(int idx) const;
private:
value_t m_values;

View File

@@ -2,6 +2,7 @@
#include "pyc_module.h"
#include "data.h"
#include <cstring>
#include <limits>
static void ascii_to_utf8(char** data)
{
@@ -64,6 +65,9 @@ void PycString::load(PycData* stream, PycModule* mod)
else
m_length = stream->get32();
if (m_length < 0 || (m_length > std::numeric_limits<int>::max() - 1))
throw std::bad_alloc();
if (m_length) {
m_value = new char[m_length+1];
stream->getBuffer(m_length, m_value);
@@ -95,6 +99,8 @@ bool PycString::isEqual(const char* str) const
{
if (m_value == str)
return true;
if (!m_value)
return false;
return (strcmp(m_value, str) == 0);
}

View File

@@ -235,12 +235,22 @@ int main(int argc, char* argv[])
}
PycModule mod;
mod.loadFromFile(argv[1]);
try {
mod.loadFromFile(argv[1]);
} catch (std::exception& ex) {
fprintf(stderr, "Error disassembling %s: %s\n", argv[1], ex.what());
return 1;
}
const char* dispname = strrchr(argv[1], PATHSEP);
dispname = (dispname == NULL) ? argv[1] : dispname + 1;
fprintf(pyc_output, "%s (Python %d.%d%s)\n", dispname, mod.majorVer(), mod.minorVer(),
(mod.majorVer() < 3 && mod.isUnicode()) ? " -U" : "");
output_object(mod.code().cast<PycObject>(), &mod, 0);
try {
output_object(mod.code().cast<PycObject>(), &mod, 0);
} catch (std::exception& ex) {
fprintf(stderr, "Error disassembling %s: %s\n", argv[1], ex.what());
return 1;
}
return 0;
}

View File

@@ -15,7 +15,12 @@ int main(int argc, char* argv[])
}
PycModule mod;
mod.loadFromFile(argv[1]);
try {
mod.loadFromFile(argv[1]);
} catch (std::exception& ex) {
fprintf(stderr, "Error loading file %s: %s\n", argv[1], ex.what());
return 1;
}
if (!mod.isValid()) {
fprintf(stderr, "Could not load file %s\n", argv[1]);
return 1;
@@ -25,7 +30,12 @@ int main(int argc, char* argv[])
fputs("# Source Generated with Decompyle++\n", pyc_output);
fprintf(pyc_output, "# File: %s (Python %d.%d%s)\n\n", dispname, mod.majorVer(), mod.minorVer(),
(mod.majorVer() < 3 && mod.isUnicode()) ? " Unicode" : "");
decompyle(mod.code(), &mod);
try {
decompyle(mod.code(), &mod);
} catch (std::exception& ex) {
fprintf(stderr, "Error decompyling %s: %s\n", argv[1], ex.what());
return 1;
}
return 0;
}