blob: 0be333c7a9b85a9b23081d2a19949c0c720fe3f5 [file] [log] [blame]
// Package oauth2redirect is a reverse proxy that runs in front of applications
// and takes care of handling the oauth2 redirect leg of the OAuth 3-legged
// flow. It passes all other traffic to the application it is running in front
// of.
//
// This is useful so that we don't need to redeploy docsyserver everytime a
// change is made to //go/login, instead just this smaller proxy can be deployed
// at the same time as go/auth-proxy is deployed.
package oauth2redirect
import (
"context"
"flag"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"go.skia.org/infra/go/cleanup"
"go.skia.org/infra/go/common"
"go.skia.org/infra/go/httputils"
"go.skia.org/infra/go/login"
"go.skia.org/infra/go/skerr"
"go.skia.org/infra/go/sklog"
)
const (
appName = "oauth2redirect"
serverReadTimeout = time.Hour
serverWriteTimeout = time.Hour
drainTime = time.Minute
)
type proxy struct {
reverseProxy http.Handler
}
func newProxy(target *url.URL) *proxy {
reverseProxy := httputil.NewSingleHostReverseProxy(target)
return &proxy{
reverseProxy: reverseProxy,
}
}
func (p *proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case login.DefaultOAuth2Callback:
login.OAuth2CallbackHandler(w, r)
case login.LoginPath:
login.AuthenticateUser(w, r)
case login.LogoutPath:
login.UnauthenticateUser(w, r)
default:
p.reverseProxy.ServeHTTP(w, r)
}
}
// App is the oauth2redirect application.
type App struct {
port string
promPort string
local bool
targetPort string
domain string
target *url.URL
server *http.Server
proxy *proxy
}
// Flagset constructs a flag.FlagSet for the App.
func (a *App) Flagset() *flag.FlagSet {
fs := flag.NewFlagSet(appName, flag.ExitOnError)
fs.StringVar(&a.port, "port", ":8000", "HTTP service address (e.g., ':8000')")
fs.StringVar(&a.promPort, "prom-port", ":20000", "Metrics service address (e.g., ':10110')")
fs.BoolVar(&a.local, "local", false, "Running locally if true. As opposed to in production.")
fs.StringVar(&a.targetPort, "target_port", ":9000", "The port we are proxying to, or a full URL.")
fs.StringVar(&a.domain, "domain", string(login.SkiaOrg), fmt.Sprintf("The domain to handle oauth2 callbacks for, choose from: %q", login.AllDomainNames))
return fs
}
func newEmptyApp() *App {
return &App{
proxy: nil,
}
}
// New returns a new *App.
func New(ctx context.Context, opts ...login.InitOption) (*App, error) {
ret := newEmptyApp()
err := common.InitWith(
appName,
common.PrometheusOpt(&ret.promPort),
common.FlagSetOpt(ret.Flagset()),
)
if err != nil {
return nil, skerr.Wrap(err)
}
opts = append(opts, login.DomainName(ret.domain))
err = login.Init(ctx, "", opts...)
if err != nil {
sklog.Fatal(err)
}
target, err := parseTargetPort(ret.targetPort)
if err != nil {
return nil, skerr.Wrap(err)
}
ret.target = target
ret.registerCleanup()
return ret, nil
}
// Parses either a port, e.g. ":8000", or a full URL into a *url.URL.
func parseTargetPort(u string) (*url.URL, error) {
if strings.HasPrefix(u, ":") {
return url.Parse(fmt.Sprintf("http://localhost%s", u))
}
return url.Parse(u)
}
func (a *App) registerCleanup() {
cleanup.AtExit(func() {
if a.server != nil {
sklog.Info("Shutdown server gracefully.")
ctx, cancel := context.WithTimeout(context.Background(), drainTime)
err := a.server.Shutdown(ctx)
if err != nil {
sklog.Error(err)
}
cancel()
}
})
}
// Run starts the application serving, it does not return unless there is an
// error or the passed in context is cancelled.
func (a *App) Run(ctx context.Context) error {
a.proxy = newProxy(a.target)
var h http.Handler = a.proxy
if !a.local {
h = httputils.HealthzAndHTTPS(h)
}
server := &http.Server{
Addr: a.port,
Handler: h,
ReadTimeout: serverReadTimeout,
WriteTimeout: serverWriteTimeout,
MaxHeaderBytes: 1 << 20,
}
a.server = server
err := server.ListenAndServe()
if err == http.ErrServerClosed {
// This is an orderly shutdown.
return nil
}
return skerr.Wrap(err)
}
// Main constructs and runs the application. This function will only return on failure.
func Main() error {
ctx := context.Background()
app, err := New(ctx)
if err != nil {
return skerr.Wrap(err)
}
return app.Run(ctx)
}