blob: 7360571c50766d7bc51b2ca5a1d60520e8b57535 [file] [log] [blame]
/*
Leasing Server for Swarming Bots.
*/
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"html/template"
"net/http"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/unrolled/secure"
swarming_api "go.chromium.org/luci/common/api/swarming/swarming/v1"
"google.golang.org/api/iterator"
"go.skia.org/infra/go/alogin"
"go.skia.org/infra/go/alogin/proxylogin"
"go.skia.org/infra/go/baseapp"
"go.skia.org/infra/go/httputils"
"go.skia.org/infra/go/metrics2"
"go.skia.org/infra/go/roles"
"go.skia.org/infra/go/sklog"
"go.skia.org/infra/go/swarming"
"go.skia.org/infra/go/util"
"go.skia.org/infra/leasing/go/types"
)
const (
maxLeaseDurationHrs = 23
swarmingHardTimeout = 24 * time.Hour
leaseTaskPriority = 50
myLeasesURI = "/my_leases"
allLeasesURI = "/all_leases"
getTaskStatusURI = "/_/get_task_status"
getLeasesPostURI = "/_/get_leases"
getSupportedPoolsPostURI = "/_/get_supported_pools"
poolDetailsPostURI = "/_/pooldetails"
addTaskPostURI = "/_/add_leasing_task"
extendTaskPostURI = "/_/extend_leasing_task"
expireTaskPostURI = "/_/expire_leasing_task"
)
var (
// Flags
host = flag.String("host", "leasing.skia.org", "HTTP service host")
workdir = flag.String("workdir", ".", "Directory to use for scratch work.")
artifactsDir = flag.String("artifacts_dir", "", "The directory to find leasing server's artifacts.")
pollInterval = flag.Duration("poll_interval", 1*time.Minute, "How often the leasing server will check if tasks have expired.")
poolDetailsUpdateFrequency = flag.Duration("pool_details_update_freq", 5*time.Minute, "How often to call swarming API to refresh the details of supported pools.")
// Datastore params
namespace = flag.String("namespace", "leasing-server", "The Cloud Datastore namespace, such as 'leasing-server'.")
projectName = flag.String("project_name", "google.com:skia-buildbots", "The Google Cloud project name.")
poolToDetails map[string]*types.PoolDetails
poolToDetailsMutex sync.Mutex
plogin *proxylogin.ProxyLogin
)
// New implements baseapp.Constructor.
func New() (baseapp.App, error) {
ctx := context.Background()
// Create workdir if it does not exist.
if err := os.MkdirAll(*workdir, 0755); err != nil {
sklog.Fatalf("Could not create %s: %s", *workdir, err)
}
// Initialize mailing library.
MailInit()
plogin = proxylogin.NewWithDefaults()
// Initialize swarming.
if err := SwarmingInit(ctx); err != nil {
sklog.Fatalf("Failed to init swarming: %s", err)
}
// Initialize cloud datastore.
if err := DatastoreInit(ctx, *projectName, *namespace); err != nil {
sklog.Fatalf("Failed to init cloud datastore: %s", err)
}
var err error
poolToDetails, err = GetDetailsOfAllPools(ctx)
if err != nil {
sklog.Fatalf("Could not get details of all pools: %s", err)
}
go func() {
for range time.Tick(*poolDetailsUpdateFrequency) {
poolToDetailsMutex.Lock()
poolToDetails, err = GetDetailsOfAllPools(ctx)
poolToDetailsMutex.Unlock()
if err != nil {
sklog.Errorf("Could not get details of all pools: %s", err)
}
}
}()
healthyGauge := metrics2.GetInt64Metric("healthy")
go func() {
for range time.Tick(*pollInterval) {
healthyGauge.Update(1)
if err := pollSwarmingTasks(ctx); err != nil {
sklog.Errorf("Error when checking for expired tasks: %v", err)
}
}
}()
srv := &Server{}
srv.loadTemplates()
return srv, nil
}
// Server is the state of the server.
type Server struct {
templates *template.Template
}
func (srv *Server) loadTemplates() {
srv.templates = template.Must(template.New("").Delims("{%", "%}").ParseFiles(
filepath.Join(*baseapp.ResourcesDir, "index.html"),
filepath.Join(*baseapp.ResourcesDir, "leases_list.html"),
))
}
// user returns the currently logged in user, or a placeholder if running locally.
func (srv *Server) user(r *http.Request) string {
user := "barney@example.org"
if !*baseapp.Local {
user = string(plogin.LoggedInAs(r))
}
return user
}
// AddHandlers implements baseapp.App.
func (srv *Server) AddHandlers(r chi.Router) {
// Get task status will be used from swarming bots.
r.Get(getTaskStatusURI, srv.statusHandler)
// All endpoints that require authentication should be added to this router.
appRouter := chi.NewRouter()
appRouter.HandleFunc("/", srv.indexHandler)
appRouter.HandleFunc("/_/login/status", alogin.LoginStatusHandler(plogin))
appRouter.HandleFunc(myLeasesURI, srv.myLeasesHandler)
appRouter.HandleFunc(allLeasesURI, srv.allLeasesHandler)
appRouter.Post(poolDetailsPostURI, srv.poolDetailsHandler)
appRouter.Post(getSupportedPoolsPostURI, srv.supportedPoolsHandler)
appRouter.Post(getLeasesPostURI, srv.getLeasesHandler)
appRouter.Post(addTaskPostURI, srv.addTaskHandler)
appRouter.Post(extendTaskPostURI, srv.extendTaskHandler)
appRouter.Post(expireTaskPostURI, srv.expireTaskHandler)
// Use the appRouter as a handler and wrap it into middleware that enforces authentication.
appHandler := http.Handler(appRouter)
if !*baseapp.Local {
appHandler = alogin.ForceRoleMiddleware(plogin, roles.Viewer)(appHandler)
}
r.Handle("/*", appHandler)
}
func (srv *Server) indexHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
if err := srv.templates.ExecuteTemplate(w, "index.html", map[string]string{
"Nonce": secure.CSPNonce(r.Context()),
}); err != nil {
httputils.ReportError(w, err, "Failed to expand template.", http.StatusInternalServerError)
return
}
}
// Status represents the status of a Swarming task.
type Status struct {
TaskId int64
Expired bool
}
func (srv *Server) statusHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
taskParam := r.FormValue("task")
if taskParam == "" {
httputils.ReportError(w, nil, "Missing task parameter", http.StatusInternalServerError)
return
}
taskID, err := strconv.ParseInt(taskParam, 10, 64)
if err != nil {
httputils.ReportError(w, err, "Invalid task parameter", http.StatusInternalServerError)
return
}
k, t, err := GetDSTask(taskID)
if err != nil {
httputils.ReportError(w, err, "Could not find task", http.StatusInternalServerError)
return
}
status := Status{
TaskId: k.ID,
Expired: t.Done,
}
if err := json.NewEncoder(w).Encode(status); err != nil {
httputils.ReportError(w, err, "Failed to encode JSON", http.StatusInternalServerError)
return
}
}
func (srv *Server) poolDetailsHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
poolParam := r.FormValue("pool")
if poolParam == "" {
httputils.ReportError(w, nil, "Missing pool parameter", http.StatusInternalServerError)
return
}
poolToDetailsMutex.Lock()
defer poolToDetailsMutex.Unlock()
poolDetails, ok := poolToDetails[poolParam]
if !ok {
httputils.ReportError(w, nil, "No such pool", http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(poolDetails); err != nil {
httputils.ReportError(w, err, fmt.Sprintf("Failed to encode JSON: %v", err), http.StatusInternalServerError)
return
}
}
func (srv *Server) supportedPoolsHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
supportedPools := []string{}
poolToDetailsMutex.Lock()
defer poolToDetailsMutex.Unlock()
for p := range poolToDetails {
supportedPools = append(supportedPools, p)
}
sort.Strings(supportedPools)
if err := json.NewEncoder(w).Encode(supportedPools); err != nil {
httputils.ReportError(w, err, fmt.Sprintf("Failed to encode JSON: %v", err), http.StatusInternalServerError)
return
}
}
type sortTasks []*types.Task
func (a sortTasks) Len() int { return len(a) }
func (a sortTasks) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a sortTasks) Less(i, j int) bool {
return a[i].Created.After(a[j].Created)
}
func getLeasingTasks(filterUser string) ([]*types.Task, error) {
tasks := []*types.Task{}
it := GetAllDSTasks(filterUser)
for {
t := &types.Task{}
k, err := it.Next(t)
if err == iterator.Done {
break
} else if err != nil {
return nil, fmt.Errorf("Failed to retrieve list of tasks: %s", err)
}
t.DatastoreId = k.ID
tasks = append(tasks, t)
}
sort.Sort(sortTasks(tasks))
return tasks, nil
}
func (srv *Server) getLeasesHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
reqGetLeasesRequest := struct {
FilterByUser string `json:"filter_by_user"`
}{}
if err := json.NewDecoder(r.Body).Decode(&reqGetLeasesRequest); err != nil {
httputils.ReportError(w, err, "Failed to decode add note request", http.StatusInternalServerError)
return
}
tasks, err := getLeasingTasks(reqGetLeasesRequest.FilterByUser)
if err != nil {
httputils.ReportError(w, err, "Failed to expand template", http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(tasks); err != nil {
sklog.Errorf("Failed to send response: %s", err)
}
}
func (srv *Server) leasesHandlerHelper(w http.ResponseWriter, r *http.Request, filterByUser string) {
w.Header().Set("Content-Type", "text/html")
if err := srv.templates.ExecuteTemplate(w, "leases_list.html", map[string]string{
"FilterByUser": filterByUser,
"Nonce": secure.CSPNonce(r.Context()),
}); err != nil {
httputils.ReportError(w, err, "Failed to expand template.", http.StatusInternalServerError)
return
}
}
func (srv *Server) myLeasesHandler(w http.ResponseWriter, r *http.Request) {
srv.leasesHandlerHelper(w, r, string(plogin.LoggedInAs(r)))
}
func (srv *Server) allLeasesHandler(w http.ResponseWriter, r *http.Request) {
srv.leasesHandlerHelper(w, r, "" /* filterByUser */)
}
func (srv *Server) extendTaskHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
extendRequest := types.ExtendTaskRequest{}
if err := json.NewDecoder(r.Body).Decode(&extendRequest); err != nil {
httputils.ReportError(w, err, "Failed to decode extend request", http.StatusInternalServerError)
return
}
k, t, err := GetDSTask(extendRequest.TaskID)
if err != nil {
httputils.ReportError(w, err, "Could not find task", http.StatusInternalServerError)
return
}
// Add duration hours to the task's lease end time only if ends up being
// less than 23 hours after the task's creation time.
newLeaseEndTime := t.LeaseEndTime.Add(time.Hour * time.Duration(extendRequest.DurationHrs))
maxPossibleLeaseEndTime := t.Created.Add(time.Hour * time.Duration(maxLeaseDurationHrs))
if newLeaseEndTime.After(maxPossibleLeaseEndTime) {
httputils.ReportError(w, nil, fmt.Sprintf("Can not extend lease beyond %d hours of the task creation time", maxLeaseDurationHrs), http.StatusInternalServerError)
return
}
// Change the lease end time.
t.LeaseEndTime = newLeaseEndTime
// Reset the warning sent flag since the lease has been extended.
t.WarningSent = false
if _, err := UpdateDSTask(k, t); err != nil {
httputils.ReportError(w, err, "Error updating task in datastore", http.StatusInternalServerError)
return
}
// Inform the requester that the task has been extended by durationHrs.
if err := SendExtensionEmail(t.Requester, t.SwarmingServer, t.SwarmingTaskId, t.SwarmingBotId, t.EmailThreadingReference, extendRequest.DurationHrs); err != nil {
httputils.ReportError(w, err, "Error sending extension email", http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(t); err != nil {
sklog.Errorf("Failed to send response: %s", err)
}
}
func (srv *Server) expireTaskHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
expireRequest := types.ExpireTaskRequest{}
if err := json.NewDecoder(r.Body).Decode(&expireRequest); err != nil {
httputils.ReportError(w, err, "Failed to decode expire request", http.StatusInternalServerError)
return
}
k, t, err := GetDSTask(expireRequest.TaskID)
if err != nil {
httputils.ReportError(w, err, "Could not find task", http.StatusInternalServerError)
return
}
// Change the task to Done, change the lease end time to now, and mark the
// state as successfully completed.
t.Done = true
t.LeaseEndTime = time.Now()
t.SwarmingTaskState = getCompletedStateStr(false)
if _, err := UpdateDSTask(k, t); err != nil {
httputils.ReportError(w, err, "Error updating task in datastore", http.StatusInternalServerError)
return
}
// Inform the requester that the task has completed.
if err := SendCompletionEmail(t.Requester, t.SwarmingServer, t.SwarmingTaskId, t.SwarmingBotId, t.EmailThreadingReference); err != nil {
httputils.ReportError(w, err, "Error sending completion email", http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(t); err != nil {
sklog.Errorf("Failed to send response: %s", err)
}
}
func (srv *Server) addTaskHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
ctx := context.Background()
task := &types.Task{}
if err := json.NewDecoder(r.Body).Decode(&task); err != nil {
httputils.ReportError(w, err, fmt.Sprintf("Failed to add %T task", task), http.StatusInternalServerError)
return
}
defer util.Close(r.Body)
key := GetNewDSKey()
if task.SwarmingBotId != "" {
// If BotId is specified then validate it so that we can fail fast if
// necessary.
validBotID, err := IsBotIDValid(r.Context(), task.SwarmingPool, task.SwarmingBotId)
if err != nil {
httputils.ReportError(w, err, fmt.Sprintf("Error querying swarming for botId %s in pool %s", task.SwarmingBotId, task.SwarmingPool), http.StatusInternalServerError)
return
}
if !validBotID {
httputils.ReportError(w, err, fmt.Sprintf("Could not find botId %s in pool %s", task.SwarmingBotId, task.SwarmingPool), http.StatusInternalServerError)
return
}
}
// Populate deviceType only if Android or iOS is the osType.
if task.OsType != "Android" && !strings.HasPrefix(task.OsType, "iOS") {
task.DeviceType = ""
}
// Add the username of the requester.
task.Requester = string(plogin.LoggedInAs(r))
// Add the created time.
task.Created = time.Now()
// Set to pending.
task.SwarmingTaskState = swarming.TASK_STATE_PENDING
// Upload artifacts.
var swarmingProps *swarming_api.SwarmingRpcsTaskProperties
if task.TaskIdForIsolates != "" {
t, err := GetSwarmingTaskMetadata(r.Context(), task.SwarmingPool, task.TaskIdForIsolates)
if err != nil {
httputils.ReportError(w, err, fmt.Sprintf("Could not find taskId %s in pool %s", task.TaskIdForIsolates, task.SwarmingPool), http.StatusInternalServerError)
return
}
swarmingProps = swarming.GetTaskRequestProperties(t)
} else {
swarmingProps = &swarming_api.SwarmingRpcsTaskProperties{}
}
datastoreKey, err := PutDSTask(key, task)
if err != nil {
httputils.ReportError(w, err, fmt.Sprintf("Error putting task in datastore: %v", err), http.StatusInternalServerError)
return
}
casDigest, err := AddLeasingArtifactsToCAS(ctx, task.SwarmingPool, swarmingProps.CasInputRoot)
if err != nil {
httputils.ReportError(w, err, fmt.Sprintf("Error merging CAS inputs: %s", err), http.StatusInternalServerError)
return
}
// Trigger the swarming task.
swarmingTaskID, err := TriggerSwarmingTask(r.Context(), task.SwarmingPool, task.Requester, strconv.Itoa(int(datastoreKey.ID)), task.OsType, task.DeviceType, task.SwarmingBotId, *host, casDigest, swarmingProps.RelativeCwd, swarmingProps.CipdInput, swarmingProps.Command)
if err != nil {
httputils.ReportError(w, err, fmt.Sprintf("Error when triggering swarming task: %v", err), http.StatusInternalServerError)
return
}
// Update the task with swarming fields.
swarmingInstance := GetSwarmingInstance(task.SwarmingPool)
task.SwarmingServer = swarmingInstance.SwarmingServer
task.SwarmingTaskId = swarmingTaskID
if _, err = UpdateDSTask(datastoreKey, task); err != nil {
httputils.ReportError(w, err, fmt.Sprintf("Error updating task with swarming fields in datastore: %v", err), http.StatusInternalServerError)
return
}
sklog.Infof("Added %v task into the datastore with key %s", task, datastoreKey)
if err := json.NewEncoder(w).Encode(task); err != nil {
sklog.Errorf("Failed to send response: %s", err)
}
}
// AddMiddleware implements baseapp.App.
func (srv *Server) AddMiddleware() []func(http.Handler) http.Handler {
return []func(http.Handler) http.Handler{}
}
func main() {
baseapp.Serve(New, []string{*host})
}