blob: d30ff45133cf984c7cafaf434897e2ad46c659eb [file] [log] [blame]
package main
import (
"context"
"fmt"
"net"
"net/http"
"cloud.google.com/go/spanner"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/urfave/cli/v2"
"go.skia.org/infra/go/cleanup"
"go.skia.org/infra/go/httputils"
"go.skia.org/infra/go/metrics2"
"go.skia.org/infra/go/sklog"
"go.skia.org/infra/rag/go/api/services/history"
"go.skia.org/infra/rag/go/config"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
)
const defaultOutputDimensionality = 768
// Service defines an interface for a service hosted by the HistoryRag server.
type Service interface {
// RegisterGrpc registers the grpc service with the server instance.
RegisterGrpc(server *grpc.Server)
// RegisterHttp registers the http service with the server instance.
RegisterHttp(ctx context.Context, mux *runtime.ServeMux) error
// GetServiceDescriptor returns the service descriptor for the service.
GetServiceDescriptor() grpc.ServiceDesc
}
// ApiServerFlags defines the commandline flags to start the api server.
type ApiServerFlags struct {
ConfigFilename string
GrpcPort string
HttpPort string
PromPort string
Services cli.StringSlice
Local bool
}
// AsCliFlags returns a slice of cli.Flag.
func (flags *ApiServerFlags) AsCliFlags() []cli.Flag {
return []cli.Flag{
&cli.StringFlag{
Destination: &flags.ConfigFilename,
Name: "config_filename",
Value: "./configs/demo.json",
Usage: "The name of the config file to use.",
},
&cli.StringSliceFlag{
Name: "services",
Value: cli.NewStringSlice("history"),
Usage: "This list of RAG services to host on the api.",
Destination: &flags.Services,
},
&cli.StringFlag{
Destination: &flags.GrpcPort,
Name: "grpc_port",
Value: ":8000",
Usage: "The port number to use for grpc server.",
},
&cli.StringFlag{
Destination: &flags.HttpPort,
Name: "http_port",
Value: ":8002",
Usage: "The port number to use for http server.",
},
&cli.StringFlag{
Destination: &flags.PromPort,
Name: "prom_port",
Value: ":20000",
Usage: "Metrics service address (e.g., ':10110')",
},
&cli.BoolFlag{
Destination: &flags.Local,
Name: "local",
Value: false,
},
}
}
// apiServer defines a struct for creating the server.
type apiServer struct {
// Spanner database client.
dbClient *spanner.Client
queryEmbeddingModel string
dimensionality int32
// Grpc server objects
grpcServer *grpc.Server
lisGRPC net.Listener
grpcPort string
// HTTP server objects
httpHandler *runtime.ServeMux
httpPort string
}
// NewApiServer returns a new instance of the api server based on the provided flags.
func NewApiServer(flags *ApiServerFlags) (*apiServer, error) {
ctx := context.Background()
// Read the configuration.
config, err := config.NewApiServerConfigFromFile(flags.ConfigFilename)
if err != nil {
sklog.Errorf("Error reading config file %s: %v", flags.ConfigFilename, err)
return nil, err
}
// Generate the database identifier string and create the spanner client.
databaseName := fmt.Sprintf("projects/%s/instances/%s/databases/%s", config.SpannerConfig.ProjectID, config.SpannerConfig.InstanceID, config.SpannerConfig.DatabaseID)
spannerClient, err := spanner.NewClient(ctx, databaseName)
if err != nil {
return nil, err
}
dimensionality := int32(config.OutputDimensionality)
if dimensionality == 0 {
dimensionality = defaultOutputDimensionality
}
server := &apiServer{
dbClient: spannerClient,
queryEmbeddingModel: config.QueryEmbeddingModel,
dimensionality: dimensionality,
grpcPort: flags.GrpcPort,
httpPort: flags.HttpPort,
}
err = server.initialize(ctx, flags)
if err != nil {
return nil, err
}
return server, nil
}
// initialize performs the init steps for the apiServer object.
func (server *apiServer) initialize(ctx context.Context, flags *ApiServerFlags) error {
// Initialize metrics/
metrics2.InitPrometheus(flags.PromPort)
// Define the list of services to be hosted based on the "services" flag.
serviceList := []Service{}
var serviceMap = map[string]Service{
"history": history.NewApiService(ctx, server.dbClient, server.queryEmbeddingModel, server.dimensionality),
}
for _, serviceName := range flags.Services.Value() {
service, ok := serviceMap[serviceName]
if !ok {
sklog.Fatalf("Invalid service name: %s", &serviceName)
}
serviceList = append(serviceList, service)
sklog.Infof("Added service: %s", serviceName)
}
// Create the GRPC server.
opts := []grpc.ServerOption{}
server.grpcServer = grpc.NewServer(opts...)
sklog.Infof("Registering grpc reflection server.")
reflection.Register(server.grpcServer)
// Create the HTTP server.
server.httpHandler = runtime.NewServeMux()
sklog.Info("Registering individual services.")
server.registerServices(ctx, serviceList)
// Set up the TCP listener for the GRPC server.
var err error
server.lisGRPC, err = net.Listen("tcp4", server.grpcPort)
if err != nil {
sklog.Errorf("failed to listen: %v", err)
return err
}
cleanup.AtExit(server.cleanup)
return nil
}
// registerServices registers all the hosted services with the server instances.
func (server *apiServer) registerServices(ctx context.Context, serviceList []Service) {
for _, service := range serviceList {
service.RegisterGrpc(server.grpcServer)
err := service.RegisterHttp(ctx, server.httpHandler)
if err != nil {
sklog.Fatalf("Error registering http handler for service %v", err)
}
}
}
// server sets up the server instances to start listening for incoming requests.
func (server *apiServer) serve() error {
// The GRPC server listens on a separate thread.
go func() {
sklog.Infof("Listening GRPC at %s", server.lisGRPC.Addr())
if err := server.grpcServer.Serve(server.lisGRPC); err != nil {
sklog.Fatalf("failed to serve grpc: %v", err)
}
}()
// The http server listens on the main thread.
httpServer := &http.Server{
Addr: server.httpPort,
Handler: httputils.HealthzAndHTTPS(server.httpHandler),
}
sklog.Infof("Listening HTTP at %s", server.httpPort)
if err := httpServer.ListenAndServe(); err != nil {
sklog.Fatalf("failed to serve grpc:")
}
return nil
}
// Cleanup performs a graceful shutdown of the grpc server.
func (server *apiServer) cleanup() {
sklog.Info("Shutdown server gracefully.")
if server.grpcServer != nil {
server.grpcServer.GracefulStop()
}
if server.dbClient != nil {
server.dbClient.Close()
}
}