blob: 1f9b872d74d748c0dae807d84bc8afc526de89e8 [file] [log] [blame]
package main
import (
"context"
"flag"
"fmt"
"os"
"path/filepath"
"go.skia.org/infra/go/sklog"
"go.skia.org/infra/go/sklog/sklogimpl"
"go.skia.org/infra/go/sklog/stdlogging"
"go.skia.org/infra/rag/go/config"
"go.skia.org/infra/rag/go/eval"
"go.skia.org/infra/rag/go/filereaders/zip"
"go.skia.org/infra/rag/go/genai"
"go.skia.org/infra/rag/go/ingest/history"
"go.skia.org/infra/rag/go/topicstore"
)
const (
embeddingFileName = "embeddings.npy"
indexFileName = "index.pkl"
topicsDirName = "topics"
geminiApiKeyEnv = "GEMINI_API_KEY"
)
func main() {
zipPath := flag.String("zip_path", "", "Path to the input zip file.")
evalSetPath := flag.String("eval_set_path", "", "Path to the evaluation set JSON file.")
configPath := flag.String("config_path", "./configs/demo.json", "Path to the API server config file.")
flag.Parse()
if *zipPath == "" || *evalSetPath == "" {
sklog.Fatal("--zip_path and --eval_set_path are required.")
}
sklogimpl.SetLogger(stdlogging.New(os.Stdout))
ctx := context.Background()
// 1. Load config
cfg, err := config.NewApiServerConfigFromFile(*configPath)
if err != nil {
sklog.Fatalf("Error loading config: %v", err)
}
// 2. Setup stores and ingester
// Note: We don't need a real blamestore for topic evaluation.
topicStore := topicstore.NewInMemoryTopicStore()
ingester := history.New(topicStore, cfg.OutputDimensionality, cfg.DefaultRepoName)
// 3. Extract ZIP and Ingest
content, err := os.ReadFile(*zipPath)
if err != nil {
sklog.Fatalf("Error reading zip file: %v", err)
}
tempDir, err := os.MkdirTemp("", "rag-eval-*")
if err != nil {
sklog.Fatalf("Error creating temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tempDir); err != nil {
sklog.Errorf("Error cleaning up temp dir: %v", err)
}
}()
sklog.Infof("Extracting %s to %s", *zipPath, tempDir)
if err := zip.ExtractZipData(content, tempDir); err != nil {
sklog.Fatalf("Error extracting zip: %v", err)
}
embeddingFilePath := filepath.Join(tempDir, embeddingFileName)
indexFilePath := filepath.Join(tempDir, indexFileName)
topicsDirPath := filepath.Join(tempDir, topicsDirName)
sklog.Infof("Ingesting data into memory store...")
if err := ingester.IngestTopics(ctx, topicsDirPath, embeddingFilePath, indexFilePath, ""); err != nil {
sklog.Fatalf("Error ingesting topics: %v", err)
}
// 4. Setup Evaluator
apiKey := os.Getenv(geminiApiKeyEnv)
if apiKey == "" {
sklog.Fatalf("%s environment variable is not set.", geminiApiKeyEnv)
}
genAiClient, err := genai.NewLocalGeminiClient(ctx, apiKey)
if err != nil {
sklog.Fatalf("Error creating Gemini client: %v", err)
}
evaluator := eval.NewEvaluator(genAiClient, topicStore, cfg.QueryEmbeddingModel, int32(cfg.OutputDimensionality))
// 5. Load Eval Set and Run
evalSet, err := eval.LoadEvaluationSet(*evalSetPath)
if err != nil {
sklog.Fatalf("Error loading eval set: %v", err)
}
sklog.Infof("Running evaluation with %d test cases...", len(evalSet.TestCases))
report, err := evaluator.Run(ctx, evalSet)
if err != nil {
sklog.Fatalf("Error running evaluation: %v", err)
}
// 6. Print Report
printReport(report)
}
func printReport(report *eval.SummaryReport) {
fmt.Println("--- Evaluation Results ---")
fmt.Printf("Total Queries: %d", report.TotalQueries)
fmt.Printf("Mean Recall@5: %.4f", report.MeanRecallAt5)
fmt.Printf("Mean MRR: %.4f", report.MeanMRR)
fmt.Println("--------------------------")
for _, res := range report.Results {
status := "✅ PASS"
if !res.Passed {
status = "❌ FAIL"
}
fmt.Printf("%s | Query: %s", status, res.Query)
fmt.Printf(" Recall@5: %.2f | MRR: %.2f", res.RecallAt5, res.MRR)
if !res.Passed {
fmt.Printf(" Expected: %v", res.ExpectedNames)
fmt.Printf(" Found : %v", res.FoundNames)
}
fmt.Println()
}
fmt.Println("--------------------------")
}