Merge pull request #1613 from piratf/Issue1604_ASCIIValidation

Check before really encoding code points by default.
diff --git a/include/rapidjson/encodings.h b/include/rapidjson/encodings.h
index 0b24467..b7e0516 100644
--- a/include/rapidjson/encodings.h
+++ b/include/rapidjson/encodings.h
@@ -99,6 +99,11 @@
     enum { supportUnicode = 1 };
 
     template<typename OutputStream>
+    static bool ValidateCodePoint(OutputStream&, unsigned codepoint) {
+        return codepoint <= 0x10FFFF;
+    }
+
+    template<typename OutputStream>
     static void Encode(OutputStream& os, unsigned codepoint) {
         if (codepoint <= 0x7F) 
             os.Put(static_cast<Ch>(codepoint & 0xFF));
@@ -273,6 +278,16 @@
     enum { supportUnicode = 1 };
 
     template<typename OutputStream>
+    static bool ValidateCodePoint(OutputStream&, unsigned codepoint) {
+        if (codepoint <= 0xFFFF) {
+            return (codepoint < 0xD800 || codepoint > 0xDFFF);
+        }
+        else {
+            return codepoint <= 0x10FFFF;
+        }
+    }
+
+    template<typename OutputStream>
     static void Encode(OutputStream& os, unsigned codepoint) {
         RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 2);
         if (codepoint <= 0xFFFF) {
@@ -422,6 +437,11 @@
     enum { supportUnicode = 1 };
 
     template<typename OutputStream>
+    static bool ValidateCodePoint(OutputStream&, unsigned codepoint) {
+        return codepoint <= 0x10FFFF;
+    }
+
+    template<typename OutputStream>
     static void Encode(OutputStream& os, unsigned codepoint) {
         RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 4);
         RAPIDJSON_ASSERT(codepoint <= 0x10FFFF);
@@ -545,6 +565,11 @@
     enum { supportUnicode = 0 };
 
     template<typename OutputStream>
+    static bool ValidateCodePoint(OutputStream&, unsigned codepoint) {
+        return codepoint <= 0x7F;
+    }
+
+    template<typename OutputStream>
     static void Encode(OutputStream& os, unsigned codepoint) {
         RAPIDJSON_ASSERT(codepoint <= 0x7F);
         os.Put(static_cast<Ch>(codepoint & 0xFF));
@@ -620,6 +645,13 @@
 #define RAPIDJSON_ENCODINGS_FUNC(x) UTF8<Ch>::x, UTF16LE<Ch>::x, UTF16BE<Ch>::x, UTF32LE<Ch>::x, UTF32BE<Ch>::x
 
     template<typename OutputStream>
+    static RAPIDJSON_FORCEINLINE bool ValidateCodePoint(OutputStream& os, unsigned codepoint) {
+        typedef bool (*ValidateCodePointFunc)(OutputStream&, unsigned);
+        static const ValidateCodePointFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(ValidateCodePoint) };
+        return (*f[os.GetType()])(os, codepoint);
+    }
+
+    template<typename OutputStream>
     static RAPIDJSON_FORCEINLINE void Encode(OutputStream& os, unsigned codepoint) {
         typedef void (*EncodeFunc)(OutputStream&, unsigned);
         static const EncodeFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Encode) };
@@ -651,10 +683,69 @@
 };
 
 ///////////////////////////////////////////////////////////////////////////////
+// ValidatableEncoder
+
+/*! Wrapper for TEncoding::Encode, have an optional validate feature.
+    Since the feature is optional, this function will be implemented by
+    template partial specialization, to avoid the overhead of runtime check.
+
+    \tprarm CodePointValidation Run validate before encode or not.
+*/
+// By default, This Encoder will validate code point and generate parse error.
+// Users can switch the check feature off by set 'CodePointValidation' to 'false'.
+template<bool CodePointValidation = true>
+class ValidatableEncoder {
+public:
+    template<typename TEncoding, typename OutputStream>
+    static bool Encode(OutputStream &os, unsigned codepoint);
+
+    template<typename TEncoding, typename OutputStream>
+    static bool EncodeUnsafe(OutputStream &os, unsigned codepoint);
+};
+
+template<bool CodePointValidation>
+template<typename TEncoding, typename OutputStream>
+bool
+ValidatableEncoder<CodePointValidation>::Encode(OutputStream &os, unsigned codepoint) {
+    if (!TEncoding::ValidateCodePoint(os, codepoint)) {
+        return false;
+    }
+    TEncoding::Encode(os, codepoint);
+    return true;
+}
+
+template<bool CodePointValidation>
+template<typename TEncoding, typename OutputStream>
+bool
+ValidatableEncoder<CodePointValidation>::EncodeUnsafe(OutputStream &os, unsigned codepoint) {
+    if (!TEncoding::ValidateCodePoint(os, codepoint)) {
+        return false;
+    }
+    TEncoding::EncodeUnsafe(os, codepoint);
+    return true;
+}
+
+template<>
+template<typename TEncoding, typename OutputStream>
+bool
+ValidatableEncoder<false>::Encode(OutputStream &os, unsigned codepoint) {
+    TEncoding::Encode(os, codepoint);
+    return true;
+}
+
+template<>
+template<typename TEncoding, typename OutputStream>
+bool
+ValidatableEncoder<false>::EncodeUnsafe(OutputStream &os, unsigned codepoint) {
+    TEncoding::EncodeUnsafe(os, codepoint);
+    return true;
+}
+
+///////////////////////////////////////////////////////////////////////////////
 // Transcoder
 
 //! Encoding conversion.
-template<typename SourceEncoding, typename TargetEncoding>
+template<typename SourceEncoding, typename TargetEncoding, bool CodePointValidation = true>
 struct Transcoder {
     //! Take one Unicode codepoint from source encoding, convert it to target encoding and put it to the output stream.
     template<typename InputStream, typename OutputStream>
@@ -662,8 +753,7 @@
         unsigned codepoint;
         if (!SourceEncoding::Decode(is, &codepoint))
             return false;
-        TargetEncoding::Encode(os, codepoint);
-        return true;
+        return ValidatableEncoder<CodePointValidation>::template Encode<TargetEncoding>(os, codepoint);
     }
 
     template<typename InputStream, typename OutputStream>
@@ -671,8 +761,7 @@
         unsigned codepoint;
         if (!SourceEncoding::Decode(is, &codepoint))
             return false;
-        TargetEncoding::EncodeUnsafe(os, codepoint);
-        return true;
+        return ValidatableEncoder<CodePointValidation>::template EncodeUnsafe<TargetEncoding>(os, codepoint);
     }
 
     //! Validate one Unicode codepoint from an encoded stream.
diff --git a/include/rapidjson/fwd.h b/include/rapidjson/fwd.h
index b74a2b8..cdb5894 100644
--- a/include/rapidjson/fwd.h
+++ b/include/rapidjson/fwd.h
@@ -31,7 +31,7 @@
 template<typename CharType> struct ASCII;
 template<typename CharType> struct AutoUTF;
 
-template<typename SourceEncoding, typename TargetEncoding>
+template<typename SourceEncoding, typename TargetEncoding, bool CodePointValidation>
 struct Transcoder;
 
 // allocators.h
diff --git a/include/rapidjson/reader.h b/include/rapidjson/reader.h
index 13d27c2..ce8bf0d 100644
--- a/include/rapidjson/reader.h
+++ b/include/rapidjson/reader.h
@@ -1028,7 +1028,10 @@
                             RAPIDJSON_PARSE_ERROR(kParseErrorStringUnicodeSurrogateInvalid, escapeOffset);
                         codepoint = (((codepoint - 0xD800) << 10) | (codepoint2 - 0xDC00)) + 0x10000;
                     }
-                    TEncoding::Encode(os, codepoint);
+                    if (!ValidatableEncoder<static_cast<bool>(parseFlags & kParseValidateEncodingFlag)>::template Encode<TEncoding>(os, codepoint))
+                    {
+                        RAPIDJSON_PARSE_ERROR(kParseErrorStringInvalidEncoding, escapeOffset);
+                    }
                 }
                 else
                     RAPIDJSON_PARSE_ERROR(kParseErrorStringEscapeInvalid, escapeOffset);
diff --git a/test/unittest/documenttest.cpp b/test/unittest/documenttest.cpp
index 2b0f269..038ddb8 100644
--- a/test/unittest/documenttest.cpp
+++ b/test/unittest/documenttest.cpp
@@ -631,6 +631,11 @@
     EXPECT_EQ(kParseErrorStringInvalidEncoding, d.GetParseError());
 }
 
+TEST(DocumentDeathTest, Issue1604_ASCIIValidation) {
+    GenericDocument<ASCII<>> d_no_check;
+    ASSERT_THROW((d_no_check.Parse("\"\\u1234\"")), AssertException);
+}
+
 // This test does not properly use parsing, just for testing.
 // It must call ClearStack() explicitly to prevent memory leak.
 // But here we cannot as ClearStack() is private.
diff --git a/test/unittest/encodingstest.cpp b/test/unittest/encodingstest.cpp
index 82cf777..86866fe 100644
--- a/test/unittest/encodingstest.cpp
+++ b/test/unittest/encodingstest.cpp
@@ -14,8 +14,6 @@
 
 #include "unittest.h"
 #include "rapidjson/filereadstream.h"
-#include "rapidjson/filewritestream.h"
-#include "rapidjson/encodedstream.h"
 #include "rapidjson/stringbuffer.h"
 
 using namespace rapidjson;
@@ -332,6 +330,12 @@
             }
         }
     }
+
+    // Validate code point before encoding
+    EXPECT_FALSE(ValidatableEncoder<>::Encode<UTF8<> >(os, 0xFFFFFFFF));
+    EXPECT_FALSE(ValidatableEncoder<>::EncodeUnsafe<UTF8<> >(os, 0xFFFFFFFF));
+    EXPECT_THROW(ValidatableEncoder<false>::Encode<UTF8<> >(os, 0xFFFFFFFF), AssertException);
+    EXPECT_THROW(ValidatableEncoder<false>::EncodeUnsafe<UTF8<> >(os, 0xFFFFFFFF), AssertException);
 }
 
 TEST(EncodingsTest, UTF16) {
@@ -392,6 +396,20 @@
             }
         }
     }
+
+    // Validate code point before encoding
+    EXPECT_FALSE(ValidatableEncoder<>::Encode<UTF16<> >(os, 0xFFFFFFFF));
+    EXPECT_FALSE(ValidatableEncoder<>::EncodeUnsafe<UTF16<> >(os, 0xFFFFFFFF));
+    EXPECT_FALSE(ValidatableEncoder<>::Encode<UTF16<> >(os, 0xD800));
+    EXPECT_FALSE(ValidatableEncoder<>::EncodeUnsafe<UTF16<> >(os, 0xD800));
+    EXPECT_FALSE(ValidatableEncoder<>::Encode<UTF16<> >(os, 0xDFFF));
+    EXPECT_FALSE(ValidatableEncoder<>::EncodeUnsafe<UTF16<> >(os, 0xDFFF));
+    EXPECT_THROW(ValidatableEncoder<false>::Encode<UTF16<> >(os, 0xFFFFFFFF), AssertException);
+    EXPECT_THROW(ValidatableEncoder<false>::EncodeUnsafe<UTF16<> >(os, 0xFFFFFFFF), AssertException);
+    EXPECT_THROW(ValidatableEncoder<false>::Encode<UTF16<> >(os, 0xD800), AssertException);
+    EXPECT_THROW(ValidatableEncoder<false>::EncodeUnsafe<UTF16<> >(os, 0xD800), AssertException);
+    EXPECT_THROW(ValidatableEncoder<false>::Encode<UTF16<> >(os, 0xDFFF), AssertException);
+    EXPECT_THROW(ValidatableEncoder<false>::EncodeUnsafe<UTF16<> >(os, 0xDFFF), AssertException);
 }
 
 TEST(EncodingsTest, UTF32) {
@@ -423,6 +441,12 @@
             }
         }
     }
+
+    // Validate code point before encoding
+    EXPECT_FALSE(ValidatableEncoder<>::Encode<UTF32<> >(os, 0xFFFFFFFF));
+    EXPECT_FALSE(ValidatableEncoder<>::EncodeUnsafe<UTF32<> >(os, 0xFFFFFFFF));
+    EXPECT_THROW(ValidatableEncoder<false>::Encode<UTF32<> >(os, 0xFFFFFFFF), AssertException);
+    EXPECT_THROW(ValidatableEncoder<false>::EncodeUnsafe<UTF32<> >(os, 0xFFFFFFFF), AssertException);
 }
 
 TEST(EncodingsTest, ASCII) {
@@ -448,4 +472,10 @@
             EXPECT_EQ(0, StrCmp(encodedStr, os2.GetString()));
         }
     }
+
+    // Validate code point before encoding
+    EXPECT_FALSE(ValidatableEncoder<>::Encode<ASCII<> >(os, 0x0080));
+    EXPECT_FALSE(ValidatableEncoder<>::EncodeUnsafe<ASCII<> >(os, 0x0080));
+    EXPECT_THROW(ValidatableEncoder<false>::Encode<ASCII<> >(os, 0x0080), AssertException);
+    EXPECT_THROW(ValidatableEncoder<false>::EncodeUnsafe<ASCII<> >(os, 0x0080), AssertException);
 }
diff --git a/test/unittest/fwdtest.cpp b/test/unittest/fwdtest.cpp
index 1936d97..353dba0 100644
--- a/test/unittest/fwdtest.cpp
+++ b/test/unittest/fwdtest.cpp
@@ -39,7 +39,7 @@
     UTF32LE<unsigned>* utf32le;
     ASCII<char>* ascii;
     AutoUTF<unsigned>* autoutf;
-    Transcoder<UTF8<char>, UTF8<char> >* transcoder;
+    Transcoder<UTF8<char>, UTF8<char>, true>* transcoder;
 
     // allocators.h
     CrtAllocator* crtallocator;