Start to catch int-range errors
diff --git a/include/rive/core/binary_reader.hpp b/include/rive/core/binary_reader.hpp
index d6e2570..eb45f15 100644
--- a/include/rive/core/binary_reader.hpp
+++ b/include/rive/core/binary_reader.hpp
@@ -5,6 +5,7 @@
#include <string>
#include <vector>
#include "rive/span.hpp"
+#include "rive/core/type_conversions.hpp"
namespace rive {
class BinaryReader {
@@ -12,12 +13,16 @@
Span<const uint8_t> m_Bytes;
const uint8_t* m_Position;
bool m_Overflowed;
+ bool m_IntRangeError;
void overflow();
+ void intRangeError();
public:
explicit BinaryReader(Span<const uint8_t>);
bool didOverflow() const;
+ bool didIntRangeError() const;
+ bool hasError() const { return m_Overflowed || m_IntRangeError; }
bool reachedEnd() const;
size_t lengthInBytes() const;
@@ -29,7 +34,18 @@
uint8_t readByte();
uint32_t readUint32();
uint64_t readVarUint64(); // Reads a LEB128 encoded uint64_t
+
+ // This will cast the uint read to the requested size, but if the
+ // raw value was out-of-range, instead returns 0 and sets the IntRangeError.
+ template <typename T> T readVarUintAs() {
+ auto value = this->readVarUint64();
+ if (!fitsIn<T>(value)) {
+ value = 0;
+ this->intRangeError();
+ }
+ return static_cast<T>(value);
+ }
};
} // namespace rive
-#endif
\ No newline at end of file
+#endif
diff --git a/include/rive/core/type_conversions.hpp b/include/rive/core/type_conversions.hpp
new file mode 100644
index 0000000..f188413
--- /dev/null
+++ b/include/rive/core/type_conversions.hpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright 2022 Rive
+ */
+
+#ifndef _RIVE_TYPE_CONVERSIONS_HPP_
+#define _RIVE_TYPE_CONVERSIONS_HPP_
+
+#include "rive/rive_types.hpp"
+#include <limits>
+
+namespace rive {
+
+template <typename T> bool fitsIn(intmax_t x) {
+ return x >= std::numeric_limits<T>::min() &&
+ x <= std::numeric_limits<T>::max();
+}
+
+template <typename T> T castTo(intmax_t x) {
+ assert(sizeof(x) <= 4); // don't use with 64bit types
+ assert(fitsIn<T>(x));
+ return static_cast<T>(x);
+}
+
+} // namespace
+
+#endif
diff --git a/src/core/binary_reader.cpp b/src/core/binary_reader.cpp
index 4ea5189..153ec46 100644
--- a/src/core/binary_reader.cpp
+++ b/src/core/binary_reader.cpp
@@ -6,19 +6,28 @@
using namespace rive;
BinaryReader::BinaryReader(Span<const uint8_t> bytes) :
- m_Bytes(bytes), m_Position(bytes.begin()), m_Overflowed(false) {}
+ m_Bytes(bytes), m_Position(bytes.begin()), m_Overflowed(false), m_IntRangeError(false) {}
-bool BinaryReader::reachedEnd() const { return m_Position == m_Bytes.end() || didOverflow(); }
+bool BinaryReader::reachedEnd() const {
+ return m_Position == m_Bytes.end() || didOverflow() || didIntRangeError();
+}
size_t BinaryReader::lengthInBytes() const { return m_Bytes.size(); }
bool BinaryReader::didOverflow() const { return m_Overflowed; }
+bool BinaryReader::didIntRangeError() const { return m_IntRangeError; }
+
void BinaryReader::overflow() {
m_Overflowed = true;
m_Position = m_Bytes.end();
}
+void BinaryReader::intRangeError() {
+ m_IntRangeError = true;
+ m_Position = m_Bytes.end();
+}
+
uint64_t BinaryReader::readVarUint64() {
uint64_t value;
auto readBytes = decode_uint_leb(m_Position, m_Bytes.end(), &value);
diff --git a/src/file.cpp b/src/file.cpp
index b59eee0..4c7314d 100644
--- a/src/file.cpp
+++ b/src/file.cpp
@@ -51,32 +51,32 @@
// Import a single Rive runtime object.
// Used by the file importer.
static Core* readRuntimeObject(BinaryReader& reader, const RuntimeHeader& header) {
- auto coreObjectKey = reader.readVarUint64();
- auto object = CoreRegistry::makeCoreInstance((int)coreObjectKey);
+ auto coreObjectKey = reader.readVarUintAs<int>();
+ auto object = CoreRegistry::makeCoreInstance(coreObjectKey);
while (true) {
- auto propertyKey = reader.readVarUint64();
+ auto propertyKey = reader.readVarUintAs<uint16_t>();
if (propertyKey == 0) {
// Terminator. https://media.giphy.com/media/7TtvTUMm9mp20/giphy.gif
break;
}
- if (reader.didOverflow()) {
+ if (reader.hasError()) {
delete object;
return nullptr;
}
- if (object == nullptr || !object->deserialize((int)propertyKey, reader)) {
+ if (object == nullptr || !object->deserialize(propertyKey, reader)) {
// We have an unknown object or property, first see if core knows
// the property type.
- int id = CoreRegistry::propertyFieldId((int)propertyKey);
+ int id = CoreRegistry::propertyFieldId(propertyKey);
if (id == -1) {
// No, check if it's in toc.
- id = header.propertyFieldId((int)propertyKey);
+ id = header.propertyFieldId(propertyKey);
}
if (id == -1) {
// Still couldn't find it, give up.
fprintf(stderr,
- "Unknown property key " RIVE_FMT_U64 ", missing from property ToC.\n",
+ "Unknown property key %d, missing from property ToC.\n",
propertyKey);
delete object;
return nullptr;
diff --git a/test/binary_reader_test.cpp b/test/binary_reader_test.cpp
new file mode 100644
index 0000000..ae1735c
--- /dev/null
+++ b/test/binary_reader_test.cpp
@@ -0,0 +1,61 @@
+#include <catch.hpp>
+#include <rive/core/binary_reader.hpp>
+
+template <typename T> void checkFits() {
+ int64_t min = std::numeric_limits<T>::min();
+ int64_t max = std::numeric_limits<T>::max();
+ REQUIRE( rive::fitsIn<T>(max+0));
+ REQUIRE( rive::fitsIn<T>(min-0));
+ REQUIRE(!rive::fitsIn<T>(max+1));
+ REQUIRE(!rive::fitsIn<T>(min-1));
+}
+
+TEST_CASE("fitsIn checks", "[type_conversions]") {
+ checkFits<int8_t>();
+ checkFits<uint8_t>();
+
+ checkFits<int16_t>();
+ checkFits<uint16_t>();
+
+ checkFits<int32_t>();
+ checkFits<uint32_t>();
+}
+
+static uint8_t* packvarint(uint8_t array[], uint64_t value) {
+ while (value > 127) {
+ *array++ = static_cast<uint8_t>(0x80 | (value & 0x7F));
+ value >>= 7;
+ }
+ *array++ = static_cast<uint8_t>(value);
+ return array;
+}
+
+template <typename T> bool checkAs(uint64_t value) {
+ uint8_t storage[16];
+ uint8_t* p = storage;
+
+ p = packvarint(storage, value);
+ rive::BinaryReader reader(rive::Span(storage, p - storage));
+
+ auto newValue = reader.readVarUintAs<T>();
+
+ if (reader.hasError()) {
+ REQUIRE(newValue == 0);
+ }
+
+ return !reader.hasError() && value == newValue;
+}
+
+TEST_CASE("range checks", "[binary_reader]") {
+ REQUIRE( checkAs<uint8_t>(100));
+ REQUIRE( checkAs<uint16_t>(100));
+ REQUIRE( checkAs<uint32_t>(100));
+
+ REQUIRE(!checkAs<uint8_t>(1000));
+ REQUIRE( checkAs<uint16_t>(1000));
+ REQUIRE( checkAs<uint32_t>(1000));
+
+ REQUIRE(!checkAs<uint8_t>(100000));
+ REQUIRE(!checkAs<uint16_t>(100000));
+ REQUIRE( checkAs<uint32_t>(100000));
+}