Add some ast.Expr.IsEtc methods
diff --git a/internal/cgen/builtin.go b/internal/cgen/builtin.go
index 54eb913..78f316a 100644
--- a/internal/cgen/builtin.go
+++ b/internal/cgen/builtin.go
@@ -728,16 +728,16 @@
}
func (g *gen) writeExprDotPtr(b *buffer, n *a.Expr, sideEffectsOnly bool, depth uint32) error {
- if n.Operator() == t.IDDotDot {
- if err := g.writeExpr(b, n.LHS().AsExpr(), sideEffectsOnly, depth); err != nil {
+ if arrayOrSlice, lo, _, ok := n.IsSlice(); ok {
+ if err := g.writeExpr(b, arrayOrSlice, sideEffectsOnly, depth); err != nil {
return err
}
- if n.LHS().AsExpr().MType().IsSliceType() {
+ if arrayOrSlice.MType().IsSliceType() {
b.writes(".ptr")
}
- if n.MHS() != nil {
+ if lo != nil {
b.writes(" + ")
- if err := g.writeExpr(b, n.MHS().AsExpr(), sideEffectsOnly, depth); err != nil {
+ if err := g.writeExpr(b, lo, sideEffectsOnly, depth); err != nil {
return err
}
}
@@ -969,13 +969,8 @@
// matchFooIndexIndexPlus8 matches n with "foo[index .. index + 8]" or "foo[..
// 8]". It returns a nil foo if there isn't a match.
func matchFooIndexIndexPlus8(n *a.Expr) (foo *a.Expr, index *a.Expr) {
- if n.Operator() != t.IDDotDot {
- return nil, nil
- }
- foo = n.LHS().AsExpr()
- index = n.MHS().AsExpr()
- rhs := n.RHS().AsExpr()
- if rhs == nil {
+ foo, index, rhs, ok := n.IsSlice()
+ if !ok || (rhs == nil) {
return nil, nil
}
diff --git a/internal/cgen/cgen.go b/internal/cgen/cgen.go
index 146a814..72242c0 100644
--- a/internal/cgen/cgen.go
+++ b/internal/cgen/cgen.go
@@ -1035,9 +1035,9 @@
}
func (g *gen) writeConstList(b *buffer, n *a.Expr) error {
- if n.Operator() == t.IDComma {
+ if args, ok := n.IsList(); ok {
b.writeb('{')
- for i, o := range n.Args() {
+ for i, o := range args {
if i&7 == 0 {
b.writeb('\n')
}
diff --git a/internal/cgen/var.go b/internal/cgen/var.go
index 0dadad7..9e59443 100644
--- a/internal/cgen/var.go
+++ b/internal/cgen/var.go
@@ -45,8 +45,8 @@
switch p.Kind() {
case a.KExpr:
// Look for p matching "args.name.etc(etc)".
- recv, meth, args := p.AsExpr().IsMethodCall()
- if recv == nil {
+ recv, meth, args, ok := p.AsExpr().IsMethodCall()
+ if !ok {
return nil
}
if recv.IsArgsDotFoo() == name {
diff --git a/lang/ast/ast.go b/lang/ast/ast.go
index b1469c8..c1d4613 100644
--- a/lang/ast/ast.go
+++ b/lang/ast/ast.go
@@ -386,13 +386,41 @@
return 0
}
-func (n *Expr) IsMethodCall() (recv *Expr, meth t.ID, args []*Node) {
+func (n *Expr) IsIndex() (arrayOrSlice *Expr, index *Expr, ok bool) {
+ if n.id0 == t.IDOpenBracket {
+ return n.lhs.AsExpr(), n.rhs.AsExpr(), true
+ }
+ return nil, nil, false
+}
+
+func (n *Expr) IsList() (args []*Node, ok bool) {
+ if n.id0 == t.IDComma {
+ return n.list0, true
+ }
+ return nil, false
+}
+
+func (n *Expr) IsMethodCall() (recv *Expr, meth t.ID, args []*Node, ok bool) {
if n.id0 == t.IDOpenParen {
if o := n.lhs; o.id0 == t.IDDot {
- return o.lhs.AsExpr(), o.id2, n.list0
+ return o.lhs.AsExpr(), o.id2, n.list0, true
}
}
- return nil, 0, nil
+ return nil, 0, nil, false
+}
+
+func (n *Expr) IsSelector() (lhs *Expr, field t.ID, ok bool) {
+ if n.id0 == t.IDDot {
+ return n.lhs.AsExpr(), n.id2, true
+ }
+ return nil, 0, false
+}
+
+func (n *Expr) IsSlice() (arrayOrSlice *Expr, lo *Expr, hi *Expr, ok bool) {
+ if n.id0 == t.IDDotDot {
+ return n.lhs.AsExpr(), n.mhs.AsExpr(), n.rhs.AsExpr(), true
+ }
+ return nil, nil, nil, false
}
func NewExpr(flags Flags, operator t.ID, ident t.ID, lhs *Node, mhs *Node, rhs *Node, args []*Node) *Expr {
diff --git a/lang/check/bounds.go b/lang/check/bounds.go
index 46815e7..6301680 100644
--- a/lang/check/bounds.go
+++ b/lang/check/bounds.go
@@ -356,15 +356,8 @@
func (q *checker) hasIsErrorFact(id t.ID) bool {
for _, x := range q.facts {
- if (x.Operator() != t.IDOpenParen) || (len(x.Args()) != 0) {
- continue
- }
- x = x.LHS().AsExpr()
- if (x.Operator() != t.IDDot) || (x.Ident() != t.IDIsError) {
- continue
- }
- x = x.LHS().AsExpr()
- if (x.Operator() != 0) || (x.Ident() != id) {
+ if lhs, meth, args, _ := x.IsMethodCall(); (meth != t.IDIsError) || (len(args) != 0) ||
+ (lhs.Operator() != 0) || (lhs.Ident() != id) {
continue
}
return true
@@ -525,16 +518,16 @@
}
// Look for "lhs = x[i .. j]" where i and j are constants.
- if rhs.Operator() == t.IDDotDot {
+ if _, i, j, ok := rhs.IsSlice(); ok {
icv := (*big.Int)(nil)
- if i := rhs.MHS().AsExpr(); i == nil {
+ if i == nil {
icv = zero
} else if i.ConstValue() != nil {
icv = i.ConstValue()
}
jcv := (*big.Int)(nil)
- if j := rhs.RHS().AsExpr(); (j != nil) && (j.ConstValue() != nil) {
+ if (j != nil) && (j.ConstValue() != nil) {
jcv = j.ConstValue()
}
@@ -1265,15 +1258,8 @@
func (q *checker) canUndoByte(recv *a.Expr) error {
for _, x := range q.facts {
- if x.Operator() != t.IDOpenParen || len(x.Args()) != 0 {
- continue
- }
- x = x.LHS().AsExpr()
- if x.Operator() != t.IDDot || x.Ident() != t.IDCanUndoByte {
- continue
- }
- x = x.LHS().AsExpr()
- if !x.Eq(recv) {
+ if lhs, meth, args, _ := x.IsMethodCall(); (meth != t.IDCanUndoByte) || (len(args) != 0) ||
+ !lhs.Eq(recv) {
continue
}
return q.facts.update(func(o *a.Expr) (*a.Expr, error) {
diff --git a/lang/check/check.go b/lang/check/check.go
index 64e6cfd..05b3004 100644
--- a/lang/check/check.go
+++ b/lang/check/check.go
@@ -391,12 +391,13 @@
func (c *Checker) checkConstElement(n *a.Expr, nb bounds, nLists int) error {
if nLists > 0 {
nLists--
- if n.Operator() != t.IDComma {
+ if args, ok := n.IsList(); !ok {
return fmt.Errorf("invalid const value %q", n.Str(c.tm))
- }
- for _, o := range n.Args() {
- if err := c.checkConstElement(o.AsExpr(), nb, nLists); err != nil {
- return err
+ } else {
+ for _, o := range args {
+ if err := c.checkConstElement(o.AsExpr(), nb, nLists); err != nil {
+ return err
+ }
}
}
return nil
diff --git a/lang/check/optimize.go b/lang/check/optimize.go
index 1517363..1f9b723 100644
--- a/lang/check/optimize.go
+++ b/lang/check/optimize.go
@@ -55,16 +55,16 @@
// Check if receiver looks like "a[i .. j]" where i and j are constants and
// ((j - i) >= advance).
- if receiver.Operator() == t.IDDotDot {
+ if _, i, j, ok := receiver.IsSlice(); ok {
icv := (*big.Int)(nil)
- if i := receiver.MHS().AsExpr(); i == nil {
+ if i == nil {
icv = zero
} else if i.ConstValue() != nil {
icv = i.ConstValue()
}
jcv := (*big.Int)(nil)
- if j := receiver.RHS().AsExpr(); (j != nil) && (j.ConstValue() != nil) {
+ if (j != nil) && (j.ConstValue() != nil) {
jcv = j.ConstValue()
}