From 16134b9877ee0ea66ae084cacf04ecff5e97a21a Mon Sep 17 00:00:00 2001 From: KCarretto Date: Tue, 9 Jul 2024 17:57:34 -0400 Subject: [PATCH] Allow downloads of HostFiles (#793) --- implants/lib/pb/src/generated/c2.rs | 1 + implants/lib/pb/src/generated/eldritch.rs | 1 + tavern/app.go | 3 + tavern/internal/cdn/download_hostfile.go | 62 ++++++++ tavern/internal/cdn/download_hostfile_test.go | 132 ++++++++++++++++++ tavern/internal/cdn/download_test.go | 10 +- tavern/internal/cdn/errors.go | 2 + 7 files changed, 206 insertions(+), 5 deletions(-) create mode 100644 tavern/internal/cdn/download_hostfile.go create mode 100644 tavern/internal/cdn/download_hostfile_test.go diff --git a/implants/lib/pb/src/generated/c2.rs b/implants/lib/pb/src/generated/c2.rs index 85563379..c388d9b6 100644 --- a/implants/lib/pb/src/generated/c2.rs +++ b/implants/lib/pb/src/generated/c2.rs @@ -1,3 +1,4 @@ +// This file is @generated by prost-build. /// Agent information to identify the type of beacon. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/implants/lib/pb/src/generated/eldritch.rs b/implants/lib/pb/src/generated/eldritch.rs index 6c7cac2b..ac4647c3 100644 --- a/implants/lib/pb/src/generated/eldritch.rs +++ b/implants/lib/pb/src/generated/eldritch.rs @@ -1,3 +1,4 @@ +// This file is @generated by prost-build. /// Tome for eldritch to execute. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/tavern/app.go b/tavern/app.go index 68295dca..bc4ecf95 100644 --- a/tavern/app.go +++ b/tavern/app.go @@ -203,6 +203,9 @@ func NewServer(ctx context.Context, options ...func(*Config)) (*Server, error) { AllowUnauthenticated: true, AllowUnactivated: true, }, + "/cdn/hostfiles/": tavernhttp.Endpoint{ + Handler: cdn.NewHostFileDownloadHandler(client, "/cdn/hostfiles/"), + }, "/cdn/upload": tavernhttp.Endpoint{ Handler: cdn.NewUploadHandler(client), }, diff --git a/tavern/internal/cdn/download_hostfile.go b/tavern/internal/cdn/download_hostfile.go new file mode 100644 index 00000000..d83e8f95 --- /dev/null +++ b/tavern/internal/cdn/download_hostfile.go @@ -0,0 +1,62 @@ +package cdn + +import ( + "bytes" + "net/http" + "path/filepath" + "strconv" + "strings" + + "realm.pub/tavern/internal/ent" + "realm.pub/tavern/internal/ent/hostfile" + "realm.pub/tavern/internal/errors" +) + +// maxFileIDLen is the maximum length the string for a file id may be. +const ( + maxFileIDLen = 256 +) + +// NewHostFileDownloadHandler returns an HTTP handler responsible for downloading a HostFile from the CDN. +func NewHostFileDownloadHandler(graph *ent.Client, prefix string) http.Handler { + return errors.WrapHandler(func(w http.ResponseWriter, req *http.Request) error { + ctx := req.Context() + + // Get the HostFile ID from the request URI + fileIDStr := strings.TrimPrefix(req.URL.Path, prefix) + if fileIDStr == "" || fileIDStr == "." || fileIDStr == "/" || len(fileIDStr) > maxFileIDLen { + return ErrInvalidFileID + } + fileID, err := strconv.Atoi(fileIDStr) + if err != nil { + return ErrInvalidFileID + } + + fileQuery := graph.HostFile.Query().Where(hostfile.ID(fileID)) + + // If hash was provided, check to see if the file has been updated. Note that + // http.ServeContent should handle this, but we want to avoid the expensive DB + // query where possible. + if hash := req.Header.Get(HeaderIfNoneMatch); hash != "" { + if exists := fileQuery.Clone().Where(hostfile.Hash(hash)).ExistX(ctx); exists { + return ErrFileNotModified + } + } + + // Ensure the file exists + if exists := fileQuery.Clone().ExistX(ctx); !exists { + return ErrFileNotFound + } + + f := fileQuery.OnlyX(ctx) + + // Set Etag to hash of file + w.Header().Set(HeaderEtag, f.Hash) + + // Set Content-Type and serve content + w.Header().Set("Content-Type", "application/octet-stream") + http.ServeContent(w, req, filepath.Base(f.Path), f.LastModifiedAt, bytes.NewReader(f.Content)) + + return nil + }) +} diff --git a/tavern/internal/cdn/download_hostfile_test.go b/tavern/internal/cdn/download_hostfile_test.go new file mode 100644 index 00000000..8e5a1fd4 --- /dev/null +++ b/tavern/internal/cdn/download_hostfile_test.go @@ -0,0 +1,132 @@ +package cdn_test + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "realm.pub/tavern/internal/c2/c2pb" + "realm.pub/tavern/internal/cdn" + "realm.pub/tavern/internal/ent/enttest" +) + +// TestDownloadHostFile asserts that the download handler exhibits expected behavior. +func TestDownloadHostFile(t *testing.T) { + graph := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer graph.Close() + + ctx := context.Background() + existingHost := graph.Host.Create(). + SetIdentifier("test-host"). + SetPlatform(c2pb.Host_PLATFORM_LINUX). + SaveX(ctx) + existingBeacon := graph.Beacon.Create(). + SetHost(existingHost). + SetIdentifier("ABCDEFG"). + SaveX(ctx) + existingTome := graph.Tome.Create(). + SetName("Wowza"). + SetDescription("Why did we require this?"). + SetAuthor("kcarretto"). + SetEldritch("blah"). + SaveX(ctx) + existingQuest := graph.Quest.Create(). + SetName("HelloWorld"). + SetTome(existingTome). + SaveX(ctx) + existingTask := graph.Task.Create(). + SetBeacon(existingBeacon). + SetQuest(existingQuest). + SaveX(ctx) + + existingHostFile := graph.HostFile.Create(). + SetPath("/existing/file"). + SetContent([]byte(`some data`)). + SetHost(existingHost). + SetTask(existingTask). + SaveX(ctx) + + handler := cdn.NewHostFileDownloadHandler(graph, "/download/") + + tests := []struct { + name string + + reqURL string + reqMethod string + reqBody io.Reader + reqHeaders map[string][]string + + wantStatus int + wantBody []byte + wantErr error + }{ + { + name: "Valid", + reqURL: fmt.Sprintf("/download/%d", existingHostFile.ID), + wantBody: existingHostFile.Content, + }, + { + name: "NotFound", + reqURL: "/download/123", + wantBody: []byte(fmt.Sprintf("%s\n", cdn.ErrFileNotFound.Error())), + wantStatus: cdn.ErrFileNotFound.StatusCode, + }, + { + name: "InvalidID/Alphabet", + reqURL: "/download/abcd", + wantBody: []byte(fmt.Sprintf("%s\n", cdn.ErrInvalidFileID.Error())), + wantStatus: cdn.ErrInvalidFileID.StatusCode, + }, + { + name: "InvalidID/Empty", + reqURL: "/download/", + wantBody: []byte(fmt.Sprintf("%s\n", cdn.ErrInvalidFileID.Error())), + wantStatus: cdn.ErrInvalidFileID.StatusCode, + }, + { + name: "Cached", + reqURL: fmt.Sprintf("/download/%d", existingHostFile.ID), + reqHeaders: map[string][]string{ + cdn.HeaderIfNoneMatch: {existingHostFile.Hash}, + }, + wantBody: []byte(fmt.Sprintf("%s\n", cdn.ErrFileNotModified.Error())), + wantStatus: cdn.ErrFileNotModified.StatusCode, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Build Request + req, reqErr := http.NewRequest(tc.reqMethod, tc.reqURL, tc.reqBody) + require.NoError(t, reqErr) + for key, vals := range tc.reqHeaders { + for _, val := range vals { + req.Header.Add(key, val) + } + } + + // Default to wanting OK status + if tc.wantStatus == 0 { + tc.wantStatus = http.StatusOK + } + + // Send request and record response + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + result := w.Result() + assert.Equal(t, tc.wantStatus, result.StatusCode) + + // Attempt to read the response body + body, err := io.ReadAll(result.Body) + require.NoError(t, err, "failed to parse response body") + defer result.Body.Close() + + assert.Equal(t, tc.wantBody, body) + }) + } +} diff --git a/tavern/internal/cdn/download_test.go b/tavern/internal/cdn/download_test.go index 9a6a0790..73e91909 100644 --- a/tavern/internal/cdn/download_test.go +++ b/tavern/internal/cdn/download_test.go @@ -16,7 +16,7 @@ import ( "realm.pub/tavern/internal/errors" ) -// TestUpload asserts that the download handler exhibits expected behavior. +// TestDownload asserts that the download handler exhibits expected behavior. func TestDownload(t *testing.T) { graph := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") defer graph.Close() @@ -26,7 +26,7 @@ func TestDownload(t *testing.T) { t.Run("File", newDownloadTest( graph, - newDownloadRequest(t, existingFile.Name), + newDownloadRequest(existingFile.Name), func(t *testing.T, fileContent []byte, err *errors.HTTP) { assert.Nil(t, err) assert.Equal(t, string(expectedContent), string(fileContent)) @@ -34,7 +34,7 @@ func TestDownload(t *testing.T) { )) t.Run("CachedFile", newDownloadTest( graph, - newDownloadRequest(t, existingFile.Name, withIfNoneMatchHeader(existingFile.Hash)), + newDownloadRequest(existingFile.Name, withIfNoneMatchHeader(existingFile.Hash)), func(t *testing.T, fileContent []byte, err *errors.HTTP) { require.NotNil(t, err) assert.Equal(t, http.StatusNotModified, err.StatusCode) @@ -44,7 +44,7 @@ func TestDownload(t *testing.T) { )) t.Run("NonExistentFile", newDownloadTest( graph, - newDownloadRequest(t, "ThisFileDoesNotExist"), + newDownloadRequest("ThisFileDoesNotExist"), func(t *testing.T, fileContent []byte, err *errors.HTTP) { require.NotNil(t, err) assert.Equal(t, http.StatusNotFound, err.StatusCode) @@ -92,7 +92,7 @@ func newDownloadTest(graph *ent.Client, req *http.Request, checks ...func(t *tes } // newDownloadRequest is a helper to create an http.Request for a file download -func newDownloadRequest(t *testing.T, fileName string, options ...func(*http.Request)) *http.Request { +func newDownloadRequest(fileName string, options ...func(*http.Request)) *http.Request { req := httptest.NewRequest(http.MethodGet, "/download/"+fileName, nil) for _, opt := range options { opt(req) diff --git a/tavern/internal/cdn/errors.go b/tavern/internal/cdn/errors.go index dd2aa7af..a7715531 100644 --- a/tavern/internal/cdn/errors.go +++ b/tavern/internal/cdn/errors.go @@ -6,11 +6,13 @@ import ( "realm.pub/tavern/internal/errors" ) +// ErrInvalidFileID occurs when an invalid file id is provided to the API. // ErrInvalidFileName occurs when an invalid file name is provided to the API. // ErrInvalidFileContent occurs when an upload is attempted with invalid file content. // ErrFileNotModified occurs when a download is attempted using the If-None-Match header and the file has not been modified. // ErrFileNotFound occurs when a download is attempted for a file that does not exist. var ( + ErrInvalidFileID = errors.NewHTTP("invalid file id provided", http.StatusBadRequest) ErrInvalidFileName = errors.NewHTTP("invalid file name provided", http.StatusBadRequest) ErrInvalidFileContent = errors.NewHTTP("invalid file content provided", http.StatusBadRequest) ErrFileNotModified = errors.NewHTTP("file has not been modified", http.StatusNotModified)