blob: dbab217caef6c861b55a16b0efb98a2b6679dfcf [file]
package comment
import (
"context"
"os"
"strings"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"go.opencensus.io/trace"
"google.golang.org/grpc"
"go.skia.org/infra/comment_rag/go/commentstore"
"go.skia.org/infra/comment_rag/go/spanner"
pb "go.skia.org/infra/comment_rag/proto/comment/v1"
"go.skia.org/infra/go/metrics2"
"go.skia.org/infra/go/skerr"
"go.skia.org/infra/go/sklog"
"go.skia.org/infra/rag/go/genai"
)
const (
geminiApiKeyEnvVar = "GEMINI_API_KEY"
geminiProjectEnvVar = "GEMINI_PROJECT"
geminiLocationEnvVar = "GEMINI_LOCATION"
defaultCommentsLimit = 10
)
// ApiService provides a struct for the Comment RAG api implementation.
type ApiService struct {
pb.UnimplementedCommentRagApiServiceServer
// Store instance.
commentStore commentstore.CommentStore
// GenAI Client instance.
genAiClient genai.GenAIClient
// Embedding model to use for query.
queryEmbeddingModel string
// Output dimensionality for query embedding.
dimensionality int32
// Metric to count SearchComments calls.
searchCommentsCounterMetric metrics2.Counter
}
// NewApiService returns a new instance of the ApiService struct.
func NewApiService(ctx context.Context, commentStore commentstore.CommentStore, queryEmbeddingModel string, dimensionality int32) *ApiService {
var genAiClient *genai.GeminiClient
var err error
// Get the api key from the env.
apiKey := os.Getenv(geminiApiKeyEnvVar)
if apiKey != "" {
sklog.Infof("Gemini api key specified in the environment, creating a local client.")
genAiClient, err = genai.NewLocalGeminiClient(ctx, apiKey)
} else {
projectId := os.Getenv(geminiProjectEnvVar)
location := os.Getenv(geminiLocationEnvVar)
if projectId == "" || location == "" {
sklog.Fatalf("%s and %s environment variables need to be set.", geminiProjectEnvVar, geminiLocationEnvVar)
}
sklog.Infof("Creating a new Gemini client for project %s and location %s", projectId, location)
genAiClient, err = genai.NewGeminiClient(ctx, projectId, location)
}
if err != nil {
sklog.Errorf("Error creating new gemini client: %v", err)
return nil
}
return &ApiService{
commentStore: commentStore,
genAiClient: genAiClient,
queryEmbeddingModel: queryEmbeddingModel,
dimensionality: dimensionality,
searchCommentsCounterMetric: metrics2.GetCounter("commentrag_searchComments_count"),
}
}
// RegisterGrpc registers the grpc service with the server instance.
func (service *ApiService) RegisterGrpc(server *grpc.Server) {
pb.RegisterCommentRagApiServiceServer(server, service)
}
// RegisterHttp registers the service with the http handler.
func (service *ApiService) RegisterHttp(ctx context.Context, mux *runtime.ServeMux) error {
return pb.RegisterCommentRagApiServiceHandlerServer(ctx, mux, service)
}
// GetServiceDescriptor returns the service descriptor.
func (service *ApiService) GetServiceDescriptor() grpc.ServiceDesc {
return pb.CommentRagApiService_ServiceDesc
}
// SearchComments implements the SearchComments endpoint.
func (service *ApiService) SearchComments(ctx context.Context, req *pb.SearchCommentsRequest) (*pb.SearchCommentsResponse, error) {
if req == nil {
return nil, skerr.Fmt("request cannot be nil.")
}
var categories []string
seen := make(map[string]bool)
for _, cat := range req.GetCategories() {
trimmed := strings.TrimSpace(cat)
if trimmed == "" {
continue
}
normalized := strings.ToUpper(trimmed)
if !spanner.IsValidCategory(normalized) {
return nil, skerr.Fmt("invalid category: %q. Supported categories are: %v", cat, spanner.ValidCategories)
}
if !seen[normalized] {
seen[normalized] = true
categories = append(categories, normalized)
}
}
if len(categories) == 0 {
return nil, skerr.Fmt("at least one category must be specified.")
}
project := strings.TrimSpace(req.GetProject())
if project == "" {
return nil, skerr.Fmt("project cannot be empty.")
}
repo := strings.TrimSpace(req.GetRepo())
if repo == "" {
return nil, skerr.Fmt("repo cannot be empty.")
}
query := strings.TrimSpace(req.GetQuery())
if query == "" {
return nil, skerr.Fmt("query cannot be empty.")
}
sklog.Infof("Received SearchComments request with query: %s", query)
service.searchCommentsCounterMetric.Inc(1)
ctx, span := trace.StartSpan(ctx, "commentrag.service.SearchComments")
defer span.End()
// Get the embedding vector for the input query.
queryEmbedding, err := service.genAiClient.GetEmbedding(ctx, service.queryEmbeddingModel, service.dimensionality, query)
if err != nil {
sklog.Errorf("Error getting embedding for query %s: %v", query, err)
return nil, err
}
sklog.Infof("Embedding for query %q has length %d", query, len(queryEmbedding))
// Search the relevant comments in Spanner.
limit := defaultCommentsLimit
if req.GetMaxComments() > 0 {
limit = int(req.GetMaxComments())
}
foundCases, err := service.commentStore.SearchComments(ctx, queryEmbedding, limit, project, repo, categories)
if err != nil {
sklog.Errorf("Error searching for comments: %v", err)
return nil, err
}
// Generate the response.
resp := &pb.SearchCommentsResponse{}
for _, c := range foundCases {
resp.Comments = append(resp.Comments, &pb.SearchCommentsResponse_CommentRecord{
Id: c.ID,
ChangeId: c.ChangeID,
Project: c.Project,
Category: c.Category,
Repo: c.Repo,
FilePath: c.FilePath,
CommentText: c.CommentText,
CodeSnippet: c.CodeSnippet,
ClSubject: c.CLSubject,
ClDescription: c.CLDescription,
Analysis: c.Analysis,
CosineDistance: float32(c.Distance),
})
}
sklog.Infof("Returning %d matching comment records", len(resp.Comments))
return resp, nil
}
// ListValidCategories lists all valid review categories supported by the comment_rag service.
func (service *ApiService) ListValidCategories(ctx context.Context, req *pb.ListValidCategoriesRequest) (*pb.ListValidCategoriesResponse, error) {
sklog.Info("Received ListValidCategories request")
return &pb.ListValidCategoriesResponse{
Categories: spanner.ValidCategories,
}, nil
}