package dsutil

import (
	"context"
	"math/rand"
	"sort"
	"time"

	"cloud.google.com/go/datastore"
	"go.skia.org/infra/go/ds"
	"go.skia.org/infra/go/sklog"
	"go.skia.org/infra/go/util"
)

const (
	// DefaultConsistencyDelta is the default (assumed) time it takes for query results to be consistent
	DefaultConsistencyDelta time.Duration = 5 * 60 * 1000 * time.Millisecond
)

// RecentKeysList is a utility type to store recently added/deleted keys in an
// entity for a defined duration.
// This allows to augment an eventually consistent query with the most recent
// keys and therefore yield a consistent listing of entries.
//
// It assumes that the used keys are time based (created via TimeSortableKey)
// and can easily be sorted and the underlying time can be extracted.
type RecentKeysList struct {
	// client is the cloud datastore client.
	client *datastore.Client

	// containerKey is the key of the entity where the most recently added keys
	// should be stored.
	containerKey *datastore.Key

	// consistencyDeltaMs is the consistency delta in milliseconds
	consistencyDeltaMs int64
}

// NewRecentKeysList creates a new instance of RecentKeysList. It will store any key
// changes via the Add and Delete functions to the entity identified by the provided
// 'containerKey'.
// Any keys in the container entity that are older than the duration given as
// consistentDelta will be removed. The kind of containerKey is usually of
// kind ds.HELPER_RECENT_KEYS (but this is not strictly required).
func NewRecentKeysList(client *datastore.Client, containerKey *datastore.Key, consistentDelta time.Duration) *RecentKeysList {
	return &RecentKeysList{
		client:             client,
		containerKey:       containerKey,
		consistencyDeltaMs: int64(consistentDelta / time.Millisecond),
	}
}

// Add adds a new key to the set of recently added keys within the given transaction.
// This should be called in the transaction that is used to add the new entity.
// It will add the key to the set of recently added keys and removes any keys
// that are no longer within the defined time delta.
func (l *RecentKeysList) Add(tx *datastore.Transaction, newKey *datastore.Key) error {
	if newKey.ID <= 0 {
		return sklog.FmtErrorf("Key contains invalid numeric ID. It must be generated by calling TimeSortableKey")
	}
	return l.updateRecentKeys(tx, newKey, false)
}

// Delete removes the given Key from the container that contains the recent keys.
// This needs to be called whenever an entity of the target collection is removed
// to make sure there are no dangling keys in the container.
func (l *RecentKeysList) Delete(tx *datastore.Transaction, removeKey *datastore.Key) error {
	return l.updateRecentKeys(tx, removeKey, true)
}

// updateRecentKeys adds or removes the given key from the entity that contains
// the list of recently changed keys.
func (l *RecentKeysList) updateRecentKeys(tx *datastore.Transaction, key *datastore.Key, remove bool) error {
	recent := &Recently{}
	// This works because direct reads of keys via Get(...) are always strongly consistent.
	if err := tx.Get(l.containerKey, recent); err != nil && err != datastore.ErrNoSuchEntity {
		return err
	}

	// If the new key causes a change we write it to the database
	var err error
	if recent.update(key, l.consistencyDeltaMs, remove) {
		_, err = tx.Put(l.containerKey, recent)
	}
	return err
}

// GetRecent returns the collection of recently changed keys, this includes
// added and deleted keys, encapsulated in an instance of Recently.
// This should be called in parallel to an eventually consistent query.
// The returned Recently instance can then be used to combine the recent keys
// with the result of the query to get a consistent snapshot of the collection.
func (l *RecentKeysList) GetRecent() (*Recently, error) {
	ret := &Recently{}
	if err := l.client.Get(context.Background(), l.containerKey, ret); err != nil && err != datastore.ErrNoSuchEntity {
		return nil, err
	}
	return ret, nil
}

// Recently contains recently added and deleted keys that can be combined with
// an eventually-consistent query result to produce a complete list of valid keys.
// Its members should never be accessed directly. Instead the Add and
// Delete methods of the RecentKeysList type should be used or its own Combine
// method.
type Recently struct {
	// Added contains the sorted keys that were added within the consistency time window
	Added []*datastore.Key

	// Deleted contains the sorted keys that were deleted within the consistency time window
	Deleted []*datastore.Key
}

// Combine assumes that all keys are time based (created via TimeSortableKey) and
// therefore sortable. It combines 'queried' with the added and deleted keys,
// deduplicates them and returns the deduplicated keys sorted in ascending order
// which means the underlying times are in descending order (newest first).
func (r *Recently) Combine(queried []*datastore.Key) []*datastore.Key {
	delMap := toMap(r.Deleted)
	addMap := toMap(r.Added)

	// The Added and Deleted keys are mutually exclusive, so we can only have to
	// check the queried keys for recently added and deleted keys.
	var ret []*datastore.Key
	ret = make([]*datastore.Key, 0, len(r.Added)+len(queried))
	ret = append(ret, r.Added...)
	for _, k := range queried {
		if addMap[k.ID] == nil && delMap[k.ID] == nil {
			ret = append(ret, k)
		}
	}
	sort.Slice(ret, func(i, j int) bool { return ret[i].ID < ret[j].ID })
	return ret
}

// update adds or removes a key from the list of keys and guarantees that all entries are unique
// and sorted in ascending order (newest first).
func (r *Recently) update(keyToUpdate *datastore.Key, evConsistentDeltaMs int64, remove bool) bool {
	// Update the current keys and store them back to the database.
	newerThanMs := util.TimeStamp(time.Millisecond) - evConsistentDeltaMs
	changed := false
	changed = updateSortedKeys(&r.Deleted, keyToUpdate, newerThanMs, !remove) || changed
	changed = updateSortedKeys(&r.Added, keyToUpdate, newerThanMs, remove) || changed
	return changed
}

// sortableIDMask is a bit mask with the lowest 63 bits set to 1. It is used to
// invert the id in getSortableTimeID
const sortableIDMask = int64((uint64(1) << 63) - 1)

// TimeSortableKey returns a datastore key for the given kind and timestamp (in ms).
// The returned key has the property that it contains the given timestamp embedded
// in its numeric ID and that it is sortable. The ID is inverted in a way that
// when sorted in ascending order the embedded timestamps are sorted in decending
// order. Thus the newest keys are first.
// The GetTimeFromID function allows to extract the timestamp from the id of the returned key.
//
// NOTE: Any program using TimeSortableKey should call rand.Seed(...) to
// initialize the seed of the default random number generator.
func TimeSortableKey(kind ds.Kind, timeStampMs int64) *datastore.Key {
	ret := ds.NewKey(kind)
	if timeStampMs == 0 {
		timeStampMs = util.TimeStamp(time.Millisecond)
	}
	ret.ID = getSortableTimeID(timeStampMs)
	return ret
}

// GetTimeFromID returns a time stamp in ms from the given id. It is a assumed
// that the id comes from a key that was generated with the TimeSortableKey function.
func GetTimeFromID(id int64) int64 {
	return (id ^ sortableIDMask) >> 20
}

// getSortableTimeID returns a 64-bit ID that contains the current time and
// is inverted and has the property that when sorted in ascending order contains
// time stamps in decreasing order. Thus the newest IDs are first in the ordering.
//
// This was adapted from:
//
// https://github.com/luci/luci-py/blob/master/appengine/swarming/server/task_request.py#L1078
//
// The key contains a 64-bit numeric ID that follows this structure:
// - 1 highest order bits set to 0 to keep value positive.
// - 43 bits is the time since the epoch at 1ms resolution.
// 	 It is good for 2**43 / 365.3 / 24 / 60 / 60 / 1000 = 278 years or 1970+278 =
// 	2248. The author will be dead at that time.
// - 16 bits set to a random value or a server instance specific value. Assuming
// 	an instance is internally consistent with itself, it can ensure to not reuse
// 	the same 16 bits in two consecutive requests and/or throttle itself to one
// 	request per millisecond.
// 	Using random value reduces to 2**-15 the probability of collision on exact
// 	same timestamp at 1ms resolution, so a maximum theoretical rate of 65536000
// 	requests/sec but an effective rate in the range of ~64k requests/sec without
// 	much transaction conflicts. We should be fine.
// - 4 bits set to 0x1. This is to represent the 'version' of the entity schema.
// 	Previous version had 0. Note that this value is XOR'ed in the DB so it's
// 	stored as 0xE. When the TaskRequest entity tree is modified in a breaking
// 	way that affects the packing and unpacking of task ids, this value should be
// 	bumped.
// The key id is this value XOR'ed with sortableIDMask (lowest 63 bit set to 1).
// The reason is that increasing key id values are in decreasing timestamp order.
//
func getSortableTimeID(timeStampMs int64) int64 {
	random16Bits := rand.Int63() & 0x0FFFF
	id := (timeStampMs << 20) | (random16Bits << 4) | 1
	ret := id ^ sortableIDMask
	return ret
}

func updateSortedKeys(keys *[]*datastore.Key, k *datastore.Key, newerThanMs int64, remove bool) bool {
	// Remove all keys that are don't need to be cached any longer
	changed := filterSortedKeys(keys, newerThanMs)

	// If the key is outside the time window we don't have to add or remove it.
	if GetTimeFromID(k.ID) < newerThanMs {
		return changed
	}

	// Find where to insert the new key and insert or remove it
	idx := sort.Search(len(*keys), func(i int) bool { return (*keys)[i].ID >= k.ID })
	found := idx < len(*keys) && (*keys)[idx].ID == k.ID
	if remove {
		// If we have found the key we need to remove it.
		if found {
			// Since we have a slice of pointers we need to take care to not leak the
			// memory of the removed element.
			copy((*keys)[idx:], (*keys)[idx+1:])
			(*keys)[len(*keys)-1] = nil
			*keys = (*keys)[:len(*keys)-1]
			changed = true
		}
		return changed
	}

	// If we found it nothing has to be done.
	if found {
		return changed
	}

	*keys = append(*keys, k)
	// If the target location is at the end we done. Otherwise we are guaranteed
	// that the underlying slice is big enough and we insert the new key without
	// an intermediate allocation.
	if idx < (len(*keys) - 1) {
		// Note: This copies only the existing keys because of the length of dst.
		copy((*keys)[idx+1:], (*keys)[idx:])
		(*keys)[idx] = k
	}
	return true
}

func filterSortedKeys(keys *[]*datastore.Key, newerThanMs int64) bool {
	idx := sort.Search(len(*keys), func(i int) bool {
		return GetTimeFromID((*keys)[i].ID) < newerThanMs
	})

	if idx == len(*keys) {
		return false
	}

	*keys = (*keys)[:idx]
	return true
}

// toMap converts a list of keys to a map keyed by the numeric ID for quick lookup.
func toMap(keys []*datastore.Key) map[int64]*datastore.Key {
	ret := make(map[int64]*datastore.Key, len(keys))
	for _, k := range keys {
		ret[k.ID] = k
	}
	return ret
}

// GetFn is a utility function that provides a uniform interface to a Get function
// of the datastore regardless whether we use a transaction or the client.
func GetFn(client *datastore.Client, tx *datastore.Transaction) func(*datastore.Key, interface{}) error {
	if tx != nil {
		return tx.Get
	}
	return func(k *datastore.Key, dst interface{}) error {
		return client.Get(context.TODO(), k, dst)
	}
}

// PutFn is a utility function that provides a uniform interface to a Put function
// of the datastore regardless whether we use a transaction or the client.
func PutFn(client *datastore.Client, tx *datastore.Transaction) func(*datastore.Key, interface{}) error {
	if tx != nil {
		return func(k *datastore.Key, val interface{}) error {
			_, err := tx.Put(k, val)
			return err
		}
	}
	return func(k *datastore.Key, val interface{}) error {
		_, err := client.Put(context.TODO(), k, val)
		return err
	}
}

// TxActions is helper type that allows to gather actions that should be taken
// based on the outcome of a transaction. These are datastore operations that
// are not part of the transaction but related to it, e.g. for performance reasons
// are not executed within the transaction.
type TxActions struct {
	commitActions   []TxActionFn
	rollbackActions []TxActionFn
}

// TxActionFn is a function that is executed outside a transaction but related
// to the positive (commit) or negative (rollback) outcome of a transaction.
type TxActionFn func() error

// Run runs all the actions that have been added via AddCommitFn and
// AddRollbackFn. It should be called after a transaction has finished.
// If err is nil all the actions added via AddCommitFn are executed otherwise
// all the actions added via AddRollbackFn are executed.
// All errors returned by the actions are logged.
func (t *TxActions) Run(err error) {
	runFNs := t.commitActions
	phase := "Commit"
	if err != nil {
		runFNs = t.rollbackActions
		phase = "Rollback"
	}
	for _, fn := range runFNs {
		if err := fn(); err != nil {
			sklog.Errorf("Error during %s: %s", phase, err)
		}
	}
}

// Batch breaks the given slice of keys int batches of the given size.
// It's main purpose is to deal with calls to datastore that have a limitation
// on the number of keys that can be handled at once, e.g. only 500 keys at a
// can be deleted by DeleteMulti.
func Batch(keySlice []*datastore.Key, size int) [][]*datastore.Key {
	ret := make([][]*datastore.Key, 0, len(keySlice)/size+1)
	for start := 0; start < len(keySlice); start += size {
		ret = append(ret, keySlice[start:util.MinInt(start+size, len(keySlice))])
	}
	return ret
}

// AddCommitFn adds a function that should be run if the related transaction succeeds
func (t *TxActions) AddCommitFn(fn TxActionFn) {
	t.commitActions = append(t.commitActions, fn)
}

// AddRollbackFn adds a function that should be executed if the related transaction fails
func (t *TxActions) AddRollbackFn(fn TxActionFn) {
	t.rollbackActions = append(t.rollbackActions, fn)
}
