| package main |
| |
| import ( |
| "context" |
| |
| "fmt" |
| "net" |
| "net/http" |
| |
| "cloud.google.com/go/spanner" |
| "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" |
| "github.com/urfave/cli/v2" |
| |
| "go.skia.org/infra/comment_rag/go/api/services/comment" |
| "go.skia.org/infra/comment_rag/go/commentstore" |
| "go.skia.org/infra/go/cleanup" |
| "go.skia.org/infra/go/httputils" |
| "go.skia.org/infra/go/metrics2" |
| "go.skia.org/infra/go/skerr" |
| "go.skia.org/infra/go/sklog" |
| "go.skia.org/infra/rag/go/config" |
| "google.golang.org/grpc" |
| "google.golang.org/grpc/reflection" |
| ) |
| |
| const defaultOutputDimensionality = 768 |
| |
| // Service defines an interface for a service hosted by the CommentRag server. |
| type Service interface { |
| // RegisterGrpc registers the grpc service with the server instance. |
| RegisterGrpc(server *grpc.Server) |
| |
| // RegisterHttp registers the http service with the server instance. |
| RegisterHttp(ctx context.Context, mux *runtime.ServeMux) error |
| |
| // GetServiceDescriptor returns the service descriptor for the service. |
| GetServiceDescriptor() grpc.ServiceDesc |
| } |
| |
| // ApiServerFlags defines the commandline flags to start the api server. |
| type ApiServerFlags struct { |
| ConfigFilename string |
| GrpcPort string |
| HttpPort string |
| PromPort string |
| Local bool |
| } |
| |
| // AsCliFlags returns a slice of cli.Flag. |
| func (flags *ApiServerFlags) AsCliFlags() []cli.Flag { |
| return []cli.Flag{ |
| &cli.StringFlag{ |
| Destination: &flags.ConfigFilename, |
| Name: "config_filename", |
| Value: "./configs/demo.json", |
| Usage: "The name of the config file to use.", |
| }, |
| &cli.StringFlag{ |
| Destination: &flags.GrpcPort, |
| Name: "grpc_port", |
| Value: ":8000", |
| Usage: "The port number to use for grpc server.", |
| }, |
| &cli.StringFlag{ |
| Destination: &flags.HttpPort, |
| Name: "http_port", |
| Value: ":8002", |
| Usage: "The port number to use for http server.", |
| }, |
| &cli.StringFlag{ |
| Destination: &flags.PromPort, |
| Name: "prom_port", |
| Value: ":20000", |
| Usage: "Metrics service address (e.g., ':10110')", |
| }, |
| &cli.BoolFlag{ |
| Destination: &flags.Local, |
| Name: "local", |
| Value: false, |
| }, |
| } |
| } |
| |
| // apiServer defines a struct for creating the server. |
| type apiServer struct { |
| // Spanner database client. |
| dbClient *spanner.Client |
| queryEmbeddingModel string |
| dimensionality int32 |
| |
| // Grpc server objects |
| grpcServer *grpc.Server |
| lisGRPC net.Listener |
| grpcPort string |
| |
| // HTTP server objects |
| httpHandler http.Handler |
| httpPort string |
| } |
| |
| // NewApiServer returns a new instance of the api server based on the provided flags. |
| func NewApiServer(flags *ApiServerFlags) (*apiServer, error) { |
| ctx := context.Background() |
| // Read the configuration. |
| config, err := config.NewApiServerConfigFromFile(flags.ConfigFilename) |
| if err != nil { |
| sklog.Errorf("Error reading config file %s: %v", flags.ConfigFilename, err) |
| return nil, err |
| } |
| |
| // Generate the database identifier string and create the spanner client. |
| databaseName := fmt.Sprintf("projects/%s/instances/%s/databases/%s", config.SpannerConfig.ProjectID, config.SpannerConfig.InstanceID, config.SpannerConfig.DatabaseID) |
| spannerClient, err := spanner.NewClient(ctx, databaseName) |
| if err != nil { |
| return nil, err |
| } |
| |
| dimensionality := int32(config.OutputDimensionality) |
| if dimensionality == 0 { |
| return nil, skerr.Fmt("output_dimensionality is required and cannot be missing or 0 in config") |
| } |
| |
| server := &apiServer{ |
| dbClient: spannerClient, |
| queryEmbeddingModel: config.QueryEmbeddingModel, |
| dimensionality: dimensionality, |
| grpcPort: flags.GrpcPort, |
| httpPort: flags.HttpPort, |
| } |
| err = server.initialize(ctx, flags, config) |
| if err != nil { |
| return nil, err |
| } |
| |
| return server, nil |
| } |
| |
| // initialize performs the init steps for the apiServer object. |
| func (server *apiServer) initialize(ctx context.Context, flags *ApiServerFlags, cfg *config.ApiServerConfig) error { |
| // Initialize metrics/ |
| metrics2.InitPrometheus(flags.PromPort) |
| |
| // Initialize the Spanner comment RAG storage client |
| commentStore := commentstore.NewSpannerCommentStore(server.dbClient) |
| |
| // Clean standalone Comment service registration |
| serviceList := []Service{ |
| comment.NewApiService(ctx, commentStore, server.queryEmbeddingModel, server.dimensionality), |
| } |
| |
| // Create the GRPC server. |
| opts := []grpc.ServerOption{} |
| server.grpcServer = grpc.NewServer(opts...) |
| |
| sklog.Infof("Registering grpc reflection server.") |
| reflection.Register(server.grpcServer) |
| |
| // Create the HTTP server. |
| gwmux := runtime.NewServeMux() |
| |
| sklog.Info("Registering comment services.") |
| server.registerServices(ctx, serviceList, gwmux) |
| |
| rootMux := http.NewServeMux() |
| rootMux.Handle("/commentrag/", gwmux) |
| |
| server.httpHandler = rootMux |
| |
| // Set up the TCP listener for the GRPC server. |
| var err error |
| server.lisGRPC, err = net.Listen("tcp4", server.grpcPort) |
| if err != nil { |
| sklog.Errorf("failed to listen: %v", err) |
| return err |
| } |
| |
| cleanup.AtExit(server.cleanup) |
| return nil |
| } |
| |
| // registerServices registers all the hosted services with the server instances. |
| func (server *apiServer) registerServices(ctx context.Context, serviceList []Service, gwmux *runtime.ServeMux) { |
| for _, service := range serviceList { |
| service.RegisterGrpc(server.grpcServer) |
| err := service.RegisterHttp(ctx, gwmux) |
| if err != nil { |
| sklog.Fatalf("Error registering http handler for service %v", err) |
| } |
| } |
| } |
| |
| // server sets up the server instances to start listening for incoming requests. |
| func (server *apiServer) serve() error { |
| // gRPC server runs on a separate thread |
| go func() { |
| sklog.Infof("Listening GRPC at %s", server.lisGRPC.Addr()) |
| if err := server.grpcServer.Serve(server.lisGRPC); err != nil { |
| sklog.Fatalf("failed to serve grpc: %v", err) |
| } |
| }() |
| |
| // HTTP REST Gateway server runs on main thread |
| httpServer := &http.Server{ |
| Addr: server.httpPort, |
| Handler: httputils.HealthzAndHTTPS(server.httpHandler), |
| } |
| sklog.Infof("Listening HTTP at %s", server.httpPort) |
| if err := httpServer.ListenAndServe(); err != nil { |
| sklog.Fatalf("failed to serve HTTP gateway") |
| } |
| |
| return nil |
| } |
| |
| // Cleanup performs a graceful shutdown of Spanner client and gRPC server. |
| func (server *apiServer) cleanup() { |
| sklog.Info("Shutdown comment server gracefully.") |
| if server.grpcServer != nil { |
| server.grpcServer.GracefulStop() |
| } |
| if server.dbClient != nil { |
| server.dbClient.Close() |
| } |
| } |