Implement struct == and != on SPIR-V.

OpenGL ES2 allows structs to be compared: http://screen/6KnX4ZfkdLtqDWv
This already worked in GLSL, Metal and SkVM.

Change-Id: Iaf7029c0c1ea9d447348c8280a2788f0d36befad
Bug: skia:11846
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/393598
Auto-Submit: John Stiles <johnstiles@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
diff --git a/resources/sksl/shared/StructsInFunctions.sksl b/resources/sksl/shared/StructsInFunctions.sksl
index 8f34630..7c62697 100644
--- a/resources/sksl/shared/StructsInFunctions.sksl
+++ b/resources/sksl/shared/StructsInFunctions.sksl
@@ -4,6 +4,8 @@
 
 struct S { float x; int y; };
 
+struct Nested { S a, b; };
+
 S returns_a_struct() {
     S s;
     s.x = 1;
@@ -25,6 +27,16 @@
     float x = accepts_a_struct(s);
     modifies_a_struct(s);
 
-    bool valid = (x == 3) && (s.x == 2) && (s.y == 3);
+    S expected;
+    expected.x = 2;
+    expected.y = 3;
+
+    Nested n1, n2, n3;
+    n1.a = n1.b = returns_a_struct();
+    n3 = n2 = n1;
+    modifies_a_struct(n3.b);
+
+    bool valid = (x == 3) && (s.x == 2) && (s.y == 3) &&
+                 (s == expected) && (s != returns_a_struct()) && (n1 == n2) && (n1 != n3);
     return valid ? colorGreen : colorRed;
 }
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index 027bf85..60e63d4 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -2394,6 +2394,9 @@
                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
                                                    SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
             }
+            if (operandType->isStruct()) {
+                return this->writeStructComparison(*operandType, lhs, op, rhs, out);
+            }
             SkASSERT(resultType.isBoolean());
             const Type* tmpType;
             if (operandType->isVector()) {
@@ -2413,6 +2416,9 @@
                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
                                                    SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
             }
+            if (operandType->isStruct()) {
+                return this->writeStructComparison(*operandType, lhs, op, rhs, out);
+            }
             [[fallthrough]];
         case Token::Kind::TK_LOGICALXOR:
             SkASSERT(resultType.isBoolean());
@@ -2503,21 +2509,85 @@
     }
 }
 
+SpvId SPIRVCodeGenerator::writeStructComparison(const Type& structType, SpvId lhs, Operator op,
+                                                SpvId rhs, OutputStream& out) {
+    // The inputs must be structs containing fields, and the op must be == or !=.
+    SkASSERT(structType.isStruct());
+    SkASSERT(op.kind() == Token::Kind::TK_EQEQ || op.kind() == Token::Kind::TK_NEQ);
+    const std::vector<Type::Field>& fields = structType.fields();
+    SkASSERT(!fields.empty());
+
+    // Synthesize equality checks for each field in the struct.
+    const Type& boolType = *fContext.fTypes.fBool;
+    SpvId boolTypeId = this->getType(boolType);
+    SpvId allComparisons = (SpvId)-1;
+    for (int index = 0; index < (int)fields.size(); ++index) {
+        // Get the left and right versions of this field.
+        const Type& fieldType = *fields[index].fType;
+        SpvId fieldTypeId = this->getType(fieldType);
+
+        SpvId fieldL = this->nextId(&fieldType);
+        this->writeInstruction(SpvOpCompositeExtract, fieldTypeId, fieldL, lhs, index, out);
+        SpvId fieldR = this->nextId(&fieldType);
+        this->writeInstruction(SpvOpCompositeExtract, fieldTypeId, fieldR, rhs, index, out);
+        // Use `writeBinaryExpression` with the requested == or != operator on these fields.
+        SpvId comparison = this->writeBinaryExpression(fieldType, fieldL, op, fieldType, fieldR,
+                                                       boolType, out);
+        // If this is the first field, we don't need to merge comparison results with anything.
+        if (allComparisons == (SpvId)-1) {
+            allComparisons = comparison;
+            continue;
+        }
+        // Use LogicalAnd or LogicalOr to combine the comparison with all the other comparisons.
+        SpvId logicalOp = this->nextId(&boolType);
+        switch (op.kind()) {
+            case Token::Kind::TK_EQEQ:
+                this->writeInstruction(SpvOpLogicalAnd, boolTypeId, logicalOp,
+                                       comparison, allComparisons, out);
+                break;
+            case Token::Kind::TK_NEQ:
+                this->writeInstruction(SpvOpLogicalOr, boolTypeId, logicalOp,
+                                       comparison, allComparisons, out);
+                break;
+            default:
+                return (SpvId)-1;
+        }
+        allComparisons = logicalOp;
+    }
+    return allComparisons;
+}
+
+static float division_by_literal_value(Operator op, const Expression& right) {
+    // If this is a division by a literal value, returns that literal value. Otherwise, returns 0.
+    if (op.kind() == Token::Kind::TK_SLASH && right.is<FloatLiteral>()) {
+        float rhsValue = right.as<FloatLiteral>().value();
+        if (std::isfinite(rhsValue)) {
+            return rhsValue;
+        }
+    }
+    return 0.0f;
+}
+
 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
     const Expression* left = b.left().get();
     const Expression* right = b.right().get();
     Operator op = b.getOperator();
-    // handle cases where we don't necessarily evaluate both LHS and RHS
+
     switch (op.kind()) {
         case Token::Kind::TK_EQ: {
+            // Handles assignment.
             SpvId rhs = this->writeExpression(*right, out);
             this->getLValue(*left, out)->store(rhs, out);
             return rhs;
         }
         case Token::Kind::TK_LOGICALAND:
-            return this->writeLogicalAnd(b, out);
+            // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
+            return this->writeLogicalAnd(*b.left(), *b.right(), out);
+
         case Token::Kind::TK_LOGICALOR:
-            return this->writeLogicalOr(b, out);
+            // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
+            return this->writeLogicalOr(*b.left(), *b.right(), out);
+
         default:
             break;
     }
@@ -2532,20 +2602,17 @@
         lhs = this->writeExpression(*left, out);
     }
 
-    SpvId rhs = (SpvId)-1;
-    if (op.kind() == Token::Kind::TK_SLASH && right->is<FloatLiteral>()) {
-        float rhsValue = right->as<FloatLiteral>().value();
-        if (std::isfinite(rhsValue) && rhsValue != 0.0f) {
-            // Rewrite floating-point division by a literal into multiplication by the reciprocal.
-            // This converts `expr / 2` into `expr * 0.5`
-            // This improves codegen, especially for certain types of divides (e.g. vector/scalar).
-            op = Operator(Token::Kind::TK_STAR);
-            FloatLiteral reciprocal{right->fOffset, 1.0f / rhsValue, &right->type()};
-            rhs = this->writeExpression(reciprocal, out);
-        }
-    }
-
-    if (rhs == (SpvId)-1) {
+    SpvId rhs;
+    float rhsValue = division_by_literal_value(op, *right);
+    if (rhsValue != 0.0f) {
+        // Rewrite floating-point division by a literal into multiplication by the reciprocal.
+        // This converts `expr / 2` into `expr * 0.5`
+        // This improves codegen, especially for certain types of divides (e.g. vector/scalar).
+        op = Operator(Token::Kind::TK_STAR);
+        FloatLiteral reciprocal{right->fOffset, 1.0f / rhsValue, &right->type()};
+        rhs = this->writeExpression(reciprocal, out);
+    } else {
+        // Write the right-hand side expression normally.
         rhs = this->writeExpression(*right, out);
     }
 
@@ -2557,18 +2624,18 @@
     return result;
 }
 
-SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
-    SkASSERT(a.getOperator().kind() == Token::Kind::TK_LOGICALAND);
+SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right,
+                                          OutputStream& out) {
     BoolLiteral falseLiteral(/*offset=*/-1, /*value=*/false, fContext.fTypes.fBool.get());
     SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
-    SpvId lhs = this->writeExpression(*a.left(), out);
+    SpvId lhs = this->writeExpression(left, out);
     SpvId rhsLabel = this->nextId(nullptr);
     SpvId end = this->nextId(nullptr);
     SpvId lhsBlock = fCurrentBlock;
     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
     this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
     this->writeLabel(rhsLabel, out);
-    SpvId rhs = this->writeExpression(*a.right(), out);
+    SpvId rhs = this->writeExpression(right, out);
     SpvId rhsBlock = fCurrentBlock;
     this->writeInstruction(SpvOpBranch, end, out);
     this->writeLabel(end, out);
@@ -2578,18 +2645,18 @@
     return result;
 }
 
-SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) {
-    SkASSERT(o.getOperator().kind() == Token::Kind::TK_LOGICALOR);
+SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right,
+                                         OutputStream& out) {
     BoolLiteral trueLiteral(/*offset=*/-1, /*value=*/true, fContext.fTypes.fBool.get());
     SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
-    SpvId lhs = this->writeExpression(*o.left(), out);
+    SpvId lhs = this->writeExpression(left, out);
     SpvId rhsLabel = this->nextId(nullptr);
     SpvId end = this->nextId(nullptr);
     SpvId lhsBlock = fCurrentBlock;
     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
     this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
     this->writeLabel(rhsLabel, out);
-    SpvId rhs = this->writeExpression(*o.right(), out);
+    SpvId rhs = this->writeExpression(right, out);
     SpvId rhsBlock = fCurrentBlock;
     this->writeInstruction(SpvOpBranch, end, out);
     this->writeLabel(end, out);
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h
index 31d4c66..ee44ca1 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.h
+++ b/src/sksl/SkSLSPIRVCodeGenerator.h
@@ -324,6 +324,9 @@
                                 SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
                                 SpvOp_ mergeOperator, OutputStream& out);
 
+    SpvId writeStructComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
+                                OutputStream& out);
+
     SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
                                          SpvOp_ floatOperator, SpvOp_ intOperator,
                                          OutputStream& out);
@@ -347,9 +350,9 @@
 
     SpvId writeIndexExpression(const IndexExpression& expr, OutputStream& out);
 
-    SpvId writeLogicalAnd(const BinaryExpression& b, OutputStream& out);
+    SpvId writeLogicalAnd(const Expression& left, const Expression& right, OutputStream& out);
 
-    SpvId writeLogicalOr(const BinaryExpression& o, OutputStream& out);
+    SpvId writeLogicalOr(const Expression& left, const Expression& right, OutputStream& out);
 
     SpvId writePrefixExpression(const PrefixExpression& p, OutputStream& out);
 
diff --git a/tests/sksl/shared/StructsInFunctions.asm.frag b/tests/sksl/shared/StructsInFunctions.asm.frag
index c47265c..ad8f735 100644
--- a/tests/sksl/shared/StructsInFunctions.asm.frag
+++ b/tests/sksl/shared/StructsInFunctions.asm.frag
@@ -19,6 +19,13 @@
 OpName %main "main"
 OpName %s_0 "s"
 OpName %x "x"
+OpName %expected "expected"
+OpName %Nested "Nested"
+OpMemberName %Nested 0 "a"
+OpMemberName %Nested 1 "b"
+OpName %n1 "n1"
+OpName %n2 "n2"
+OpName %n3 "n3"
 OpName %valid "valid"
 OpDecorate %sk_FragColor RelaxedPrecision
 OpDecorate %sk_FragColor Location 0
@@ -35,10 +42,23 @@
 OpMemberDecorate %S 1 Offset 4
 OpDecorate %35 RelaxedPrecision
 OpDecorate %59 RelaxedPrecision
-OpDecorate %83 RelaxedPrecision
-OpDecorate %91 RelaxedPrecision
-OpDecorate %93 RelaxedPrecision
-OpDecorate %94 RelaxedPrecision
+OpMemberDecorate %Nested 0 Offset 0
+OpMemberDecorate %Nested 0 RelaxedPrecision
+OpMemberDecorate %Nested 1 Offset 16
+OpMemberDecorate %Nested 1 RelaxedPrecision
+OpDecorate %76 RelaxedPrecision
+OpDecorate %78 RelaxedPrecision
+OpDecorate %102 RelaxedPrecision
+OpDecorate %103 RelaxedPrecision
+OpDecorate %114 RelaxedPrecision
+OpDecorate %126 RelaxedPrecision
+OpDecorate %127 RelaxedPrecision
+OpDecorate %150 RelaxedPrecision
+OpDecorate %151 RelaxedPrecision
+OpDecorate %172 RelaxedPrecision
+OpDecorate %180 RelaxedPrecision
+OpDecorate %182 RelaxedPrecision
+OpDecorate %183 RelaxedPrecision
 %float = OpTypeFloat 32
 %v4float = OpTypeVector %float 4
 %_ptr_Output_v4float = OpTypePointer Output %v4float
@@ -64,11 +84,13 @@
 %36 = OpTypeFunction %float %_ptr_Function_S
 %45 = OpTypeFunction %void %_ptr_Function_S
 %54 = OpTypeFunction %v4float
+%float_2 = OpConstant %float 2
+%int_3 = OpConstant %int 3
+%Nested = OpTypeStruct %S %S
+%_ptr_Function_Nested = OpTypePointer Function %Nested
 %_ptr_Function_bool = OpTypePointer Function %bool
 %false = OpConstantFalse %bool
 %float_3 = OpConstant %float 3
-%float_2 = OpConstant %float 2
-%int_3 = OpConstant %int 3
 %_ptr_Function_v4float = OpTypePointer Function %v4float
 %_ptr_Uniform_v4float = OpTypePointer Uniform %v4float
 %_entrypoint_v = OpFunction %void None %18
@@ -116,8 +138,13 @@
 %s_0 = OpVariable %_ptr_Function_S Function
 %x = OpVariable %_ptr_Function_float Function
 %60 = OpVariable %_ptr_Function_S Function
+%expected = OpVariable %_ptr_Function_S Function
+%n1 = OpVariable %_ptr_Function_Nested Function
+%n2 = OpVariable %_ptr_Function_Nested Function
+%n3 = OpVariable %_ptr_Function_Nested Function
+%79 = OpVariable %_ptr_Function_S Function
 %valid = OpVariable %_ptr_Function_bool Function
-%84 = OpVariable %_ptr_Function_v4float Function
+%173 = OpVariable %_ptr_Function_v4float Function
 %57 = OpFunctionCall %S %returns_a_struct_S
 OpStore %s_0 %57
 %59 = OpLoad %S %s_0
@@ -125,41 +152,143 @@
 %61 = OpFunctionCall %float %accepts_a_struct_fS %60
 OpStore %x %61
 %62 = OpFunctionCall %void %modifies_a_struct_vS %s_0
-%66 = OpLoad %float %x
-%68 = OpFOrdEqual %bool %66 %float_3
-OpSelectionMerge %70 None
-OpBranchConditional %68 %69 %70
-%69 = OpLabel
-%71 = OpAccessChain %_ptr_Function_float %s_0 %int_0
-%72 = OpLoad %float %71
-%74 = OpFOrdEqual %bool %72 %float_2
-OpBranch %70
-%70 = OpLabel
-%75 = OpPhi %bool %false %55 %74 %69
-OpSelectionMerge %77 None
-OpBranchConditional %75 %76 %77
-%76 = OpLabel
-%78 = OpAccessChain %_ptr_Function_int %s_0 %int_1
-%79 = OpLoad %int %78
-%81 = OpIEqual %bool %79 %int_3
-OpBranch %77
-%77 = OpLabel
-%82 = OpPhi %bool %false %70 %81 %76
-OpStore %valid %82
-%83 = OpLoad %bool %valid
-OpSelectionMerge %88 None
-OpBranchConditional %83 %86 %87
-%86 = OpLabel
-%89 = OpAccessChain %_ptr_Uniform_v4float %13 %int_1
-%91 = OpLoad %v4float %89
-OpStore %84 %91
-OpBranch %88
-%87 = OpLabel
-%92 = OpAccessChain %_ptr_Uniform_v4float %13 %int_0
-%93 = OpLoad %v4float %92
-OpStore %84 %93
-OpBranch %88
+%65 = OpAccessChain %_ptr_Function_float %expected %int_0
+OpStore %65 %float_2
+%67 = OpAccessChain %_ptr_Function_int %expected %int_1
+OpStore %67 %int_3
+%73 = OpFunctionCall %S %returns_a_struct_S
+%74 = OpAccessChain %_ptr_Function_S %n1 %int_1
+OpStore %74 %73
+%75 = OpAccessChain %_ptr_Function_S %n1 %int_0
+OpStore %75 %73
+%76 = OpLoad %Nested %n1
+OpStore %n2 %76
+OpStore %n3 %76
+%77 = OpAccessChain %_ptr_Function_S %n3 %int_1
+%78 = OpLoad %S %77
+OpStore %79 %78
+%80 = OpFunctionCall %void %modifies_a_struct_vS %79
+%81 = OpLoad %S %79
+OpStore %77 %81
+%85 = OpLoad %float %x
+%87 = OpFOrdEqual %bool %85 %float_3
+OpSelectionMerge %89 None
+OpBranchConditional %87 %88 %89
 %88 = OpLabel
-%94 = OpLoad %v4float %84
-OpReturnValue %94
+%90 = OpAccessChain %_ptr_Function_float %s_0 %int_0
+%91 = OpLoad %float %90
+%92 = OpFOrdEqual %bool %91 %float_2
+OpBranch %89
+%89 = OpLabel
+%93 = OpPhi %bool %false %55 %92 %88
+OpSelectionMerge %95 None
+OpBranchConditional %93 %94 %95
+%94 = OpLabel
+%96 = OpAccessChain %_ptr_Function_int %s_0 %int_1
+%97 = OpLoad %int %96
+%98 = OpIEqual %bool %97 %int_3
+OpBranch %95
+%95 = OpLabel
+%99 = OpPhi %bool %false %89 %98 %94
+OpSelectionMerge %101 None
+OpBranchConditional %99 %100 %101
+%100 = OpLabel
+%102 = OpLoad %S %s_0
+%103 = OpLoad %S %expected
+%104 = OpCompositeExtract %float %102 0
+%105 = OpCompositeExtract %float %103 0
+%106 = OpFOrdEqual %bool %104 %105
+%107 = OpCompositeExtract %int %102 1
+%108 = OpCompositeExtract %int %103 1
+%109 = OpIEqual %bool %107 %108
+%110 = OpLogicalAnd %bool %109 %106
+OpBranch %101
+%101 = OpLabel
+%111 = OpPhi %bool %false %95 %110 %100
+OpSelectionMerge %113 None
+OpBranchConditional %111 %112 %113
+%112 = OpLabel
+%114 = OpLoad %S %s_0
+%115 = OpFunctionCall %S %returns_a_struct_S
+%116 = OpCompositeExtract %float %114 0
+%117 = OpCompositeExtract %float %115 0
+%118 = OpFOrdNotEqual %bool %116 %117
+%119 = OpCompositeExtract %int %114 1
+%120 = OpCompositeExtract %int %115 1
+%121 = OpINotEqual %bool %119 %120
+%122 = OpLogicalOr %bool %121 %118
+OpBranch %113
+%113 = OpLabel
+%123 = OpPhi %bool %false %101 %122 %112
+OpSelectionMerge %125 None
+OpBranchConditional %123 %124 %125
+%124 = OpLabel
+%126 = OpLoad %Nested %n1
+%127 = OpLoad %Nested %n2
+%128 = OpCompositeExtract %S %126 0
+%129 = OpCompositeExtract %S %127 0
+%130 = OpCompositeExtract %float %128 0
+%131 = OpCompositeExtract %float %129 0
+%132 = OpFOrdEqual %bool %130 %131
+%133 = OpCompositeExtract %int %128 1
+%134 = OpCompositeExtract %int %129 1
+%135 = OpIEqual %bool %133 %134
+%136 = OpLogicalAnd %bool %135 %132
+%137 = OpCompositeExtract %S %126 1
+%138 = OpCompositeExtract %S %127 1
+%139 = OpCompositeExtract %float %137 0
+%140 = OpCompositeExtract %float %138 0
+%141 = OpFOrdEqual %bool %139 %140
+%142 = OpCompositeExtract %int %137 1
+%143 = OpCompositeExtract %int %138 1
+%144 = OpIEqual %bool %142 %143
+%145 = OpLogicalAnd %bool %144 %141
+%146 = OpLogicalAnd %bool %145 %136
+OpBranch %125
+%125 = OpLabel
+%147 = OpPhi %bool %false %113 %146 %124
+OpSelectionMerge %149 None
+OpBranchConditional %147 %148 %149
+%148 = OpLabel
+%150 = OpLoad %Nested %n1
+%151 = OpLoad %Nested %n3
+%152 = OpCompositeExtract %S %150 0
+%153 = OpCompositeExtract %S %151 0
+%154 = OpCompositeExtract %float %152 0
+%155 = OpCompositeExtract %float %153 0
+%156 = OpFOrdNotEqual %bool %154 %155
+%157 = OpCompositeExtract %int %152 1
+%158 = OpCompositeExtract %int %153 1
+%159 = OpINotEqual %bool %157 %158
+%160 = OpLogicalOr %bool %159 %156
+%161 = OpCompositeExtract %S %150 1
+%162 = OpCompositeExtract %S %151 1
+%163 = OpCompositeExtract %float %161 0
+%164 = OpCompositeExtract %float %162 0
+%165 = OpFOrdNotEqual %bool %163 %164
+%166 = OpCompositeExtract %int %161 1
+%167 = OpCompositeExtract %int %162 1
+%168 = OpINotEqual %bool %166 %167
+%169 = OpLogicalOr %bool %168 %165
+%170 = OpLogicalOr %bool %169 %160
+OpBranch %149
+%149 = OpLabel
+%171 = OpPhi %bool %false %125 %170 %148
+OpStore %valid %171
+%172 = OpLoad %bool %valid
+OpSelectionMerge %177 None
+OpBranchConditional %172 %175 %176
+%175 = OpLabel
+%178 = OpAccessChain %_ptr_Uniform_v4float %13 %int_1
+%180 = OpLoad %v4float %178
+OpStore %173 %180
+OpBranch %177
+%176 = OpLabel
+%181 = OpAccessChain %_ptr_Uniform_v4float %13 %int_0
+%182 = OpLoad %v4float %181
+OpStore %173 %182
+OpBranch %177
+%177 = OpLabel
+%183 = OpLoad %v4float %173
+OpReturnValue %183
 OpFunctionEnd
diff --git a/tests/sksl/shared/StructsInFunctions.glsl b/tests/sksl/shared/StructsInFunctions.glsl
index 9d7005f..99db12a 100644
--- a/tests/sksl/shared/StructsInFunctions.glsl
+++ b/tests/sksl/shared/StructsInFunctions.glsl
@@ -6,6 +6,10 @@
     float x;
     int y;
 };
+struct Nested {
+    S a;
+    S b;
+};
 S returns_a_struct_S() {
     S s;
     s.x = 1.0;
@@ -23,6 +27,15 @@
     S s = returns_a_struct_S();
     float x = accepts_a_struct_fS(s);
     modifies_a_struct_vS(s);
-    bool valid = (x == 3.0 && s.x == 2.0) && s.y == 3;
+    S expected;
+    expected.x = 2.0;
+    expected.y = 3;
+    Nested n1;
+    Nested n2;
+    Nested n3;
+    n1.a = (n1.b = returns_a_struct_S());
+    n3 = (n2 = n1);
+    modifies_a_struct_vS(n3.b);
+    bool valid = (((((x == 3.0 && s.x == 2.0) && s.y == 3) && s == expected) && s != returns_a_struct_S()) && n1 == n2) && n1 != n3;
     return valid ? colorGreen : colorRed;
 }
diff --git a/tests/sksl/shared/StructsInFunctions.metal b/tests/sksl/shared/StructsInFunctions.metal
index 94f27b2..1076644 100644
--- a/tests/sksl/shared/StructsInFunctions.metal
+++ b/tests/sksl/shared/StructsInFunctions.metal
@@ -5,6 +5,10 @@
     float x;
     int y;
 };
+struct Nested {
+    S a;
+    S b;
+};
 struct Uniforms {
     float4 colorRed;
     float4 colorGreen;
@@ -20,6 +24,12 @@
     modifies_a_struct_vS(_var0);
     s = _var0;
 }
+void modifies_a_struct_vS(thread S& s);
+void _skOutParamHelper1_modifies_a_struct_vS(thread Nested& n3) {
+    S _var0 = n3.b;
+    modifies_a_struct_vS(_var0);
+    n3.b = _var0;
+}
 
 S returns_a_struct_S() {
     S s;
@@ -40,7 +50,16 @@
     S s = returns_a_struct_S();
     float x = accepts_a_struct_fS(s);
     _skOutParamHelper0_modifies_a_struct_vS(s);
-    bool valid = (x == 3.0 && s.x == 2.0) && s.y == 3;
+    S expected;
+    expected.x = 2.0;
+    expected.y = 3;
+    Nested n1;
+    Nested n2;
+    Nested n3;
+    n1.a = (n1.b = returns_a_struct_S());
+    n3 = (n2 = n1);
+    _skOutParamHelper1_modifies_a_struct_vS(n3);
+    bool valid = (((((x == 3.0 && s.x == 2.0) && s.y == 3) && s == expected) && s != returns_a_struct_S()) && n1 == n2) && n1 != n3;
     _out.sk_FragColor = valid ? _uniforms.colorGreen : _uniforms.colorRed;
     return _out;
 }