blob: 5e3ce74f9c3b726a4cc3f5e4d43bc249a58f1e70 [file] [log] [blame]
package eval
import (
"context"
"encoding/json"
"os"
"strings"
"go.skia.org/infra/go/skerr"
"go.skia.org/infra/go/sklog"
"go.skia.org/infra/rag/go/genai"
"go.skia.org/infra/rag/go/topicstore"
)
// TestCase represents a single evaluation query and its expected results.
type TestCase struct {
Query string `json:"query"`
ExpectedTopicNames []string `json:"expected_topic_names"`
Category string `json:"category,omitempty"`
}
// EvaluationSet is a collection of TestCases.
type EvaluationSet struct {
TestCases []TestCase `json:"test_cases"`
}
// EvalResult contains the performance metrics for a specific query.
type EvalResult struct {
Query string `json:"query"`
RecallAt5 float64 `json:"recall_at_5"`
MRR float64 `json:"mrr"`
FoundNames []string `json:"found_names"`
ExpectedNames []string `json:"expected_names"`
Passed bool `json:"passed"`
}
// SummaryReport contains the aggregated results of an evaluation run.
type SummaryReport struct {
TotalQueries int `json:"total_queries"`
MeanRecallAt5 float64 `json:"mean_recall_at_5"`
MeanMRR float64 `json:"mean_mrr"`
Results []*EvalResult `json:"results"`
}
// Evaluator performs the evaluation.
type Evaluator struct {
genAiClient genai.GenAIClient
topicStore topicstore.TopicStore
embeddingModel string
dimensionality int32
}
// NewEvaluator returns a new instance of Evaluator.
func NewEvaluator(genAiClient genai.GenAIClient, topicStore topicstore.TopicStore, embeddingModel string, dimensionality int32) *Evaluator {
return &Evaluator{
genAiClient: genAiClient,
topicStore: topicStore,
embeddingModel: embeddingModel,
dimensionality: dimensionality,
}
}
// Run performs the evaluation on the provided evaluation set.
func (e *Evaluator) Run(ctx context.Context, evalSet *EvaluationSet) (*SummaryReport, error) {
report := &SummaryReport{
TotalQueries: len(evalSet.TestCases),
}
var sumRecallAt5, sumMRR float64
for _, tc := range evalSet.TestCases {
res, err := e.evaluateTestCase(ctx, tc)
if err != nil {
sklog.Errorf("Error evaluating test case '%s': %v", tc.Query, err)
continue
}
report.Results = append(report.Results, res)
sumRecallAt5 += res.RecallAt5
sumMRR += res.MRR
}
if report.TotalQueries > 0 {
report.MeanRecallAt5 = sumRecallAt5 / float64(report.TotalQueries)
report.MeanMRR = sumMRR / float64(report.TotalQueries)
}
return report, nil
}
func (e *Evaluator) evaluateTestCase(ctx context.Context, tc TestCase) (*EvalResult, error) {
// 1. Get embedding for the query.
embedding, err := e.genAiClient.GetEmbedding(ctx, e.embeddingModel, e.dimensionality, tc.Query)
if err != nil {
return nil, skerr.Wrap(err)
}
// 2. Search for top topics.
// We search for top 5 to calculate Recall@5.
found, err := e.topicStore.SearchTopics(ctx, embedding, 5, "")
if err != nil {
return nil, skerr.Wrap(err)
}
res := &EvalResult{
Query: tc.Query,
ExpectedNames: tc.ExpectedTopicNames,
Passed: false,
}
for _, f := range found {
res.FoundNames = append(res.FoundNames, f.Title)
}
// 3. Calculate metrics.
res.RecallAt5 = calculateRecall(res.FoundNames, tc.ExpectedTopicNames)
res.MRR = calculateMRR(res.FoundNames, tc.ExpectedTopicNames)
// A simple pass criteria: at least one expected topic is found in top 5.
if res.RecallAt5 > 0 {
res.Passed = true
}
return res, nil
}
func calculateRecall(found, expected []string) float64 {
if len(expected) == 0 {
return 1.0
}
count := 0
expectedMap := make(map[string]bool)
for _, name := range expected {
expectedMap[strings.ToLower(name)] = true
}
for _, name := range found {
if expectedMap[strings.ToLower(name)] {
count++
}
}
return float64(count) / float64(len(expected))
}
func calculateMRR(found, expected []string) float64 {
expectedMap := make(map[string]bool)
for _, name := range expected {
expectedMap[strings.ToLower(name)] = true
}
for i, name := range found {
if expectedMap[strings.ToLower(name)] {
return 1.0 / float64(i+1)
}
}
return 0.0
}
// LoadEvaluationSet loads the evaluation set from a JSON file.
func LoadEvaluationSet(filePath string) (*EvaluationSet, error) {
content, err := os.ReadFile(filePath)
if err != nil {
return nil, skerr.Wrap(err)
}
var evalSet EvaluationSet
if err := json.Unmarshal(content, &evalSet); err != nil {
return nil, skerr.Wrap(err)
}
return &evalSet, nil
}