Add choose
wuffs_png_decode_filter_1_sub/clang9 905MB/s ± 0% 1339MB/s ± 0% +48.05% (p=0.008 n=5+5)
wuffs_png_decode_filter_1_sub/gcc10 1.12GB/s ± 0% 1.85GB/s ± 0% +64.80% (p=0.016 n=5+4)
diff --git a/internal/cgen/statement.go b/internal/cgen/statement.go
index 598cda2..3ed6bd7 100644
--- a/internal/cgen/statement.go
+++ b/internal/cgen/statement.go
@@ -57,6 +57,8 @@
case a.KAssign:
n := n.AsAssign()
return g.writeStatementAssign(b, n.Operator(), n.LHS(), n.RHS(), depth)
+ case a.KChoose:
+ return g.writeStatementChoose(b, n.AsChoose(), depth)
case a.KIOBind:
return g.writeStatementIOBind(b, n.AsIOBind(), depth)
case a.KIf:
@@ -221,6 +223,22 @@
return nil
}
+func (g *gen) writeStatementChoose(b *buffer, n *a.Choose, depth uint32) error {
+ recv := g.currFunk.astFunc.Receiver()
+ args := n.Args()
+ if len(args) != 1 {
+ return fmt.Errorf("TODO: multiple choice")
+ }
+ id := args[0].AsExpr().Ident()
+ suffix := ""
+ if n.Name() == id {
+ suffix = "__choosy_default"
+ }
+ b.printf("self->private_impl.choosy_%s = &%s%s__%s%s;\n",
+ n.Name().Str(g.tm), g.pkgPrefix, recv.Str(g.tm), id.Str(g.tm), suffix)
+ return nil
+}
+
func (g *gen) writeStatementIOBind(b *buffer, n *a.IOBind, depth uint32) error {
if g.currFunk.ioBinds > maxIOBinds {
return fmt.Errorf("too many temporary variables required")
diff --git a/lang/ast/ast.go b/lang/ast/ast.go
index fc68888..5b5868c 100644
--- a/lang/ast/ast.go
+++ b/lang/ast/ast.go
@@ -40,6 +40,7 @@
KArg
KAssert
KAssign
+ KChoose
KConst
KExpr
KField
@@ -71,6 +72,7 @@
KArg: "KArg",
KAssert: "KAssert",
KAssign: "KAssign",
+ KChoose: "KChoose",
KConst: "KConst",
KExpr: "KExpr",
KField: "KField",
@@ -160,6 +162,7 @@
// Arg . . name Arg
// Assert keyword . lit(reason) Assert
// Assign operator . . Assign
+ // Choose . . name Choose
// Const . pkg name Const
// Expr operator . literal/ident Expr
// Field . . name Field
@@ -198,6 +201,7 @@
func (n *Node) AsArg() *Arg { return (*Arg)(n) }
func (n *Node) AsAssert() *Assert { return (*Assert)(n) }
func (n *Node) AsAssign() *Assign { return (*Assign)(n) }
+func (n *Node) AsChoose() *Choose { return (*Choose)(n) }
func (n *Node) AsConst() *Const { return (*Const)(n) }
func (n *Node) AsExpr() *Expr { return (*Expr)(n) }
func (n *Node) AsField() *Field { return (*Field)(n) }
@@ -647,6 +651,23 @@
}
}
+// Choose is "choose ID2: List0":
+// - ID2: name
+// - List0: <Expr> method names.
+type Choose Node
+
+func (n *Choose) AsNode() *Node { return (*Node)(n) }
+func (n *Choose) Name() t.ID { return n.id2 }
+func (n *Choose) Args() []*Node { return n.list0 }
+
+func NewChoose(name t.ID, args []*Node) *Choose {
+ return &Choose{
+ kind: KChoose,
+ id2: name,
+ list0: args,
+ }
+}
+
// Ret is "return LHS" or "yield LHS":
// - FlagsReturnsError LHS is an error status
// - ID0: <IDReturn|IDYield>
@@ -891,6 +912,32 @@
return (end.kind == KRet) && (end.AsRet().Keyword() == t.IDReturn)
}
+func fieldsEq(xs []*Node, ys []*Node) bool {
+ if len(xs) != len(ys) {
+ return false
+ }
+ for i := range xs {
+ x, y := xs[i].AsField(), ys[i].AsField()
+ if (x.Name() != y.Name()) || !x.XType().Eq(y.XType()) {
+ return false
+ }
+ }
+ return true
+}
+
+func (n *Func) CheckChooseCompatible(o *Func) error {
+ if n.Effect() != o.Effect() {
+ return fmt.Errorf("different effects")
+ }
+ if !fieldsEq(n.In().Fields(), o.In().Fields()) {
+ return fmt.Errorf("different args type")
+ }
+ if !n.Out().Eq(o.Out()) {
+ return fmt.Errorf("different return type")
+ }
+ return nil
+}
+
func NewFunc(flags Flags, filename string, line uint32, receiverName t.ID, funcName t.ID, in *Struct, out *TypeExpr, asserts []*Node, body []*Node) *Func {
return &Func{
kind: KFunc,
@@ -1030,6 +1077,7 @@
// Statement means one of:
// - Assert
// - Assign
+// - Choose
// - IOBind
// - If
// - Iterate
diff --git a/lang/check/bounds.go b/lang/check/bounds.go
index d676efb..a7b5c31 100644
--- a/lang/check/bounds.go
+++ b/lang/check/bounds.go
@@ -245,6 +245,9 @@
return err
}
+ case a.KChoose:
+ // No-op.
+
case a.KIOBind:
n := n.AsIOBind()
if _, err := q.bcheckExpr(n.IO(), 0); err != nil {
diff --git a/lang/check/resolve.go b/lang/check/resolve.go
index c09629c..2ae723a 100644
--- a/lang/check/resolve.go
+++ b/lang/check/resolve.go
@@ -34,6 +34,7 @@
typeExprGeneric2 = a.NewTypeExpr(0, t.IDBase, t.IDDagger2, nil, nil, nil)
typeExprIdeal = a.NewTypeExpr(0, t.IDBase, t.IDQIdeal, nil, nil, nil)
typeExprList = a.NewTypeExpr(0, t.IDBase, t.IDComma, nil, nil, nil)
+ typeExprNonNullptr = a.NewTypeExpr(0, t.IDBase, t.IDQNonNullptr, nil, nil, nil)
typeExprNullptr = a.NewTypeExpr(0, t.IDBase, t.IDQNullptr, nil, nil, nil)
typeExprPackage = a.NewTypeExpr(0, t.IDBase, t.IDQPackage, nil, nil, nil)
typeExprPlaceholder = a.NewTypeExpr(0, t.IDBase, t.IDQPlaceholder, nil, nil, nil)
diff --git a/lang/check/type.go b/lang/check/type.go
index 01d243b..519f35d 100644
--- a/lang/check/type.go
+++ b/lang/check/type.go
@@ -63,6 +63,11 @@
return err
}
+ case a.KChoose:
+ if err := q.tcheckChoose(n.AsChoose()); err != nil {
+ return err
+ }
+
case a.KIf:
for n := n.AsIf(); n != nil; n = n.ElseIf() {
cond := n.Condition()
@@ -1061,6 +1066,29 @@
return nil
}
+func (q *checker) tcheckChoose(n *a.Choose) error {
+ qqid := q.astFunc.QQID()
+ fQQID := t.QQID{qqid[0], qqid[1], n.Name()}
+ f := q.c.funcs[fQQID]
+ if f == nil {
+ return fmt.Errorf("check: no function named %q", fQQID.Str(q.tm))
+ }
+ for _, o := range n.Args() {
+ o := o.AsExpr()
+ gQQID := t.QQID{qqid[0], qqid[1], o.Ident()}
+ g := q.c.funcs[gQQID]
+ if g == nil {
+ return fmt.Errorf("check: no function named %q", gQQID.Str(q.tm))
+ } else if err := f.CheckChooseCompatible(g); err != nil {
+ return fmt.Errorf("check: incompatible choose functions %q and %q: %v",
+ fQQID.Str(q.tm), gQQID.Str(q.tm), err)
+ }
+ o.SetMBounds(bounds{one, one})
+ o.SetMType(typeExprNonNullptr)
+ }
+ return nil
+}
+
var comparisonOps = [...]bool{
t.IDXBinaryNotEq: true,
t.IDXBinaryLessThan: true,
diff --git a/lang/parse/parse.go b/lang/parse/parse.go
index d496231..5fd15e2 100644
--- a/lang/parse/parse.go
+++ b/lang/parse/parse.go
@@ -309,7 +309,7 @@
implements := []*a.Node(nil)
if p.peek1() == t.IDImplements {
p.src = p.src[1:]
- implements, err = p.parseList(t.IDOpenParen, (*parser).parseQualifiedIdentNode)
+ implements, err = p.parseList(t.IDOpenParen, (*parser).parseQualifiedIdentAsTypeExprNode)
if err != nil {
return nil, err
}
@@ -340,7 +340,7 @@
return nil, fmt.Errorf(`parse: unrecognized top level declaration at %s:%d`, p.filename, line)
}
-func (p *parser) parseQualifiedIdentNode() (*a.Node, error) {
+func (p *parser) parseQualifiedIdentAsTypeExprNode() (*a.Node, error) {
pkg, name, err := p.parseQualifiedIdent()
if err != nil {
return nil, err
@@ -367,6 +367,14 @@
return x, y, nil
}
+func (p *parser) parseIdentAsExprNode() (*a.Node, error) {
+ id, err := p.parseIdent()
+ if err != nil {
+ return nil, err
+ }
+ return a.NewExpr(0, 0, id, nil, nil, nil, nil).AsNode(), nil
+}
+
func (p *parser) parseIdent() (t.ID, error) {
if len(p.src) == 0 {
return 0, fmt.Errorf(`parse: expected identifier at %s:%d`, p.filename, p.line())
@@ -790,6 +798,31 @@
n.SetJumpTarget(loop)
return n.AsNode(), nil
+ case t.IDChoose:
+ p.src = p.src[1:]
+ if p.funcEffect.Pure() {
+ return nil, fmt.Errorf(`parse: choose within pure function at %s:%d`, p.filename, p.line())
+ }
+ name, err := p.parseIdent()
+ if err != nil {
+ return nil, err
+ }
+ if x := p.peek1(); x != t.IDEq {
+ got := p.tm.ByID(x)
+ return nil, fmt.Errorf(`parse: expected "=", got %q at %s:%d`, got, p.filename, p.line())
+ }
+ p.src = p.src[1:]
+ if x := p.peek1(); x != t.IDOpenBracket {
+ got := p.tm.ByID(x)
+ return nil, fmt.Errorf(`parse: expected "[", got %q at %s:%d`, got, p.filename, p.line())
+ }
+ p.src = p.src[1:]
+ args, err := p.parseList(t.IDCloseBracket, (*parser).parseIdentAsExprNode)
+ if err != nil {
+ return nil, err
+ }
+ return a.NewChoose(name, args).AsNode(), nil
+
case t.IDIOBind, t.IDIOLimit:
return p.parseIOBindNode()
diff --git a/lang/token/list.go b/lang/token/list.go
index 086cf44..3f5e456 100644
--- a/lang/token/list.go
+++ b/lang/token/list.go
@@ -457,6 +457,7 @@
IDDagger1 = ID(0x106)
IDDagger2 = ID(0x107)
+ IDQNonNullptr = ID(0x10A)
IDQNullptr = ID(0x10B)
IDQPackage = ID(0x10C)
IDQPlaceholder = ID(0x10D)
@@ -811,6 +812,10 @@
IDDagger1: "†", // U+2020 DAGGER
IDDagger2: "‡", // U+2021 DOUBLE DAGGER
+ // IDQNonNullptr is used by the type checker to build an artificial MType
+ // for function pointers.
+ IDQNonNullptr: "«NonNullptr»",
+
// IDQNullptr is used by the type checker to build an artificial MType for
// the nullptr literal.
IDQNullptr: "«Nullptr»",
diff --git a/release/c/wuffs-unsupported-snapshot.c b/release/c/wuffs-unsupported-snapshot.c
index c45ed3e..ace5131 100644
--- a/release/c/wuffs-unsupported-snapshot.c
+++ b/release/c/wuffs-unsupported-snapshot.c
@@ -29851,6 +29851,11 @@
wuffs_base__slice_u8 a_curr);
static wuffs_base__empty_struct
+wuffs_png__decoder__filter_1_distance_4_fallback(
+ wuffs_png__decoder* self,
+ wuffs_base__slice_u8 a_curr);
+
+static wuffs_base__empty_struct
wuffs_png__decoder__filter_2(
wuffs_png__decoder* self,
wuffs_base__slice_u8 a_curr,
@@ -29886,6 +29891,10 @@
wuffs_base__slice_u8 a_curr,
wuffs_base__slice_u8 a_prev);
+static wuffs_base__empty_struct
+wuffs_png__decoder__choose_filter_implementations(
+ wuffs_png__decoder* self);
+
static wuffs_base__status
wuffs_png__decoder__decode_plte(
wuffs_png__decoder* self,
@@ -30056,6 +30065,38 @@
return wuffs_base__make_empty_struct();
}
+// -------- func png.decoder.filter_1_distance_4_fallback
+
+static wuffs_base__empty_struct
+wuffs_png__decoder__filter_1_distance_4_fallback(
+ wuffs_png__decoder* self,
+ wuffs_base__slice_u8 a_curr) {
+ wuffs_base__slice_u8 v_p = {0};
+ uint8_t v_fa0 = 0;
+ uint8_t v_fa1 = 0;
+ uint8_t v_fa2 = 0;
+ uint8_t v_fa3 = 0;
+
+ {
+ wuffs_base__slice_u8 i_slice_p = a_curr;
+ v_p = i_slice_p;
+ v_p.len = 4;
+ uint8_t* i_end0_p = i_slice_p.ptr + (i_slice_p.len / 4) * 4;
+ while (v_p.ptr < i_end0_p) {
+ v_fa0 += v_p.ptr[0];
+ v_p.ptr[0] = v_fa0;
+ v_fa1 += v_p.ptr[1];
+ v_p.ptr[1] = v_fa1;
+ v_fa2 += v_p.ptr[2];
+ v_p.ptr[2] = v_fa2;
+ v_fa3 += v_p.ptr[3];
+ v_p.ptr[3] = v_fa3;
+ v_p.ptr += 4;
+ }
+ }
+ return wuffs_base__make_empty_struct();
+}
+
// -------- func png.decoder.filter_2
static wuffs_base__empty_struct
@@ -30461,6 +30502,7 @@
goto exit;
}
self->private_impl.f_workbuf_length = (((uint64_t)(self->private_impl.f_height)) * (1 + self->private_impl.f_bytes_per_row));
+ wuffs_png__decoder__choose_filter_implementations(self);
{
WUFFS_BASE__COROUTINE_SUSPENSION_POINT(11);
if (WUFFS_BASE__UNLIKELY(iop_a_src == io2_a_src)) {
@@ -30664,6 +30706,18 @@
return status;
}
+// -------- func png.decoder.choose_filter_implementations
+
+static wuffs_base__empty_struct
+wuffs_png__decoder__choose_filter_implementations(
+ wuffs_png__decoder* self) {
+ if (self->private_impl.f_filter_distance == 3) {
+ } else if (self->private_impl.f_filter_distance == 4) {
+ self->private_impl.choosy_filter_1 = &wuffs_png__decoder__filter_1_distance_4_fallback;
+ }
+ return wuffs_base__make_empty_struct();
+}
+
// -------- func png.decoder.decode_plte
static wuffs_base__status
diff --git a/std/png/decode_filter_fallback.wuffs b/std/png/decode_filter_fallback.wuffs
index 634f230..972d990 100644
--- a/std/png/decode_filter_fallback.wuffs
+++ b/std/png/decode_filter_fallback.wuffs
@@ -37,6 +37,25 @@
} endwhile
}
+pri func decoder.filter_1_distance_4_fallback!(curr: slice base.u8) {
+ var p : slice base.u8
+ var fa0 : base.u8
+ var fa1 : base.u8
+ var fa2 : base.u8
+ var fa3 : base.u8
+
+ iterate (p = args.curr)(length: 4, unroll: 1) {
+ fa0 ~mod+= p[0]
+ p[0] = fa0
+ fa1 ~mod+= p[1]
+ p[1] = fa1
+ fa2 ~mod+= p[2]
+ p[2] = fa2
+ fa3 ~mod+= p[3]
+ p[3] = fa3
+ }
+}
+
pri func decoder.filter_2!(curr: slice base.u8, prev: slice base.u8),
choosy,
{
diff --git a/std/png/decode_png.wuffs b/std/png/decode_png.wuffs
index 882baf5..8a89283 100644
--- a/std/png/decode_png.wuffs
+++ b/std/png/decode_png.wuffs
@@ -177,6 +177,7 @@
return "#unsupported PNG file"
}
this.workbuf_length = (this.height as base.u64) * (1 + this.bytes_per_row)
+ this.choose_filter_implementations!()
// Compression.
a8 = args.src.read_u8?()
@@ -244,6 +245,14 @@
this.call_sequence = 3
}
+pri func decoder.choose_filter_implementations!() {
+ if this.filter_distance == 3 {
+ // TODO.
+ } else if this.filter_distance == 4 {
+ choose filter_1 = [filter_1_distance_4_fallback]
+ }
+}
+
pri func decoder.decode_plte?(src: base.io_reader) {
var num_palette_entries : base.u32[..= 256]
var i : base.u32
diff --git a/test/c/std/png.c b/test/c/std/png.c
index d9aafa0..29cb54d 100644
--- a/test/c/std/png.c
+++ b/test/c/std/png.c
@@ -89,20 +89,24 @@
}
const char* //
-do_wuffs_png_swizzle(wuffs_png__decoder* dec,
- uint32_t width,
+do_wuffs_png_swizzle(uint32_t width,
uint32_t height,
uint8_t filter_distance,
wuffs_base__slice_u8 dst,
wuffs_base__slice_u8 workbuf) {
- dec->private_impl.f_width = width;
- dec->private_impl.f_height = height;
- dec->private_impl.f_bytes_per_row = width;
- dec->private_impl.f_filter_distance = filter_distance;
+ wuffs_png__decoder dec;
+ CHECK_STATUS("initialize", wuffs_png__decoder__initialize(
+ &dec, sizeof dec, WUFFS_VERSION,
+ WUFFS_INITIALIZE__DEFAULT_OPTIONS));
+ dec.private_impl.f_width = width;
+ dec.private_impl.f_height = height;
+ dec.private_impl.f_bytes_per_row = width;
+ dec.private_impl.f_filter_distance = filter_distance;
+ wuffs_png__decoder__choose_filter_implementations(&dec);
CHECK_STATUS("prepare",
wuffs_base__pixel_swizzler__prepare(
- &dec->private_impl.f_swizzler,
+ &dec.private_impl.f_swizzler,
wuffs_base__make_pixel_format(WUFFS_BASE__PIXEL_FORMAT__Y),
wuffs_base__empty_slice_u8(),
wuffs_base__make_pixel_format(WUFFS_BASE__PIXEL_FORMAT__Y),
@@ -117,7 +121,7 @@
CHECK_STATUS("set_from_slice",
wuffs_base__pixel_buffer__set_from_slice(&pb, &pc, dst));
CHECK_STATUS("filter_and_swizzle",
- wuffs_png__decoder__filter_and_swizzle(dec, &pb, workbuf));
+ wuffs_png__decoder__filter_and_swizzle(&dec, &pb, workbuf));
return NULL;
}
@@ -198,11 +202,6 @@
{0xAA, 0xD5, 0xC6, 0xE0, 0x36, 0x16, 0x42, 0x33, 0x8F, 0x77, 0xA1, 0x8E},
};
- wuffs_png__decoder dec;
- CHECK_STATUS("initialize", wuffs_png__decoder__initialize(
- &dec, sizeof dec, WUFFS_VERSION,
- WUFFS_INITIALIZE__DEFAULT_OPTIONS));
-
int filter;
for (filter = 1; filter <= 4; filter++) {
int filter_distance;
@@ -218,7 +217,7 @@
memcpy(g_work_slice_u8.ptr + (13 * 1) + 1, src_rows[1], 12);
CHECK_STRING(do_wuffs_png_swizzle(
- &dec, 12, 2, filter_distance, g_have_slice_u8,
+ 12, 2, filter_distance, g_have_slice_u8,
wuffs_base__make_slice_u8(g_work_slice_u8.ptr, 13 * 2)));
wuffs_base__io_buffer have =
@@ -341,11 +340,6 @@
0x65, 0x43, 0x69, 0x72, 0x63, 0x75, 0x73, 0x53, 0x61, 0x6E, 0x64, 0x73},
};
- wuffs_png__decoder dec;
- CHECK_STATUS("initialize", wuffs_png__decoder__initialize(
- &dec, sizeof dec, WUFFS_VERSION,
- WUFFS_INITIALIZE__DEFAULT_OPTIONS));
-
memcpy(g_src_slice_u8.ptr + (97 * 0) + 1, src_rows[0], 96);
memcpy(g_src_slice_u8.ptr + (97 * 1) + 1, src_rows[1], 96);
@@ -370,7 +364,7 @@
wuffs_base__make_slice_u8(g_src_slice_u8.ptr, 97 * 2)));
CHECK_STRING(do_wuffs_png_swizzle(
- &dec, 96, 2, filter_distance, g_have_slice_u8,
+ 96, 2, filter_distance, g_have_slice_u8,
wuffs_base__make_slice_u8(g_work_slice_u8.ptr, 97 * 2)));
wuffs_base__io_buffer have =
@@ -573,6 +567,7 @@
dec.private_impl.f_height = height;
dec.private_impl.f_bytes_per_row = bytes_per_row;
dec.private_impl.f_filter_distance = filter_distance;
+ wuffs_png__decoder__choose_filter_implementations(&dec);
CHECK_STATUS("prepare",
wuffs_base__pixel_swizzler__prepare(