Skip to content

Commit

Permalink
Allow downloads of HostFiles (#793)
Browse files Browse the repository at this point in the history
  • Loading branch information
KCarretto authored Jul 9, 2024
1 parent 6669bc2 commit 16134b9
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 5 deletions.
1 change: 1 addition & 0 deletions implants/lib/pb/src/generated/c2.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions implants/lib/pb/src/generated/eldritch.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions tavern/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ func NewServer(ctx context.Context, options ...func(*Config)) (*Server, error) {
AllowUnauthenticated: true,
AllowUnactivated: true,
},
"/cdn/hostfiles/": tavernhttp.Endpoint{
Handler: cdn.NewHostFileDownloadHandler(client, "/cdn/hostfiles/"),
},
"/cdn/upload": tavernhttp.Endpoint{
Handler: cdn.NewUploadHandler(client),
},
Expand Down
62 changes: 62 additions & 0 deletions tavern/internal/cdn/download_hostfile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package cdn

import (
"bytes"
"net/http"
"path/filepath"
"strconv"
"strings"

"realm.pub/tavern/internal/ent"
"realm.pub/tavern/internal/ent/hostfile"
"realm.pub/tavern/internal/errors"
)

// maxFileIDLen is the maximum length the string for a file id may be.
const (
maxFileIDLen = 256
)

// NewHostFileDownloadHandler returns an HTTP handler responsible for downloading a HostFile from the CDN.
func NewHostFileDownloadHandler(graph *ent.Client, prefix string) http.Handler {
return errors.WrapHandler(func(w http.ResponseWriter, req *http.Request) error {
ctx := req.Context()

// Get the HostFile ID from the request URI
fileIDStr := strings.TrimPrefix(req.URL.Path, prefix)
if fileIDStr == "" || fileIDStr == "." || fileIDStr == "/" || len(fileIDStr) > maxFileIDLen {
return ErrInvalidFileID
}
fileID, err := strconv.Atoi(fileIDStr)
if err != nil {
return ErrInvalidFileID
}

fileQuery := graph.HostFile.Query().Where(hostfile.ID(fileID))

// If hash was provided, check to see if the file has been updated. Note that
// http.ServeContent should handle this, but we want to avoid the expensive DB
// query where possible.
if hash := req.Header.Get(HeaderIfNoneMatch); hash != "" {
if exists := fileQuery.Clone().Where(hostfile.Hash(hash)).ExistX(ctx); exists {
return ErrFileNotModified
}
}

// Ensure the file exists
if exists := fileQuery.Clone().ExistX(ctx); !exists {
return ErrFileNotFound
}

f := fileQuery.OnlyX(ctx)

// Set Etag to hash of file
w.Header().Set(HeaderEtag, f.Hash)

// Set Content-Type and serve content
w.Header().Set("Content-Type", "application/octet-stream")
http.ServeContent(w, req, filepath.Base(f.Path), f.LastModifiedAt, bytes.NewReader(f.Content))

return nil
})
}
132 changes: 132 additions & 0 deletions tavern/internal/cdn/download_hostfile_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package cdn_test

import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"realm.pub/tavern/internal/c2/c2pb"
"realm.pub/tavern/internal/cdn"
"realm.pub/tavern/internal/ent/enttest"
)

// TestDownloadHostFile asserts that the download handler exhibits expected behavior.
func TestDownloadHostFile(t *testing.T) {
graph := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer graph.Close()

ctx := context.Background()
existingHost := graph.Host.Create().
SetIdentifier("test-host").
SetPlatform(c2pb.Host_PLATFORM_LINUX).
SaveX(ctx)
existingBeacon := graph.Beacon.Create().
SetHost(existingHost).
SetIdentifier("ABCDEFG").
SaveX(ctx)
existingTome := graph.Tome.Create().
SetName("Wowza").
SetDescription("Why did we require this?").
SetAuthor("kcarretto").
SetEldritch("blah").
SaveX(ctx)
existingQuest := graph.Quest.Create().
SetName("HelloWorld").
SetTome(existingTome).
SaveX(ctx)
existingTask := graph.Task.Create().
SetBeacon(existingBeacon).
SetQuest(existingQuest).
SaveX(ctx)

existingHostFile := graph.HostFile.Create().
SetPath("/existing/file").
SetContent([]byte(`some data`)).
SetHost(existingHost).
SetTask(existingTask).
SaveX(ctx)

handler := cdn.NewHostFileDownloadHandler(graph, "/download/")

tests := []struct {
name string

reqURL string
reqMethod string
reqBody io.Reader
reqHeaders map[string][]string

wantStatus int
wantBody []byte
wantErr error
}{
{
name: "Valid",
reqURL: fmt.Sprintf("/download/%d", existingHostFile.ID),
wantBody: existingHostFile.Content,
},
{
name: "NotFound",
reqURL: "/download/123",
wantBody: []byte(fmt.Sprintf("%s\n", cdn.ErrFileNotFound.Error())),
wantStatus: cdn.ErrFileNotFound.StatusCode,
},
{
name: "InvalidID/Alphabet",
reqURL: "/download/abcd",
wantBody: []byte(fmt.Sprintf("%s\n", cdn.ErrInvalidFileID.Error())),
wantStatus: cdn.ErrInvalidFileID.StatusCode,
},
{
name: "InvalidID/Empty",
reqURL: "/download/",
wantBody: []byte(fmt.Sprintf("%s\n", cdn.ErrInvalidFileID.Error())),
wantStatus: cdn.ErrInvalidFileID.StatusCode,
},
{
name: "Cached",
reqURL: fmt.Sprintf("/download/%d", existingHostFile.ID),
reqHeaders: map[string][]string{
cdn.HeaderIfNoneMatch: {existingHostFile.Hash},
},
wantBody: []byte(fmt.Sprintf("%s\n", cdn.ErrFileNotModified.Error())),
wantStatus: cdn.ErrFileNotModified.StatusCode,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Build Request
req, reqErr := http.NewRequest(tc.reqMethod, tc.reqURL, tc.reqBody)
require.NoError(t, reqErr)
for key, vals := range tc.reqHeaders {
for _, val := range vals {
req.Header.Add(key, val)
}
}

// Default to wanting OK status
if tc.wantStatus == 0 {
tc.wantStatus = http.StatusOK
}

// Send request and record response
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)

result := w.Result()
assert.Equal(t, tc.wantStatus, result.StatusCode)

// Attempt to read the response body
body, err := io.ReadAll(result.Body)
require.NoError(t, err, "failed to parse response body")
defer result.Body.Close()

assert.Equal(t, tc.wantBody, body)
})
}
}
10 changes: 5 additions & 5 deletions tavern/internal/cdn/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"realm.pub/tavern/internal/errors"
)

// TestUpload asserts that the download handler exhibits expected behavior.
// TestDownload asserts that the download handler exhibits expected behavior.
func TestDownload(t *testing.T) {
graph := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer graph.Close()
Expand All @@ -26,15 +26,15 @@ func TestDownload(t *testing.T) {

t.Run("File", newDownloadTest(
graph,
newDownloadRequest(t, existingFile.Name),
newDownloadRequest(existingFile.Name),
func(t *testing.T, fileContent []byte, err *errors.HTTP) {
assert.Nil(t, err)
assert.Equal(t, string(expectedContent), string(fileContent))
},
))
t.Run("CachedFile", newDownloadTest(
graph,
newDownloadRequest(t, existingFile.Name, withIfNoneMatchHeader(existingFile.Hash)),
newDownloadRequest(existingFile.Name, withIfNoneMatchHeader(existingFile.Hash)),
func(t *testing.T, fileContent []byte, err *errors.HTTP) {
require.NotNil(t, err)
assert.Equal(t, http.StatusNotModified, err.StatusCode)
Expand All @@ -44,7 +44,7 @@ func TestDownload(t *testing.T) {
))
t.Run("NonExistentFile", newDownloadTest(
graph,
newDownloadRequest(t, "ThisFileDoesNotExist"),
newDownloadRequest("ThisFileDoesNotExist"),
func(t *testing.T, fileContent []byte, err *errors.HTTP) {
require.NotNil(t, err)
assert.Equal(t, http.StatusNotFound, err.StatusCode)
Expand Down Expand Up @@ -92,7 +92,7 @@ func newDownloadTest(graph *ent.Client, req *http.Request, checks ...func(t *tes
}

// newDownloadRequest is a helper to create an http.Request for a file download
func newDownloadRequest(t *testing.T, fileName string, options ...func(*http.Request)) *http.Request {
func newDownloadRequest(fileName string, options ...func(*http.Request)) *http.Request {
req := httptest.NewRequest(http.MethodGet, "/download/"+fileName, nil)
for _, opt := range options {
opt(req)
Expand Down
2 changes: 2 additions & 0 deletions tavern/internal/cdn/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import (
"realm.pub/tavern/internal/errors"
)

// ErrInvalidFileID occurs when an invalid file id is provided to the API.
// ErrInvalidFileName occurs when an invalid file name is provided to the API.
// ErrInvalidFileContent occurs when an upload is attempted with invalid file content.
// ErrFileNotModified occurs when a download is attempted using the If-None-Match header and the file has not been modified.
// ErrFileNotFound occurs when a download is attempted for a file that does not exist.
var (
ErrInvalidFileID = errors.NewHTTP("invalid file id provided", http.StatusBadRequest)
ErrInvalidFileName = errors.NewHTTP("invalid file name provided", http.StatusBadRequest)
ErrInvalidFileContent = errors.NewHTTP("invalid file content provided", http.StatusBadRequest)
ErrFileNotModified = errors.NewHTTP("file has not been modified", http.StatusNotModified)
Expand Down

0 comments on commit 16134b9

Please sign in to comment.