blob: 099048a4f80f4878b3a1f3f2657312a5f2d80a7a [file] [log] [blame]
package login
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/securecookie"
ttlcache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.skia.org/infra/go/deepequal/assertdeep"
loginMocks "go.skia.org/infra/go/login/mocks"
"go.skia.org/infra/go/secret"
"go.skia.org/infra/go/secret/mocks"
"go.skia.org/infra/go/testutils"
"golang.org/x/oauth2"
"google.golang.org/api/googleapi"
oauth2_api "google.golang.org/api/oauth2/v2"
"google.golang.org/api/option"
)
const (
saltForTesting = "salt"
sessionIDForTesting = "abcdef0123456"
codeForTesting = "oauth2 code for testing"
bearerToken = "fake-bearer-token"
)
var (
errMockError = fmt.Errorf("error returned from mocks")
)
func initLoginForTests(t *testing.T) {
ctx := context.Background()
err := initLogin(ctx, "id", "secret", "http://localhost", saltForTesting, SkiaOrg)
require.NoError(t, err)
}
func TestLoginURL(t *testing.T) {
initLoginForTests(t)
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "http://example.com/", nil)
require.NoError(t, err)
r.Header.Set("Referer", "https://foo.org")
url := loginURL(w, r)
assert.Contains(t, w.HeaderMap.Get("Set-Cookie"), sessionCookieName, "Session cookie should be set.")
assert.Contains(t, w.HeaderMap.Get("Set-Cookie"), "SameSite=None", "SameSite should be set.")
assert.Contains(t, w.HeaderMap.Get("Set-Cookie"), "Secure", "Secure should be set.")
assert.Contains(t, url, "approval_prompt=auto", "Not forced into prompt.")
cookie := &http.Cookie{
Name: sessionCookieName,
Value: "some-random-state",
}
assert.Contains(t, url, "%3Ahttps%3A%2F%2Ffoo.org")
r.AddCookie(cookie)
w = httptest.NewRecorder()
url = loginURL(w, r)
assert.NotContains(t, w.HeaderMap.Get("Set-Cookie"), sessionCookieName, "Session cookie should be set.")
assert.Contains(t, url, "some-random-state", "Pass state in Login URL.")
}
func TestLoggedInAs(t *testing.T) {
initLoginForTests(t)
for _, d := range AllDomainNames {
t.Run(string(d), func(t *testing.T) {
testLoggedInAs(t, d)
})
}
}
func testLoggedInAs(t *testing.T, domain DomainName) {
err := setDomain(domain)
require.NoError(t, err)
r, err := http.NewRequest("GET", fmt.Sprintf("http://www.%s/", domain), nil)
require.NoError(t, err)
email, err := AuthenticatedAs(r)
require.Error(t, err)
assert.Equal(t, email, "", "No skid cookie means not logged in.")
s := Session{
Email: "fred@chromium.org",
ID: "12345",
AuthScope: emailScope,
Token: nil,
}
cookie, err := cookieFor(&s, r)
assert.NoError(t, err)
assert.Equal(t, string(domain), cookie.Domain)
r.AddCookie(cookie)
email, err = AuthenticatedAs(r)
require.NoError(t, err)
assert.Equal(t, email, "fred@chromium.org", "Correctly get logged in email.")
w := httptest.NewRecorder()
url := loginURL(w, r)
assert.Contains(t, url, "approval_prompt=auto", "Not forced into prompt.")
}
func TestDomainFromHost(t *testing.T) {
initLoginForTests(t)
assert.Equal(t, "localhost", domainFromHost("localhost:10110"))
assert.Equal(t, "localhost", domainFromHost("localhost"))
assert.Equal(t, "skia.org", domainFromHost("skia.org"))
assert.Equal(t, "skia.org", domainFromHost("perf.skia.org"))
assert.Equal(t, "skia.org", domainFromHost("perf.skia.org:443"))
assert.Equal(t, "skia.org", domainFromHost("example.com:443"))
}
func TestDomainFromHost_LuciApp(t *testing.T) {
err := initLogin(context.Background(), "id", "secret", "", saltForTesting, LuciApp)
require.NoError(t, err)
assert.Equal(t, "localhost", domainFromHost("localhost:10110"))
assert.Equal(t, "localhost", domainFromHost("localhost"))
assert.Equal(t, "luci.app", domainFromHost("luci.app"))
assert.Equal(t, "luci.app", domainFromHost("perf.luci.app"))
assert.Equal(t, "luci.app", domainFromHost("perf.luci.app:443"))
assert.Equal(t, "luci.app", domainFromHost("example.com:443"))
assert.Equal(t, "https://luci.app/oauth2callback/",
oauthConfig.(*oauth2.Config).RedirectURL)
}
func TestIsAuthorized(t *testing.T) {
initLoginForTests(t)
assert.True(t, isAuthorized("fred@chromium.org"))
assert.True(t, isAuthorized("service-account@proj.iam.gserviceaccount.com"))
assert.False(t, isAuthorized("this is not an email"))
}
func TestTryLoadingFromGCPSecret_Success(t *testing.T) {
ctx := context.Background()
client := &mocks.Client{}
secretValue := `{
"client_id": "fake-client-id",
"client_secret": "fake-client-secret"
}`
client.On("Get", ctx, loginSecretProject, clientIDandSecretName, secret.VersionLatest).Return(secretValue, nil)
client.On("Get", ctx, loginSecretProject, saltSecretName, secret.VersionLatest).Return("fake-salt", nil)
cookieSalt, clientID, clientSecret, err := tryLoadingFromGCPSecret(ctx, client)
require.NoError(t, err)
require.Equal(t, "fake-salt", cookieSalt)
require.Equal(t, "fake-client-id", clientID)
require.Equal(t, "fake-client-secret", clientSecret)
}
func TestLoadSaltFromGCPSecret_Success(t *testing.T) {
ctx := context.Background()
client := &mocks.Client{}
client.On("Get", ctx, loginSecretProject, saltSecretName, secret.VersionLatest).Return("fake-salt", nil)
cookieSalt, err := loadSaltFromGCPSecret(ctx, client)
require.NoError(t, err)
require.Equal(t, "fake-salt", cookieSalt)
}
func TestStateFromPartsAndPartsFromStateRoundTrip_Success(t *testing.T) {
sessionIDSent := "sessionID"
redirectURLSent := "https://example.org"
state := stateFromParts(sessionIDSent, saltForTesting, redirectURLSent)
sessionID, hash, redirectURL, err := partsFromState(state)
require.NoError(t, err)
require.Equal(t, sessionID, sessionIDSent)
require.Equal(t, redirectURL, redirectURLSent)
require.Equal(t, hashForURL(saltForTesting, redirectURL), hash)
}
func TestPartsFromState_MissingOnePart_ReturnsError(t *testing.T) {
sessionIDSent := "sessionID"
redirectURLSent := "https://example.org"
state := stateFromParts(sessionIDSent, saltForTesting, redirectURLSent)
state = strings.Join(strings.Split(state, ".")[1:], ".")
_, _, _, err := partsFromState(state)
require.ErrorIs(t, err, errMalformedState)
}
func TestOAuth2CallbackHandler_NoCookieSet_Returns500(t *testing.T) {
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "http://example.com/", nil)
require.NoError(t, err)
OAuth2CallbackHandler(w, r)
require.Contains(t, w.Body.String(), "Missing session state")
}
func setupForOAuth2CallbackHandlerTest(t *testing.T, url string) (*httptest.ResponseRecorder, *http.Request) {
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", url, nil)
require.NoError(t, err)
cookieSalt = saltForTesting
secureCookie = securecookie.New([]byte(cookieSalt), nil)
cookie := &http.Cookie{
Name: sessionCookieName,
Value: sessionIDForTesting,
}
r.AddCookie(cookie)
return w, r
}
func TestOAuth2CallbackHandler_CookieSetButStateNotSet_Returns500(t *testing.T) {
w, r := setupForOAuth2CallbackHandlerTest(t, "https://skia.org/")
OAuth2CallbackHandler(w, r)
require.Contains(t, w.Body.String(), "Invalid session state")
}
func TestOAuth2CallbackHandler_CookieSetButSessionIDDoesNotMatchSessionIDInState_Returns500(t *testing.T) {
state := stateFromParts("wrongSessionID", saltForTesting, "/foo")
u := fmt.Sprintf("https://skia.org/?state=%s", state)
w, r := setupForOAuth2CallbackHandlerTest(t, u)
OAuth2CallbackHandler(w, r)
require.Contains(t, w.Body.String(), "Session state doesn't match callback state.")
}
func TestOAuth2CallbackHandler_HashOfRedirectURLDoesNotMatch_Returns500(t *testing.T) {
state := stateFromParts(sessionIDForTesting, "using the wrong salt here will change the hash", "/foo")
u := fmt.Sprintf("https://skia.org/?state=%s", state)
w, r := setupForOAuth2CallbackHandlerTest(t, u)
OAuth2CallbackHandler(w, r)
require.Contains(t, w.Body.String(), "Invalid redirect URL")
}
func TestOAuth2CallbackHandler_ExchangeReturnsError_Returns500(t *testing.T) {
state := stateFromParts(sessionIDForTesting, saltForTesting, "/foo")
u := fmt.Sprintf("https://skia.org/?state=%s&code=%s", state, codeForTesting)
w, r := setupForOAuth2CallbackHandlerTest(t, u)
oauthConfigMock := loginMocks.NewOAuthConfig(t)
oauthConfigMock.On("Exchange", testutils.AnyContext, codeForTesting).Return(nil, errMockError)
oauthConfig = oauthConfigMock
OAuth2CallbackHandler(w, r)
require.Contains(t, w.Body.String(), "Failed to authenticate")
}
func TestOAuth2CallbackHandler_ExtractEmailAndAccountIDFromTokenReturnsError_Returns500(t *testing.T) {
state := stateFromParts(sessionIDForTesting, saltForTesting, "/foo")
u := fmt.Sprintf("https://skia.org/?state=%s&code=%s", state, codeForTesting)
w, r := setupForOAuth2CallbackHandlerTest(t, u)
oauthConfigMock := loginMocks.NewOAuthConfig(t)
token := &oauth2.Token{}
token = token.WithExtra(map[string]string{})
oauthConfigMock.On("Exchange", testutils.AnyContext, codeForTesting).Return(token, nil)
oauthConfig = oauthConfigMock
OAuth2CallbackHandler(w, r)
require.Contains(t, w.Body.String(), "No id_token returned")
}
func TestOAuth2CallbackHandler_HappyPath(t *testing.T) {
state := stateFromParts(sessionIDForTesting, saltForTesting, "/foo")
u := fmt.Sprintf("https://skia.org/?state=%s&code=%s", state, codeForTesting)
w, r := setupForOAuth2CallbackHandlerTest(t, u)
oauthConfigMock := loginMocks.NewOAuthConfig(t)
middle := base64.URLEncoding.EncodeToString([]byte(`{
"email": "somebody@example.org",
"sub": "123"
}`))
token := &oauth2.Token{}
tokenWith := token.WithExtra(map[string]interface{}{idTokenKeyName: "a." + middle + ".c"})
oauthConfigMock.On("Exchange", testutils.AnyContext, codeForTesting).Return(tokenWith, nil)
oauthConfig = oauthConfigMock
OAuth2CallbackHandler(w, r)
require.Contains(t, w.Body.String(), "Found")
require.Equal(t, w.Result().StatusCode, http.StatusFound)
require.Equal(t, "/foo", w.Header().Get("Location"))
}
func TestExtractEmailAndAccountIDFromToken_InvalidForm_ReturnsFailureMessage(t *testing.T) {
token := &oauth2.Token{}
tokenWith := token.WithExtra(map[string]interface{}{idTokenKeyName: "a.b"})
_, _, msg := extractEmailAndAccountIDFromToken(tokenWith)
require.Contains(t, msg, "Invalid id_token")
}
func TestExtractEmailAndAccountIDFromToken_InvalidBase64_ReturnsFailureMessage(t *testing.T) {
token := &oauth2.Token{}
tokenWith := token.WithExtra(map[string]interface{}{idTokenKeyName: "a.??;;::not-valid-base64.c"})
_, _, msg := extractEmailAndAccountIDFromToken(tokenWith)
require.Contains(t, msg, "Failed to base64 decode id_token")
}
func TestExtractEmailAndAccountIDFromToken_DecodedBase64IsNotValidJSON_ReturnsFailureMessage(t *testing.T) {
middle := base64.URLEncoding.EncodeToString([]byte("{not-valid-json"))
token := &oauth2.Token{}
tokenWith := token.WithExtra(map[string]interface{}{idTokenKeyName: "a." + middle + ".c"})
_, _, msg := extractEmailAndAccountIDFromToken(tokenWith)
require.Contains(t, msg, "Failed to JSON decode id_token")
}
func TestExtractEmailAndAccountIDFromToken_EmailIsNotValidJSON_ReturnsFailureMessage(t *testing.T) {
middle := base64.URLEncoding.EncodeToString([]byte(`{
"email": "not-a-valid-email-address",
"sub": "123"
}`))
token := &oauth2.Token{}
tokenWith := token.WithExtra(map[string]interface{}{idTokenKeyName: "a." + middle + ".c"})
_, _, msg := extractEmailAndAccountIDFromToken(tokenWith)
require.Contains(t, msg, "Invalid email address received")
}
func TestExtractEmailAndAccountIDFromToken_HappyPath(t *testing.T) {
middle := base64.URLEncoding.EncodeToString([]byte(`{
"email": "somebody@example.org",
"sub": "123"
}`))
token := &oauth2.Token{}
tokenWith := token.WithExtra(map[string]interface{}{idTokenKeyName: "a." + middle + ".c"})
email, id, msg := extractEmailAndAccountIDFromToken(tokenWith)
require.Empty(t, msg)
require.Equal(t, "somebody@example.org", email)
require.Equal(t, "123", id)
}
func TestSetDomain_ValidDomainName_Success(t *testing.T) {
for _, d := range AllDomainNames {
t.Run(string(d), func(t *testing.T) {
require.NoError(t, setDomain(d))
})
}
}
func TestSetDomain_UnknonwDomainName_ReturnsError(t *testing.T) {
require.Error(t, setDomain(DomainName("this-in-not-a-known-domain.example.com")))
}
func setupForValidateBearerToken(t *testing.T, tokenInfo *oauth2_api.Tokeninfo) {
// Create an HTTP server that emulates the Token Validation endpoint, that
// takes in an access token and returns a Tokeninfo.
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, r.ParseForm())
require.Equal(t, bearerToken, r.FormValue("access_token"))
require.NoError(t, json.NewEncoder(w).Encode(tokenInfo))
}))
// Replace the default tokenValidatorService with one that points to the
// emulation service built above.
var err error
tokenValidatorService, err = oauth2_api.NewService(context.Background(), option.WithHTTPClient(testServer.Client()),
option.WithEndpoint(testServer.URL))
// Create a fresh cache.
validBearerTokenCache = ttlcache.New(validBearerTokenCacheLifetime, validBearerTokenCacheCleanup)
require.NoError(t, err)
}
func TestValidateBearerToken_HappyPath(t *testing.T) {
expectedTokenInfo := &oauth2_api.Tokeninfo{
Email: "user@example.org",
ExpiresIn: 3600, // seconds
VerifiedEmail: true,
}
setupForValidateBearerToken(t, expectedTokenInfo)
actual, err := validateBearerToken(context.Background(), bearerToken)
require.NoError(t, err)
// The TokenInfo should be identical modulo the ServerResponse.
actual.ServerResponse = googleapi.ServerResponse{}
assertdeep.Equal(t, expectedTokenInfo, actual)
}
func TestValidateBearerToken_ValidatedTokenExistsInCache_Success(t *testing.T) {
expectedTokenInfo := &oauth2_api.Tokeninfo{
Email: "user@example.org",
ExpiresIn: 3600, // seconds
VerifiedEmail: true,
}
setupForValidateBearerToken(t, expectedTokenInfo)
// Add token to cache.
validBearerTokenCache.Set(bearerToken, expectedTokenInfo, ttlcache.DefaultExpiration)
// Nil out the tokenValidatorService, to prove we don't call it.
tokenValidatorService = nil
actual, err := validateBearerToken(context.Background(), bearerToken)
require.NoError(t, err)
assertdeep.Equal(t, expectedTokenInfo, actual)
}
func TestValidateBearerToken_FirstRequestAddsTokenToCache_SecondCallReturnsTokenFromCache(t *testing.T) {
expectedTokenInfo := &oauth2_api.Tokeninfo{
Email: "user@example.org",
ExpiresIn: 3600, // seconds
VerifiedEmail: true,
}
setupForValidateBearerToken(t, expectedTokenInfo)
actual, err := validateBearerToken(context.Background(), bearerToken)
require.NoError(t, err)
// The TokenInfo should be identical modulo the ServerResponse.
actual.ServerResponse = googleapi.ServerResponse{}
assertdeep.Equal(t, expectedTokenInfo, actual)
// Nil out the tokenValidatorService, to prove we don't call it.
tokenValidatorService = nil
// Call validateBearerToken again with the same bearer token.
actual, err = validateBearerToken(context.Background(), bearerToken)
require.NoError(t, err)
assertdeep.Equal(t, expectedTokenInfo, actual)
}
func TestValidateBearerToken_EmailNotValidated_ReturnsError(t *testing.T) {
expectedTokenInfo := &oauth2_api.Tokeninfo{
Email: "user@example.org",
ExpiresIn: 3600, // seconds
VerifiedEmail: false,
}
setupForValidateBearerToken(t, expectedTokenInfo)
_, err := validateBearerToken(context.Background(), bearerToken)
require.Contains(t, err.Error(), "email not verified")
}
func TestValidateBearerToken_TokenExpired_ReturnsError(t *testing.T) {
expectedTokenInfo := &oauth2_api.Tokeninfo{
Email: "user@example.org",
ExpiresIn: 0, // seconds
VerifiedEmail: true,
}
setupForValidateBearerToken(t, expectedTokenInfo)
_, err := validateBearerToken(context.Background(), bearerToken)
require.Contains(t, err.Error(), "token is expired")
}
func TestViaBearerToken_HappyPath(t *testing.T) {
expectedTokenInfo := &oauth2_api.Tokeninfo{
Email: "user@example.org",
ExpiresIn: 3600, // seconds
VerifiedEmail: true,
}
setupForValidateBearerToken(t, expectedTokenInfo)
r := httptest.NewRequest("GET", "/", nil)
r.Header.Add("Authorization", fmt.Sprintf("Bearer %s", bearerToken))
email, err := viaBearerToken(r)
require.NoError(t, err)
require.Equal(t, "user@example.org", email)
}