Start to catch int-range errors
diff --git a/include/rive/core/binary_reader.hpp b/include/rive/core/binary_reader.hpp index d6e2570..eb45f15 100644 --- a/include/rive/core/binary_reader.hpp +++ b/include/rive/core/binary_reader.hpp
@@ -5,6 +5,7 @@ #include <string> #include <vector> #include "rive/span.hpp" +#include "rive/core/type_conversions.hpp" namespace rive { class BinaryReader { @@ -12,12 +13,16 @@ Span<const uint8_t> m_Bytes; const uint8_t* m_Position; bool m_Overflowed; + bool m_IntRangeError; void overflow(); + void intRangeError(); public: explicit BinaryReader(Span<const uint8_t>); bool didOverflow() const; + bool didIntRangeError() const; + bool hasError() const { return m_Overflowed || m_IntRangeError; } bool reachedEnd() const; size_t lengthInBytes() const; @@ -29,7 +34,18 @@ uint8_t readByte(); uint32_t readUint32(); uint64_t readVarUint64(); // Reads a LEB128 encoded uint64_t + + // This will cast the uint read to the requested size, but if the + // raw value was out-of-range, instead returns 0 and sets the IntRangeError. + template <typename T> T readVarUintAs() { + auto value = this->readVarUint64(); + if (!fitsIn<T>(value)) { + value = 0; + this->intRangeError(); + } + return static_cast<T>(value); + } }; } // namespace rive -#endif \ No newline at end of file +#endif
diff --git a/include/rive/core/type_conversions.hpp b/include/rive/core/type_conversions.hpp new file mode 100644 index 0000000..75dadae --- /dev/null +++ b/include/rive/core/type_conversions.hpp
@@ -0,0 +1,26 @@ +/* + * Copyright 2022 Rive + */ + +#ifndef _RIVE_TYPE_CONVERSIONS_HPP_ +#define _RIVE_TYPE_CONVERSIONS_HPP_ + +#include "rive/rive_types.hpp" +#include <limits> + +namespace rive { + +template <typename T> bool fitsIn(intmax_t x) { + return x >= std::numeric_limits<T>::min() && + x <= std::numeric_limits<T>::max(); +} + +template <typename T> T castTo(intmax_t x) { + assert(sizeof(x) <= 32); // don't use with 64bit types + assert(fitsIn<T>(x)); + return static_cast<T>(x); +} + +} // namespace + +#endif
diff --git a/src/core/binary_reader.cpp b/src/core/binary_reader.cpp index 4ea5189..153ec46 100644 --- a/src/core/binary_reader.cpp +++ b/src/core/binary_reader.cpp
@@ -6,19 +6,28 @@ using namespace rive; BinaryReader::BinaryReader(Span<const uint8_t> bytes) : - m_Bytes(bytes), m_Position(bytes.begin()), m_Overflowed(false) {} + m_Bytes(bytes), m_Position(bytes.begin()), m_Overflowed(false), m_IntRangeError(false) {} -bool BinaryReader::reachedEnd() const { return m_Position == m_Bytes.end() || didOverflow(); } +bool BinaryReader::reachedEnd() const { + return m_Position == m_Bytes.end() || didOverflow() || didIntRangeError(); +} size_t BinaryReader::lengthInBytes() const { return m_Bytes.size(); } bool BinaryReader::didOverflow() const { return m_Overflowed; } +bool BinaryReader::didIntRangeError() const { return m_IntRangeError; } + void BinaryReader::overflow() { m_Overflowed = true; m_Position = m_Bytes.end(); } +void BinaryReader::intRangeError() { + m_IntRangeError = true; + m_Position = m_Bytes.end(); +} + uint64_t BinaryReader::readVarUint64() { uint64_t value; auto readBytes = decode_uint_leb(m_Position, m_Bytes.end(), &value);
diff --git a/src/file.cpp b/src/file.cpp index b59eee0..4c7314d 100644 --- a/src/file.cpp +++ b/src/file.cpp
@@ -51,32 +51,32 @@ // Import a single Rive runtime object. // Used by the file importer. static Core* readRuntimeObject(BinaryReader& reader, const RuntimeHeader& header) { - auto coreObjectKey = reader.readVarUint64(); - auto object = CoreRegistry::makeCoreInstance((int)coreObjectKey); + auto coreObjectKey = reader.readVarUintAs<int>(); + auto object = CoreRegistry::makeCoreInstance(coreObjectKey); while (true) { - auto propertyKey = reader.readVarUint64(); + auto propertyKey = reader.readVarUintAs<uint16_t>(); if (propertyKey == 0) { // Terminator. https://media.giphy.com/media/7TtvTUMm9mp20/giphy.gif break; } - if (reader.didOverflow()) { + if (reader.hasError()) { delete object; return nullptr; } - if (object == nullptr || !object->deserialize((int)propertyKey, reader)) { + if (object == nullptr || !object->deserialize(propertyKey, reader)) { // We have an unknown object or property, first see if core knows // the property type. - int id = CoreRegistry::propertyFieldId((int)propertyKey); + int id = CoreRegistry::propertyFieldId(propertyKey); if (id == -1) { // No, check if it's in toc. - id = header.propertyFieldId((int)propertyKey); + id = header.propertyFieldId(propertyKey); } if (id == -1) { // Still couldn't find it, give up. fprintf(stderr, - "Unknown property key " RIVE_FMT_U64 ", missing from property ToC.\n", + "Unknown property key %d, missing from property ToC.\n", propertyKey); delete object; return nullptr;
diff --git a/test/binary_reader_test.cpp b/test/binary_reader_test.cpp new file mode 100644 index 0000000..ae1735c --- /dev/null +++ b/test/binary_reader_test.cpp
@@ -0,0 +1,61 @@ +#include <catch.hpp> +#include <rive/core/binary_reader.hpp> + +template <typename T> void checkFits() { + int64_t min = std::numeric_limits<T>::min(); + int64_t max = std::numeric_limits<T>::max(); + REQUIRE( rive::fitsIn<T>(max+0)); + REQUIRE( rive::fitsIn<T>(min-0)); + REQUIRE(!rive::fitsIn<T>(max+1)); + REQUIRE(!rive::fitsIn<T>(min-1)); +} + +TEST_CASE("fitsIn checks", "[type_conversions]") { + checkFits<int8_t>(); + checkFits<uint8_t>(); + + checkFits<int16_t>(); + checkFits<uint16_t>(); + + checkFits<int32_t>(); + checkFits<uint32_t>(); +} + +static uint8_t* packvarint(uint8_t array[], uint64_t value) { + while (value > 127) { + *array++ = static_cast<uint8_t>(0x80 | (value & 0x7F)); + value >>= 7; + } + *array++ = static_cast<uint8_t>(value); + return array; +} + +template <typename T> bool checkAs(uint64_t value) { + uint8_t storage[16]; + uint8_t* p = storage; + + p = packvarint(storage, value); + rive::BinaryReader reader(rive::Span(storage, p - storage)); + + auto newValue = reader.readVarUintAs<T>(); + + if (reader.hasError()) { + REQUIRE(newValue == 0); + } + + return !reader.hasError() && value == newValue; +} + +TEST_CASE("range checks", "[binary_reader]") { + REQUIRE( checkAs<uint8_t>(100)); + REQUIRE( checkAs<uint16_t>(100)); + REQUIRE( checkAs<uint32_t>(100)); + + REQUIRE(!checkAs<uint8_t>(1000)); + REQUIRE( checkAs<uint16_t>(1000)); + REQUIRE( checkAs<uint32_t>(1000)); + + REQUIRE(!checkAs<uint8_t>(100000)); + REQUIRE(!checkAs<uint16_t>(100000)); + REQUIRE( checkAs<uint32_t>(100000)); +}