Parse mAB tags

Change-Id: I760cf44793f64951ff2e275679c64467938f90bc
Reviewed-on: https://skia-review.googlesource.com/109485
Reviewed-by: Mike Klein <mtklein@chromium.org>
Commit-Queue: Brian Osman <brianosman@google.com>
diff --git a/skcms.h b/skcms.h
index 9a1faff..d233484 100644
--- a/skcms.h
+++ b/skcms.h
@@ -22,6 +22,11 @@
     float vals[3][3];
 } skcms_Matrix3x3;
 
+// A row-major 3x4 matrix (ie vals[row][col])
+typedef struct {
+    float vals[3][4];
+} skcms_Matrix3x4;
+
 // A transfer function mapping encoded values to linear values,
 // represented by this 7-parameter piecewise function:
 //
@@ -60,7 +65,7 @@
     // Otherwise, matrix_channels must be 3.
     uint32_t        matrix_channels;
     skcms_Curve     matrix_curves[3];
-    skcms_Matrix3x3 matrix;
+    skcms_Matrix3x4 matrix;
 
     // Required: 3 1D curves. Always present, and output_channels must be 3.
     uint32_t        output_channels;
diff --git a/src/ICCProfile.c b/src/ICCProfile.c
index 875b6bd..f02e3de 100644
--- a/src/ICCProfile.c
+++ b/src/ICCProfile.c
@@ -157,7 +157,8 @@
     uint8_t parameters    [ ];  // 1, 3, 4, 5, or 7 s15.16 parameters, depending on function_type
 } para_Layout;
 
-static bool read_curve_para(const uint8_t* buf, uint32_t size, skcms_Curve* curve) {
+static bool read_curve_para(const uint8_t* buf, uint32_t size,
+                            skcms_Curve* curve, uint32_t* curve_size) {
     if (size < SAFE_SIZEOF(para_Layout)) {
         return false;
     }
@@ -175,6 +176,10 @@
         return false;
     }
 
+    if (curve_size) {
+        *curve_size = SAFE_SIZEOF(para_Layout) + curve_bytes[function_type];
+    }
+
     curve->table_8       = NULL;
     curve->table_16      = NULL;
     curve->table_entries = 0;
@@ -230,7 +235,8 @@
     uint8_t parameters    [ ];  // value_count parameters (8.8 if 1, uint16 (n*65535) if > 1)
 } curv_Layout;
 
-static bool read_curve_curv(const uint8_t* buf, uint32_t size, skcms_Curve* curve) {
+static bool read_curve_curv(const uint8_t* buf, uint32_t size,
+                            skcms_Curve* curve, uint32_t* curve_size) {
     if (size < SAFE_SIZEOF(curv_Layout)) {
         return false;
     }
@@ -242,6 +248,10 @@
         return false;
     }
 
+    if (curve_size) {
+        *curve_size = SAFE_SIZEOF(curv_Layout) + value_count * SAFE_SIZEOF(uint16_t);
+    }
+
     if (value_count < 2) {
         curve->table_8       = NULL;
         curve->table_16      = NULL;
@@ -269,17 +279,19 @@
     return true;
 }
 
-// Parses both curveType and parametricCurveType data
-static bool read_curve(const uint8_t* buf, uint32_t size, skcms_Curve* curve) {
+// Parses both curveType and parametricCurveType data. Ensures that at most 'size' bytes are read.
+// If curve_size is not NULL, writes the number of bytes used by the curve in (*curve_size).
+static bool read_curve(const uint8_t* buf, uint32_t size,
+                       skcms_Curve* curve, uint32_t* curve_size) {
     if (!buf || size < 4 || !curve) {
         return false;
     }
 
     uint32_t type = read_big_u32(buf);
     if (type == make_signature('p', 'a', 'r', 'a')) {
-        return read_curve_para(buf, size, curve);
+        return read_curve_para(buf, size, curve, curve_size);
     } else if (type == make_signature('c', 'u', 'r', 'v')) {
-        return read_curve_curv(buf, size, curve);
+        return read_curve_curv(buf, size, curve, curve_size);
     }
 
     return false;
@@ -405,9 +417,11 @@
 static bool init_a2b_tables(const uint8_t* table_base, uint64_t max_tables_len, uint32_t byte_width,
                             uint32_t input_table_entries, uint32_t output_table_entries,
                             skcms_A2B* a2b) {
+    // byte_width is 1 or 2, [input|output]_table_entries are in [2, 4096], so no overflow
     uint32_t byte_len_per_input_table  = input_table_entries * byte_width;
     uint32_t byte_len_per_output_table = output_table_entries * byte_width;
 
+    // [input|output]_channels are <= 4, so still no overflow
     uint32_t byte_len_all_input_tables  = a2b->input_channels * byte_len_per_input_table;
     uint32_t byte_len_all_output_tables = a2b->output_channels * byte_len_per_output_table;
 
@@ -494,6 +508,177 @@
                            input_table_entries, output_table_entries, a2b);
 }
 
+static bool read_curves(const uint8_t* buf, uint32_t size, uint32_t curve_offset,
+                        uint32_t num_curves, skcms_Curve* curves) {
+    for (uint32_t i = 0; i < num_curves; ++i) {
+        if (curve_offset > size) {
+            return false;
+        }
+
+        uint32_t curve_bytes;
+        if (!read_curve(buf + curve_offset, size - curve_offset, &curves[i], &curve_bytes)) {
+            return false;
+        }
+
+        if (curve_bytes > UINT32_MAX - 3) {
+            return false;
+        }
+        curve_bytes = (curve_bytes + 3) & ~3U;
+
+        uint64_t new_offset_64 = (uint64_t)curve_offset + curve_bytes;
+        curve_offset = (uint32_t)new_offset_64;
+        if (new_offset_64 != curve_offset) {
+            return false;
+        }
+    }
+
+    return true;
+}
+
+typedef struct {
+    uint8_t type                 [ 4];
+    uint8_t reserved_a           [ 4];
+    uint8_t input_channels       [ 1];
+    uint8_t output_channels      [ 1];
+    uint8_t reserved_b           [ 2];
+    uint8_t b_curve_offset       [ 4];
+    uint8_t matrix_offset        [ 4];
+    uint8_t m_curve_offset       [ 4];
+    uint8_t clut_offset          [ 4];
+    uint8_t a_curve_offset       [ 4];
+} mAB_Layout;
+
+typedef struct {
+    uint8_t grid_points          [16];
+    uint8_t grid_byte_width      [ 1];
+    uint8_t reserved             [ 3];
+    uint8_t data                 [  ];
+} mABCLUT_Layout;
+
+static bool read_tag_mab(const skcms_ICCTag* tag, skcms_A2B* a2b) {
+    if (tag->size < SAFE_SIZEOF(mAB_Layout)) {
+        return false;
+    }
+
+    const mAB_Layout* mABTag = (const mAB_Layout*)tag->buf;
+
+    a2b->input_channels  = mABTag->input_channels[0];
+    a2b->output_channels = mABTag->output_channels[0];
+
+    // We require exactly three (ie XYZ/Lab/RGB) output channels
+    if (a2b->output_channels != ARRAY_COUNT(a2b->output_curves)) {
+        return false;
+    }
+    // We require no more than four (ie CMYK) input channels
+    if (a2b->input_channels > ARRAY_COUNT(a2b->input_curves)) {
+        return false;
+    }
+
+    uint32_t b_curve_offset = read_big_u32(mABTag->b_curve_offset);
+    uint32_t matrix_offset  = read_big_u32(mABTag->matrix_offset);
+    uint32_t m_curve_offset = read_big_u32(mABTag->m_curve_offset);
+    uint32_t clut_offset    = read_big_u32(mABTag->clut_offset);
+    uint32_t a_curve_offset = read_big_u32(mABTag->a_curve_offset);
+
+    // "B" curves must be present
+    if (0 == b_curve_offset) {
+        return false;
+    }
+
+    if (!read_curves(tag->buf, tag->size, b_curve_offset, a2b->output_channels,
+                     a2b->output_curves)) {
+        return false;
+    }
+
+    // "M" curves and Matrix must be used together
+    if (0 != m_curve_offset) {
+        if (0 == matrix_offset) {
+            return false;
+        }
+        a2b->matrix_channels = a2b->output_channels;
+        if (!read_curves(tag->buf, tag->size, m_curve_offset, a2b->matrix_channels,
+                         a2b->matrix_curves)) {
+            return false;
+        }
+
+        // Read matrix, which is stored as a row-major 3x3, followed by the fourth column
+        if (tag->size < matrix_offset + 12 * SAFE_SIZEOF(uint32_t)) {
+            return false;
+        }
+        const uint8_t* mtx_buf = tag->buf + matrix_offset;
+        a2b->matrix.vals[0][0] = read_big_fixed(mtx_buf + 0);
+        a2b->matrix.vals[0][1] = read_big_fixed(mtx_buf + 4);
+        a2b->matrix.vals[0][2] = read_big_fixed(mtx_buf + 8);
+        a2b->matrix.vals[1][0] = read_big_fixed(mtx_buf + 12);
+        a2b->matrix.vals[1][1] = read_big_fixed(mtx_buf + 16);
+        a2b->matrix.vals[1][2] = read_big_fixed(mtx_buf + 20);
+        a2b->matrix.vals[2][0] = read_big_fixed(mtx_buf + 24);
+        a2b->matrix.vals[2][1] = read_big_fixed(mtx_buf + 28);
+        a2b->matrix.vals[2][2] = read_big_fixed(mtx_buf + 32);
+        a2b->matrix.vals[0][3] = read_big_fixed(mtx_buf + 36);
+        a2b->matrix.vals[1][3] = read_big_fixed(mtx_buf + 40);
+        a2b->matrix.vals[2][3] = read_big_fixed(mtx_buf + 44);
+    } else {
+        if (0 != matrix_offset) {
+            return false;
+        }
+        a2b->matrix_channels = 0;
+    }
+
+    // "A" curves and CLUT must be used together
+    if (0 != a_curve_offset) {
+        if (0 == clut_offset) {
+            return false;
+        }
+        if (!read_curves(tag->buf, tag->size, a_curve_offset, a2b->input_channels,
+                         a2b->input_curves)) {
+            return false;
+        }
+
+        if (tag->size < clut_offset + SAFE_SIZEOF(mABCLUT_Layout)) {
+            return false;
+        }
+        const mABCLUT_Layout* clut = (const mABCLUT_Layout*)(tag->buf + clut_offset);
+
+        if (clut->grid_byte_width[0] == 1) {
+            a2b->grid_8  = clut->data;
+            a2b->grid_16 = NULL;
+        } else if (clut->grid_byte_width[0] == 2) {
+            a2b->grid_8  = NULL;
+            a2b->grid_16 = clut->data;
+        } else {
+            return false;
+        }
+
+        uint64_t grid_size = a2b->output_channels * clut->grid_byte_width[0];
+        for (uint32_t i = 0; i < a2b->input_channels; ++i) {
+            a2b->grid_points[i] = clut->grid_points[i];
+            // The grid only makes sense with at least two points along each axis
+            if (a2b->grid_points[i] < 2) {
+                return false;
+            }
+            grid_size *= a2b->grid_points[i];
+        }
+        if (tag->size < clut_offset + SAFE_SIZEOF(mABCLUT_Layout) + grid_size) {
+            return false;
+        }
+    } else {
+        if (0 != clut_offset) {
+            return false;
+        }
+
+        // If there is no CLUT, the number of input and output channels must match
+        if (a2b->input_channels != a2b->output_channels) {
+            return false;
+        }
+
+        // Zero out the number of input channels to signal that we're skipping this stage
+        a2b->input_channels = 0;
+    }
+
+    return true;
+}
+
 static bool read_a2b(const skcms_ICCProfile* profile, skcms_A2B* a2b) {
     if (!profile || !a2b) {
         return false;
@@ -508,6 +693,8 @@
         return read_tag_mft1(&tag, a2b);
     } else if (tag.type == make_signature('m', 'f', 't', '2')) {
         return read_tag_mft2(&tag, a2b);
+    } else if (tag.type == make_signature('m', 'A', 'B', ' ')) {
+        return read_tag_mab(&tag, a2b);
     }
 
     // TODO: Also parse lutAtoBType ('mAB ')
@@ -612,9 +799,9 @@
         skcms_GetTagBySignature(profile, make_signature('r', 'T', 'R', 'C'), &rTRC) &&
         skcms_GetTagBySignature(profile, make_signature('g', 'T', 'R', 'C'), &gTRC) &&
         skcms_GetTagBySignature(profile, make_signature('b', 'T', 'R', 'C'), &bTRC) &&
-        read_curve(rTRC.buf, rTRC.size, &profile->trc[0]) &&
-        read_curve(gTRC.buf, gTRC.size, &profile->trc[1]) &&
-        read_curve(bTRC.buf, bTRC.size, &profile->trc[2]);
+        read_curve(rTRC.buf, rTRC.size, &profile->trc[0], NULL) &&
+        read_curve(gTRC.buf, gTRC.size, &profile->trc[1], NULL) &&
+        read_curve(bTRC.buf, bTRC.size, &profile->trc[2], NULL);
 
     profile->has_tf       = get_transfer_function(profile, &profile->tf);
     profile->has_toXYZD50 = read_to_XYZD50       (profile, &profile->toXYZD50);