blob: a2d61280589ef093d19036e1b9f2576976ba9f8e [file] [log] [blame]
// Copyright 2017 The Puffs Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package check
import (
"errors"
"fmt"
"math/big"
a "github.com/google/puffs/lang/ast"
t "github.com/google/puffs/lang/token"
)
// otherHandSide returns the operator and other hand side when n is an
// binary-op expression like "thisHS == thatHS" or "thatHS < thisHS" (which is
// equivalent to "thisHS > thatHS"). If not, it returns (0, nil).
func otherHandSide(n *a.Expr, thisHS *a.Expr) (op t.ID, thatHS *a.Expr) {
op = n.ID0()
reverseOp := t.ID(0)
switch op.Key() {
case t.KeyXBinaryNotEq:
reverseOp = t.IDXBinaryNotEq
case t.KeyXBinaryLessThan:
reverseOp = t.IDXBinaryGreaterThan
case t.KeyXBinaryLessEq:
reverseOp = t.IDXBinaryGreaterEq
case t.KeyXBinaryEqEq:
reverseOp = t.IDXBinaryEqEq
case t.KeyXBinaryGreaterEq:
reverseOp = t.IDXBinaryLessEq
case t.KeyXBinaryGreaterThan:
reverseOp = t.IDXBinaryLessThan
}
if reverseOp != 0 {
if thisHS.Eq(n.LHS().Expr()) {
return op, n.RHS().Expr()
}
if thisHS.Eq(n.RHS().Expr()) {
return reverseOp, n.LHS().Expr()
}
}
return 0, nil
}
type facts []*a.Expr
func (z *facts) appendFact(fact *a.Expr) {
// TODO: make this faster than O(N) by keeping facts sorted somehow?
for _, x := range *z {
if x.Eq(fact) {
return
}
}
switch fact.ID0().Key() {
case 0:
for _, x := range *z {
if op, other := otherHandSide(x, fact); op.Key() == t.KeyXBinaryEqEq {
z.appendFact(other)
}
}
case t.KeyXBinaryAnd:
z.appendFact(fact.LHS().Expr())
z.appendFact(fact.RHS().Expr())
return
case t.KeyXAssociativeAnd:
for _, a := range fact.Args() {
z.appendFact(a.Expr())
}
return
}
*z = append(*z, fact)
}
// update applies f to each fact, replacing the slice element with the result
// of the function call. The slice is then compacted to remove all nils.
func (z *facts) update(f func(*a.Expr) (*a.Expr, error)) error {
i := 0
for _, x := range *z {
x, err := f(x)
if err != nil {
return err
}
if x != nil {
(*z)[i] = x
i++
}
}
for j := i; j < len(*z); j++ {
(*z)[j] = nil
}
*z = (*z)[:i]
return nil
}
func (z facts) refine(n *a.Expr, nMin *big.Int, nMax *big.Int, tm *t.Map) (*big.Int, *big.Int, error) {
if nMin == nil || nMax == nil {
return nMin, nMax, nil
}
for _, x := range z {
op, other := otherHandSide(x, n)
if op == 0 {
continue
}
cv := other.ConstValue()
if cv == nil {
continue
}
originalNMin, originalNMax, changed := nMin, nMax, false
switch op.Key() {
case t.KeyXBinaryNotEq:
if nMin.Cmp(cv) == 0 {
nMin = add1(nMin)
changed = true
} else if nMax.Cmp(cv) == 0 {
nMax = sub1(nMax)
changed = true
}
case t.KeyXBinaryLessThan:
if nMax.Cmp(cv) >= 0 {
nMax = sub1(cv)
changed = true
}
case t.KeyXBinaryLessEq:
if nMax.Cmp(cv) > 0 {
nMax = cv
changed = true
}
case t.KeyXBinaryEqEq:
nMin, nMax = cv, cv
changed = true
case t.KeyXBinaryGreaterEq:
if nMin.Cmp(cv) < 0 {
nMin = cv
changed = true
}
case t.KeyXBinaryGreaterThan:
if nMin.Cmp(cv) <= 0 {
nMin = add1(cv)
changed = true
}
}
if changed && nMin.Cmp(nMax) > 0 {
return nil, nil, fmt.Errorf("check: expression %q bounds [%v..%v] inconsistent with fact %q",
n.String(tm), originalNMin, originalNMax, x.String(tm))
}
}
return nMin, nMax, nil
}
// simplify returns a simplified form of n. For example, (x - x) becomes 0.
func simplify(tm *t.Map, n *a.Expr) (*a.Expr, error) {
// TODO: be rigorous about this, not ad hoc.
op, lhs, rhs := parseBinaryOp(n)
if lhs != nil && rhs != nil {
if lcv, rcv := lhs.ConstValue(), rhs.ConstValue(); lcv != nil && rcv != nil {
ncv, err := evalConstValueBinaryOp(tm, n, lcv, rcv)
if err != nil {
return nil, err
}
id, err := tm.Insert(ncv.String())
if err != nil {
return nil, err
}
o := a.NewExpr(a.FlagsTypeChecked, 0, id, nil, nil, nil, nil)
o.SetConstValue(ncv)
o.SetMType(typeExprIdeal)
return o, nil
}
}
switch op.Key() {
case t.KeyXBinaryPlus:
// TODO: more constant folding, so ((x + 1) + 1) becomes (x + 2).
case t.KeyXBinaryMinus:
if lhs.Eq(rhs) {
return zeroExpr, nil
}
if lOp, lLHS, lRHS := parseBinaryOp(lhs); lOp.Key() == t.KeyXBinaryPlus {
if lLHS.Eq(rhs) {
return lRHS, nil
}
if lRHS.Eq(rhs) {
return lLHS, nil
}
}
case t.KeyXBinaryNotEq, t.KeyXBinaryLessThan, t.KeyXBinaryLessEq,
t.KeyXBinaryEqEq, t.KeyXBinaryGreaterEq, t.KeyXBinaryGreaterThan:
l, err := simplify(tm, lhs)
if err != nil {
return nil, err
}
r, err := simplify(tm, rhs)
if err != nil {
return nil, err
}
if l != lhs || r != rhs {
o := a.NewExpr(a.FlagsTypeChecked, op, 0, l.Node(), nil, r.Node(), nil)
o.SetConstValue(n.ConstValue())
o.SetMType(n.MType())
return o, nil
}
}
return n, nil
}
func argValue(tm *t.Map, args []*a.Node, name string) *a.Expr {
if x := tm.ByName(name); x != 0 {
for _, a := range args {
if a.Arg().Name() == x {
return a.Arg().Value()
}
}
}
return nil
}
// parseBinaryOp parses n as "lhs op rhs".
func parseBinaryOp(n *a.Expr) (op t.ID, lhs *a.Expr, rhs *a.Expr) {
if !n.ID0().IsBinaryOp() {
return 0, nil, nil
}
op = n.ID0()
if op.Key() == t.KeyAs {
return 0, nil, nil
}
return op, n.LHS().Expr(), n.RHS().Expr()
}
func proveBinaryOpConstValues(op t.Key, lMin *big.Int, lMax *big.Int, rMin *big.Int, rMax *big.Int) (ok bool) {
switch op {
case t.KeyXBinaryNotEq:
return lMax.Cmp(rMin) < 0 || lMin.Cmp(rMax) > 0
case t.KeyXBinaryLessThan:
return lMax.Cmp(rMin) < 0
case t.KeyXBinaryLessEq:
return lMax.Cmp(rMin) <= 0
case t.KeyXBinaryEqEq:
return lMin.Cmp(rMax) == 0 && lMax.Cmp(rMin) == 0
case t.KeyXBinaryGreaterEq:
return lMin.Cmp(rMax) >= 0
case t.KeyXBinaryGreaterThan:
return lMin.Cmp(rMax) > 0
}
return false
}
func (q *checker) proveBinaryOp(op t.Key, lhs *a.Expr, rhs *a.Expr) error {
lcv := lhs.ConstValue()
if lcv != nil {
rMin, rMax, err := q.bcheckExpr(rhs, 0)
if err != nil {
return err
}
if proveBinaryOpConstValues(op, lcv, lcv, rMin, rMax) {
return nil
}
}
rcv := rhs.ConstValue()
if rcv != nil {
lMin, lMax, err := q.bcheckExpr(lhs, 0)
if err != nil {
return err
}
if proveBinaryOpConstValues(op, lMin, lMax, rcv, rcv) {
return nil
}
}
for _, x := range q.facts {
if !x.LHS().Expr().Eq(lhs) {
continue
}
factOp := x.ID0().Key()
if opImpliesOp(factOp, op) && x.RHS().Expr().Eq(rhs) {
return nil
}
if factOp == t.KeyXBinaryEqEq && rcv != nil {
if factCV := x.RHS().Expr().ConstValue(); factCV != nil {
switch op {
case t.KeyXBinaryNotEq:
return errFailedOrNil(factCV.Cmp(rcv) != 0)
case t.KeyXBinaryLessThan:
return errFailedOrNil(factCV.Cmp(rcv) < 0)
case t.KeyXBinaryLessEq:
return errFailedOrNil(factCV.Cmp(rcv) <= 0)
case t.KeyXBinaryEqEq:
return errFailedOrNil(factCV.Cmp(rcv) == 0)
case t.KeyXBinaryGreaterEq:
return errFailedOrNil(factCV.Cmp(rcv) >= 0)
case t.KeyXBinaryGreaterThan:
return errFailedOrNil(factCV.Cmp(rcv) > 0)
}
}
}
}
return errFailed
}
// opImpliesOp returns whether the first op implies the second. For example,
// knowing "x < y" implies that "x != y" and "x <= y".
func opImpliesOp(op0 t.Key, op1 t.Key) bool {
if op0 == op1 {
return true
}
switch op0 {
case t.KeyXBinaryLessThan:
return op1 == t.KeyXBinaryNotEq || op1 == t.KeyXBinaryLessEq
case t.KeyXBinaryGreaterThan:
return op1 == t.KeyXBinaryNotEq || op1 == t.KeyXBinaryGreaterEq
}
return false
}
func errFailedOrNil(ok bool) error {
if ok {
return nil
}
return errFailed
}
var errFailed = errors.New("failed")
func proveReasonRequirement(q *checker, op t.ID, lhs *a.Expr, rhs *a.Expr) error {
if !op.IsXBinaryOp() {
return fmt.Errorf(
"check: internal error: proveReasonRequirement token.Key (0x%02X) is not an XBinaryOp", op.Key())
}
if err := q.proveBinaryOp(op.Key(), lhs, rhs); err != nil {
n := a.NewExpr(a.FlagsTypeChecked, op, 0, lhs.Node(), nil, rhs.Node(), nil)
return fmt.Errorf("cannot prove %q: %v", n.String(q.tm), err)
}
return nil
}