blob: b3eab5a6c8630cf4b188e3e520c01848334451e4 [file] [log] [blame]
package main
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
"go.skia.org/infra/go/testutils/unittest"
)
const body = `
<!DOCTYPE html>
<html>
<body>
<details>
<summary>
<object type="image/svg+xml" data="https://dot.skia.org/dot?foo"></object>
</summary>
<pre>
graph {
Hello -- World
}
</pre>
</details>
</body>
</html>
`
func simpleTx(ctx context.Context, format, dotCode string) (string, error) {
return "<svg></svg>", nil
}
func failingTx(ctx context.Context, format, dotCode string) (string, error) {
return "", fmt.Errorf("Failed to transform.")
}
func Test_server_transformHandler(t *testing.T) {
unittest.LargeTest(t)
goodTS := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
_, err := fmt.Fprint(w, body)
assert.NoError(t, err)
}))
defer goodTS.Close()
allowed := regexp.MustCompile(`https://`)
notAllowed := regexp.MustCompile(`https://notallowed.example.org`)
goodRequest, err := http.NewRequest("GET", "https://dot.skia.org/dot?foo", nil)
assert.NoError(t, err)
goodRequest.Header.Set("Referer", goodTS.URL)
goodRequestBadTarget, err := http.NewRequest("GET", "https://dot.skia.org/dot?bar", nil)
assert.NoError(t, err)
goodRequestBadTarget.Header.Set("Referer", goodTS.URL)
goodRequestBadFormat, err := http.NewRequest("GET", "https://dot.skia.org/ls", nil)
assert.NoError(t, err)
goodRequestBadTarget.Header.Set("Referer", goodTS.URL)
goodRequestNoReferer, err := http.NewRequest("GET", "https://dot.skia.org/dot?foo", nil)
assert.NoError(t, err)
badTS := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusInternalServerError)
}))
defer badTS.Close()
badRequest, err := http.NewRequest("GET", "https://dot.skia.org/dot?foo", nil)
assert.NoError(t, err)
badRequest.Header.Set("Referer", badTS.URL)
type fields struct {
client *http.Client
tx transformer
allowed *regexp.Regexp
}
type args struct {
w *httptest.ResponseRecorder
r *http.Request
}
tests := []struct {
name string
fields fields
args args
statusCode int
body string
}{
{
name: "Simple success",
fields: fields{
client: goodTS.Client(),
tx: simpleTx,
allowed: allowed,
},
args: args{
w: httptest.NewRecorder(),
r: goodRequest,
},
statusCode: 200,
body: "<svg></svg>",
},
{
name: "Not allowed by domain regexp.",
fields: fields{
client: goodTS.Client(),
tx: simpleTx,
allowed: notAllowed,
},
args: args{
w: httptest.NewRecorder(),
r: goodRequest,
},
statusCode: 404,
body: "Not an allowed domain.\n",
},
{
name: "The transformation from the input format to SVG fails.",
fields: fields{
client: goodTS.Client(),
tx: failingTx,
allowed: allowed,
},
args: args{
w: httptest.NewRecorder(),
r: goodRequest,
},
statusCode: 404,
body: "Failed to transform.\n",
},
{
name: "Good request, but requested URI not found in source document",
fields: fields{
client: goodTS.Client(),
tx: simpleTx,
allowed: allowed,
},
args: args{
w: httptest.NewRecorder(),
r: goodRequestBadTarget,
},
statusCode: 404,
body: "Failed to find requester URL in source document.\n",
},
{
name: "Request transform by an unknown format",
fields: fields{
client: goodTS.Client(),
tx: simpleTx,
allowed: allowed,
},
args: args{
w: httptest.NewRecorder(),
r: goodRequestBadFormat,
},
statusCode: 404,
body: "Unknown format.\n",
},
{
name: "Request doesn't contain a referer header.",
fields: fields{
client: goodTS.Client(),
tx: simpleTx,
allowed: allowed,
},
args: args{
w: httptest.NewRecorder(),
r: goodRequestNoReferer,
},
statusCode: 404,
body: "Missing Referer header.\n",
},
{
name: "Source server returns error.",
fields: fields{
client: badTS.Client(),
tx: simpleTx,
allowed: allowed,
},
args: args{
w: httptest.NewRecorder(),
r: badRequest,
},
statusCode: 404,
body: "Failed to get 200 fetching referring page.\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := &server{
client: tt.fields.client,
tx: tt.fields.tx,
allowed: tt.fields.allowed,
}
srv.transformHandler(tt.args.w, tt.args.r)
assert.Equal(t, tt.statusCode, tt.args.w.Result().StatusCode)
assert.Equal(t, tt.body, tt.args.w.Body.String())
})
}
}