diff --git a/mackerel.go b/mackerel.go index 65117e2..15c9f26 100644 --- a/mackerel.go +++ b/mackerel.go @@ -2,6 +2,7 @@ package mackerel import ( "bytes" + "context" "encoding/json" "io" "log" @@ -131,68 +132,35 @@ func (c *Client) Request(req *http.Request) (resp *http.Response, err error) { } func requestGet[T any](client *Client, path string) (*T, error) { - return requestNoBody[T](client, http.MethodGet, path, nil) + return requestGetContext[T](context.Background(), client, path) } func requestGetWithParams[T any](client *Client, path string, params url.Values) (*T, error) { - return requestNoBody[T](client, http.MethodGet, path, params) + return requestGetWithParamsContext[T](context.Background(), client, path, params) } func requestGetAndReturnHeader[T any](client *Client, path string) (*T, http.Header, error) { - return requestInternal[T](client, http.MethodGet, path, nil, nil) + return requestGetAndReturnHeaderContext[T](context.Background(), client, path) } func requestPost[T any](client *Client, path string, payload any) (*T, error) { - return requestJSON[T](client, http.MethodPost, path, payload) + return requestPostContext[T](context.Background(), client, path, payload) } func requestPut[T any](client *Client, path string, payload any) (*T, error) { - return requestJSON[T](client, http.MethodPut, path, payload) + return requestPutContext[T](context.Background(), client, path, payload) } func requestDelete[T any](client *Client, path string) (*T, error) { - return requestNoBody[T](client, http.MethodDelete, path, nil) + return requestDeleteContext[T](context.Background(), client, path) } func requestJSON[T any](client *Client, method, path string, payload any) (*T, error) { - var body bytes.Buffer - err := json.NewEncoder(&body).Encode(payload) - if err != nil { - return nil, err - } - data, _, err := requestInternal[T](client, method, path, nil, &body) - return data, err -} - -func requestNoBody[T any](client *Client, method, path string, params url.Values) (*T, error) { - data, _, err := requestInternal[T](client, method, path, params, nil) - return data, err + return requestJSONContext[T](context.Background(), client, method, path, payload) } func requestInternal[T any](client *Client, method, path string, params url.Values, body io.Reader) (*T, http.Header, error) { - req, err := http.NewRequest(method, client.urlFor(path, params).String(), body) - if err != nil { - return nil, nil, err - } - if body != nil || method != http.MethodGet { - req.Header.Add("Content-Type", "application/json") - } - - resp, err := client.Request(req) - if err != nil { - return nil, nil, err - } - defer func() { - io.Copy(io.Discard, resp.Body) // nolint - resp.Body.Close() - }() - - var data T - err = json.NewDecoder(resp.Body).Decode(&data) - if err != nil { - return nil, nil, err - } - return &data, resp.Header, nil + return requestInternalContext[T](context.Background(), client, method, path, params, body) } func (c *Client) compatRequestJSON(method string, path string, payload interface{}) (*http.Response, error) { diff --git a/mackerel_context.go b/mackerel_context.go new file mode 100644 index 0000000..9c4f54c --- /dev/null +++ b/mackerel_context.go @@ -0,0 +1,75 @@ +package mackerel + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/url" +) + +func requestGetContext[T any](ctx context.Context, client *Client, path string) (*T, error) { + return requestNoBodyContext[T](ctx, client, http.MethodGet, path, nil) +} + +func requestGetWithParamsContext[T any](ctx context.Context, client *Client, path string, params url.Values) (*T, error) { + return requestNoBodyContext[T](ctx, client, http.MethodGet, path, params) +} + +func requestGetAndReturnHeaderContext[T any](ctx context.Context, client *Client, path string) (*T, http.Header, error) { + return requestInternalContext[T](ctx, client, http.MethodGet, path, nil, nil) +} + +func requestPostContext[T any](ctx context.Context, client *Client, path string, payload any) (*T, error) { + return requestJSONContext[T](ctx, client, http.MethodPost, path, payload) +} + +func requestPutContext[T any](ctx context.Context, client *Client, path string, payload any) (*T, error) { + return requestJSONContext[T](ctx, client, http.MethodPut, path, payload) +} + +func requestDeleteContext[T any](ctx context.Context, client *Client, path string) (*T, error) { + return requestNoBodyContext[T](ctx, client, http.MethodDelete, path, nil) +} + +func requestJSONContext[T any](ctx context.Context, client *Client, method, path string, payload any) (*T, error) { + var body bytes.Buffer + err := json.NewEncoder(&body).Encode(payload) + if err != nil { + return nil, err + } + data, _, err := requestInternalContext[T](ctx, client, method, path, nil, &body) + return data, err +} + +func requestNoBodyContext[T any](ctx context.Context, client *Client, method, path string, params url.Values) (*T, error) { + data, _, err := requestInternalContext[T](context.Background(), client, method, path, params, nil) + return data, err +} + +func requestInternalContext[T any](ctx context.Context, client *Client, method, path string, params url.Values, body io.Reader) (*T, http.Header, error) { + req, err := http.NewRequestWithContext(ctx, method, client.urlFor(path, params).String(), body) + if err != nil { + return nil, nil, err + } + if body != nil || method != http.MethodGet { + req.Header.Add("Content-Type", "application/json") + } + + resp, err := client.Request(req) + if err != nil { + return nil, nil, err + } + defer func() { + io.Copy(io.Discard, resp.Body) // nolint + resp.Body.Close() + }() + + var data T + err = json.NewDecoder(resp.Body).Decode(&data) + if err != nil { + return nil, nil, err + } + return &data, resp.Header, nil +} diff --git a/mackerel_test.go b/mackerel_test.go index c26f43d..78e8bb1 100644 --- a/mackerel_test.go +++ b/mackerel_test.go @@ -2,6 +2,8 @@ package mackerel import ( "bytes" + "context" + "errors" "fmt" "io" "log" @@ -11,6 +13,7 @@ import ( "os" "strings" "testing" + "time" ) func TestRequest(t *testing.T) { @@ -77,6 +80,24 @@ func Test_requestInternal(t *testing.T) { } } +func Test_requestInternalContext_cancel(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + time.Sleep(1 * time.Second) + fmt.Fprint(res, "ok") + })) + t.Cleanup(ts.Close) + + client, _ := NewClientWithOptions("dummy-key", ts.URL, false) + ctx, cancel := context.WithCancelCause(context.Background()) + expectedErr := errors.New("expected error") + cancel(expectedErr) + + _, _, err := requestInternalContext[struct{}](ctx, client, "GET", "/", nil, nil) + if cause := context.Cause(ctx); err == nil || !errors.Is(cause, expectedErr) { + t.Errorf("got %v; want %v", cause, expectedErr) + } +} + func TestUrlFor(t *testing.T) { client, _ := NewClientWithOptions("dummy-key", "https://example.com/with/ignored/path", false) expected := "https://example.com/some/super/endpoint"