Skip to content

Commit

Permalink
Fix --html.template option (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkroepke authored Oct 28, 2023
1 parent c56e850 commit 17505c9
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 15 deletions.
4 changes: 2 additions & 2 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ func FlagSet() *flag.FlagSet {
"Path to tls server certificate. (env: CONFIG_HTTP_CERT)",
)
flagSet.String(
"http.callback-template-path",
"http.template",
"",
"Path to a HTML file which is displayed at the end of the screen. (env: CONFIG_HTTP_CALLBACK_TEMPLATE_PATH)",
"Path to a HTML file which is displayed at the end of the screen. (env: CONFIG_HTTP_TEMPLATE)",
)
flagSet.Bool(
"http.check.ipaddr",
Expand Down
9 changes: 3 additions & 6 deletions internal/config/decode_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package config

import (
"fmt"
"html/template"
"net/url"
"path"
"reflect"
"text/template"

"github.com/mitchellh/mapstructure"
)
Expand Down Expand Up @@ -76,11 +77,7 @@ func StringToTemplateHookFunc() mapstructure.DecodeHookFuncType {
return template.Template{}, nil
}

var err error

tmpl := template.New("callback")
tmpl, err = tmpl.ParseFiles(dataString)

tmpl, err := template.New(path.Base(dataString)).ParseFiles(dataString)
if err != nil {
return template.Template{}, fmt.Errorf("error paring template files: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/config/decode_hooks_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package config_test

import (
"html/template"
"net/url"
"reflect"
"testing"
"text/template"

"github.com/jkroepke/openvpn-auth-oauth2/internal/config"
"github.com/mitchellh/mapstructure"
Expand Down Expand Up @@ -58,7 +58,7 @@ func TestStringToTemplateHookFunc(t *testing.T) {
err bool
}{
{"valid", reflect.ValueOf("./../../README.md"), reflect.ValueOf(template.Template{}), func() *template.Template {
tmpl, err := template.New("callback").ParseFiles("./../../README.md")
tmpl, err := template.New("README.md").ParseFiles("./../../README.md")
require.NoError(t, err)

return tmpl
Expand Down
4 changes: 2 additions & 2 deletions internal/config/types.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package config

import (
"html/template"
"net/url"
"text/template"
"time"
)

Expand All @@ -21,7 +21,7 @@ type HTTP struct {
TLS bool `koanf:"tls"`
BaseURL *url.URL `koanf:"baseurl"`
Secret string `koanf:"secret"`
CallbackTemplate *template.Template `koanf:"callback-template-path"`
CallbackTemplate *template.Template `koanf:"template"`
Check HTTPCheck `koanf:"check"`
EnableProxyHeaders bool `koanf:"enable-proxy-headers"`
}
Expand Down
19 changes: 16 additions & 3 deletions internal/oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oauth2
import (
"context"
"encoding/base64"
"fmt"
"log/slog"
"net"
"net/http"
Expand Down Expand Up @@ -36,6 +37,8 @@ func Handler(logger *slog.Logger, conf config.Config, provider Provider, openvpn

func oauth2Start(logger *slog.Logger, provider Provider, conf config.Config, openvpn OpenVPN) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sendCacheHeaders(w)

sessionState := r.URL.Query().Get("state")
if sessionState == "" {
w.WriteHeader(http.StatusBadRequest)
Expand Down Expand Up @@ -76,6 +79,12 @@ func oauth2Start(logger *slog.Logger, provider Provider, conf config.Config, ope
})
}

func sendCacheHeaders(w http.ResponseWriter) {
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
}

func checkClientIPAddr(r *http.Request, logger *slog.Logger, session state.State, conf config.Config) (bool, int, string) {
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
Expand Down Expand Up @@ -111,6 +120,10 @@ func oauth2Callback(
w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], encryptedSession string,
rp rp.RelyingParty,
) {
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

Expand Down Expand Up @@ -186,7 +199,7 @@ func getAuthTokenUsername(session state.State, user types.UserData) string {
}

func writeError(w http.ResponseWriter, logger *slog.Logger, conf config.Config, httpCode int, errorType, errorDesc string) {
if conf.HTTP.CallbackTemplate == nil || conf.HTTP.CallbackTemplate.Tree == nil {
if conf.HTTP.CallbackTemplate == nil {
http.Error(w, utils.StringConcat(errorType, ": ", errorDesc), httpCode)

return
Expand All @@ -207,15 +220,15 @@ func writeError(w http.ResponseWriter, logger *slog.Logger, conf config.Config,
}

func writeSuccess(w http.ResponseWriter, conf config.Config, logger *slog.Logger) {
if conf.HTTP.CallbackTemplate == nil || conf.HTTP.CallbackTemplate.Tree == nil {
if conf.HTTP.CallbackTemplate == nil {
_, _ = w.Write([]byte(callbackHTML))

return
}

err := conf.HTTP.CallbackTemplate.Execute(w, map[string]string{})
if err != nil {
logger.Error("executing template:", err)
logger.Error(fmt.Sprintf("executing template: %s", err))
w.WriteHeader(http.StatusInternalServerError)
}
}
43 changes: 43 additions & 0 deletions internal/oauth2/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"
"sync"
"testing"
"text/template"
"time"

"github.com/jkroepke/openvpn-auth-oauth2/internal/config"
Expand Down Expand Up @@ -67,6 +68,42 @@ func TestHandler(t *testing.T) {
true,
"-",
},
{
"with template",
config.Config{
HTTP: config.HTTP{
Secret: testutils.HTTPSecret,
Check: config.HTTPCheck{
IPAddr: false,
},
CallbackTemplate: func() *template.Template {
tmpl, err := template.New("README.md").ParseFiles("./../../README.md")
require.NoError(t, err)

return tmpl
}(),
},
OAuth2: config.OAuth2{
Provider: "generic",
Endpoints: config.OAuth2Endpoints{},
Scopes: []string{"openid", "profile"},
Validate: config.OAuth2Validate{
Groups: make([]string, 0),
Roles: make([]string, 0),
Issuer: true,
IPAddr: false,
},
},
OpenVpn: config.OpenVpn{
Bypass: config.OpenVpnBypass{CommonNames: []string{}},
AuthTokenUser: true,
},
},
"127.0.0.1",
"",
true,
"-",
},
{
"with ipaddr",
config.Config{
Expand Down Expand Up @@ -384,6 +421,12 @@ func TestHandler(t *testing.T) {
return
}

if tt.conf.HTTP.CallbackTemplate != nil {
if !assert.Contains(t, string(body), "openvpn-auth-oauth2 is a management client for OpenVPN that handles the authentication") {
return
}
}

client.Shutdown()
wg.Wait()
})
Expand Down

0 comments on commit 17505c9

Please sign in to comment.