blob: d01bdd514cbb55237a61bb598168ff864dc305f1 [file] [log] [blame]
package main
import (
"context"
"fmt"
"testing"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.skia.org/infra/mcp/common"
"go.skia.org/infra/mcp/common/mocks"
)
func TestCreateServer_Success(t *testing.T) {
flags := &mcpFlags{
ServiceName: string(HelloWorld),
}
server, err := createMcpSSEServer(flags)
assert.Nil(t, err)
assert.NotNil(t, server)
}
func TestCreateServer_Invalid(t *testing.T) {
flags := &mcpFlags{
ServiceName: "random",
}
server, err := createMcpSSEServer(flags)
assert.NotNil(t, err)
assert.Nil(t, server)
}
func TestCreateMcpSSEServer_ArgumentTypeSwitch(t *testing.T) {
originalServiceRegistry := make(map[mcpservice]serviceFactory)
for k, v := range serviceRegistry {
originalServiceRegistry[k] = v
}
defer func() {
serviceRegistry = originalServiceRegistry
}()
testCases := []struct {
name string
argType common.ToolArgumentType
expectError bool
expectedErrorMsg string
}{
{name: "StringArgument", argType: common.StringArgument, expectError: false},
{name: "BooleanArgument", argType: common.BooleanArgument, expectError: false},
{name: "NumberArgument", argType: common.NumberArgument, expectError: false},
{name: "ObjectArgument", argType: common.ObjectArgument, expectError: false},
{name: "ArrayArgument", argType: common.ArrayArgument, expectError: false},
{name: "InvalidArgumentType", argType: common.ToolArgumentType(99), expectError: true, expectedErrorMsg: "Invalid argument type 99"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockService := &mocks.MockArgumentService{ArgTypeToTest: tc.argType}
testServiceName := "testargumentservice_" + tc.name
testMcpService := mcpservice(testServiceName)
serviceRegistry[testMcpService] = func() common.McpService { return mockService }
defer delete(serviceRegistry, testMcpService)
flags := &mcpFlags{
ServiceName: testServiceName,
}
server, err := createMcpSSEServer(flags)
if tc.expectError {
require.Error(t, err)
assert.Nil(t, server)
if tc.expectedErrorMsg != "" {
assert.Contains(t, err.Error(), tc.expectedErrorMsg)
}
} else {
require.NoError(t, err)
assert.NotNil(t, server)
}
})
}
}
func TestCreateMcpSSEServer_ToolWithNoArguments(t *testing.T) {
originalServiceRegistry := make(map[mcpservice]serviceFactory)
for k, v := range serviceRegistry {
originalServiceRegistry[k] = v
}
defer func() {
serviceRegistry = originalServiceRegistry
}()
mockService := &mocks.MockArgumentService{
CustomTools: []common.Tool{
{Name: "noArgTool", Description: "Tool with no arguments", Arguments: nil, Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return nil, nil }},
},
}
testServiceName := "testnoargservice"
testMcpService := mcpservice(testServiceName)
serviceRegistry[testMcpService] = func() common.McpService { return mockService }
defer delete(serviceRegistry, testMcpService)
flags := &mcpFlags{ServiceName: testServiceName}
server, err := createMcpSSEServer(flags)
require.NoError(t, err)
assert.NotNil(t, server)
}
func TestCreateMcpSSEServer_ServiceWithNoTools(t *testing.T) {
originalServiceRegistry := make(map[mcpservice]serviceFactory)
for k, v := range serviceRegistry {
originalServiceRegistry[k] = v
}
defer func() {
serviceRegistry = originalServiceRegistry
}()
mockService := &mocks.MockArgumentService{CustomTools: []common.Tool{}} // No tools
testServiceName := "testnotoolservice"
testMcpService := mcpservice(testServiceName)
serviceRegistry[testMcpService] = func() common.McpService { return mockService }
defer delete(serviceRegistry, testMcpService)
flags := &mcpFlags{ServiceName: testServiceName}
server, err := createMcpSSEServer(flags)
require.NoError(t, err)
assert.NotNil(t, server)
}
func TestCreateMcpSSEServer_ServiceInitError(t *testing.T) {
// This test doesn't need to manipulate the global serviceRegistry if the service is already registered or if we register a new one.
// For consistency with other tests, we can follow the same pattern.
expectedErr := fmt.Errorf("init failed")
mockService := &mocks.MockArgumentService{InitError: expectedErr}
testServiceName := "testiniterrorservice" // This service name won't be in the default registry
testMcpService := mcpservice(testServiceName)
serviceRegistry[testMcpService] = func() common.McpService { return mockService }
defer delete(serviceRegistry, testMcpService) // Clean up after test
flags := &mcpFlags{ServiceName: testServiceName}
server, err := createMcpSSEServer(flags)
require.Error(t, err)
assert.Equal(t, expectedErr, err)
assert.Nil(t, server)
}
func TestCreateMcpSSEServer_ToolArgumentFeatures(t *testing.T) {
originalServiceRegistry := make(map[mcpservice]serviceFactory)
for k, v := range serviceRegistry {
originalServiceRegistry[k] = v
}
defer func() {
serviceRegistry = originalServiceRegistry
}()
testCases := []struct {
name string
tool common.Tool
expectError bool
expectedErrorMsg string
}{
{
name: "StringWithEnum",
tool: common.Tool{
Name: "enumTool",
Description: "Tool with enum argument",
Arguments: []common.ToolArgument{
{Name: "enumArg", Description: "Enum argument", Required: true, ArgumentType: common.StringArgument, EnumValues: []string{"val1", "val2"}},
},
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return nil, nil },
},
expectError: false,
},
{
name: "ArrayWithSchema",
tool: common.Tool{
Name: "arrayTool",
Description: "Tool with array argument",
Arguments: []common.ToolArgument{
{Name: "arrayArg", Description: "Array argument", Required: true, ArgumentType: common.ArrayArgument, ArraySchema: map[string]any{"type": "string"}},
},
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return nil, nil },
},
expectError: false,
},
{
name: "ArrayWithoutSchema",
tool: common.Tool{
Name: "arrayNoSchemaTool",
Description: "Tool with array argument but no schema",
Arguments: []common.ToolArgument{
{Name: "arrayArgNoSchema", Description: "Array argument without schema", Required: true, ArgumentType: common.ArrayArgument, ArraySchema: map[string]any{}}, // Empty schema
},
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return nil, nil },
},
expectError: true,
expectedErrorMsg: "Array type argument arrayArgNoSchema does not have a schema defined",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockService := &mocks.MockArgumentService{CustomTools: []common.Tool{tc.tool}}
testServiceName := "testargfeatureservice_" + tc.name
testMcpService := mcpservice(testServiceName)
serviceRegistry[testMcpService] = func() common.McpService { return mockService }
defer delete(serviceRegistry, testMcpService)
flags := &mcpFlags{ServiceName: testServiceName}
server, err := createMcpSSEServer(flags)
if tc.expectError {
require.Error(t, err)
assert.Nil(t, server)
if tc.expectedErrorMsg != "" {
assert.Contains(t, err.Error(), tc.expectedErrorMsg)
}
} else {
require.NoError(t, err)
assert.NotNil(t, server)
}
})
}
}