diff --git a/internal/config/env.go b/internal/config/env.go index 0505d143..3ceb12be 100644 --- a/internal/config/env.go +++ b/internal/config/env.go @@ -3,7 +3,6 @@ package config import ( "encoding/json" "fmt" - "go-template/pkg/utl/convert" "log" "os" "path/filepath" @@ -11,6 +10,8 @@ import ( "strconv" "github.com/joho/godotenv" + + "go-template/pkg/utl/convert" ) func GetString(key string) string { @@ -57,9 +58,12 @@ func FileName() string { } func LoadEnv() error { + const ( + localEnvFile = "local" + ) _, filename, _, ok := runtime.Caller(0) if !ok { - return fmt.Errorf("Error getting current file path") + return fmt.Errorf("error getting current file path") } prefix := fmt.Sprintf("%s/", filepath.Join(filepath.Dir(filename), "../../")) @@ -71,18 +75,15 @@ func LoadEnv() error { envName := os.Getenv("ENVIRONMENT_NAME") if envName == "" { - envName = "local" + envName = localEnvFile } log.Println("envName: " + envName) envVarInjection := GetBool("ENV_INJECTION") - if !envVarInjection || envName == "local" { + if !envVarInjection || envName == localEnvFile { err = godotenv.Load(fmt.Sprintf("%s.env.%s", prefix, envName)) - if err != nil { - fmt.Printf(".env.%s\n", envName) - log.Println(err) - return err + return fmt.Errorf("failed to load env for environment %q file: %w", envName, err) } fmt.Println("loaded", fmt.Sprintf("%s.env.%s", prefix, envName)) return nil @@ -90,26 +91,45 @@ func LoadEnv() error { dbCredsInjected := GetBool("COPILOT_DB_CREDS_VIA_SECRETS_MANAGER") + // except for local environment the db creds should be + // injected through the secret manager + if envName != localEnvFile && !dbCredsInjected { + return fmt.Errorf("db creds should be injected through secret manager") + } + + // if db creds are injected, extract those if dbCredsInjected { - type copilotSecrets struct { - Username string `json:"username"` - Host string `json:"host"` - DBName string `json:"dbname"` - Password string `json:"password"` - Port int `json:"port"` - } - secrets := &copilotSecrets{} + return extractDBCredsFromSecret() + } + // otherwise + return nil +} - err := json.Unmarshal([]byte(os.Getenv("DB_SECRET")), secrets) - if err != nil { - return err - } +// extractDBCredsFromSecret helper function to extract db secret +func extractDBCredsFromSecret() error { + type copilotSecrets struct { + Username string `json:"username"` + Host string `json:"host"` + DBName string `json:"dbname"` + Password string `json:"password"` + Port int `json:"port"` + } + secrets, dbSecret := &copilotSecrets{}, os.Getenv("DB_SECRET") + + if dbSecret == "" { + return fmt.Errorf("'DB_SECRET' environment var is not set or is empty") + } - os.Setenv("PSQL_DBNAME", secrets.DBName) - os.Setenv("PSQL_HOST", secrets.Host) - os.Setenv("PSQL_PASS", secrets.Password) - os.Setenv("PSQL_PORT", strconv.Itoa(secrets.Port)) - os.Setenv("PSQL_USER", secrets.Username) + err := json.Unmarshal([]byte(dbSecret), secrets) + if err != nil { + return fmt.Errorf("couldn't unmarshal db secret: %w", err) } - return fmt.Errorf("COPILOT_DB_CREDS_VIA_SECRETS_MANAGER should have had a value") + + os.Setenv("PSQL_DBNAME", secrets.DBName) + os.Setenv("PSQL_HOST", secrets.Host) + os.Setenv("PSQL_PASS", secrets.Password) + os.Setenv("PSQL_PORT", strconv.Itoa(secrets.Port)) + os.Setenv("PSQL_USER", secrets.Username) + + return nil } diff --git a/internal/config/env_test.go b/internal/config/env_test.go index 601023f1..a9293e17 100644 --- a/internal/config/env_test.go +++ b/internal/config/env_test.go @@ -2,11 +2,12 @@ package config_test import ( "fmt" - . "go-template/internal/config" "os" "testing" "github.com/stretchr/testify/assert" + + . "go-template/internal/config" ) func TestGetString(t *testing.T) { @@ -231,6 +232,7 @@ func loadLocalEnv() envTestCaseArgs { }, } } + func errorOnEnvInjectionAndCopilotFalse() envTestCaseArgs { return envTestCaseArgs{ name: "Error when ENV_INJECTION and COPILOT_DB_CREDS_VIA_SECRETS_MANAGER false", @@ -254,10 +256,67 @@ func errorOnEnvInjectionAndCopilotFalse() envTestCaseArgs { } } -func loadOnDbCredsInjected(username string, host string, dbname string, password string, port string) envTestCaseArgs { +func loadOnCopilotTrueAndLocalEnv() envTestCaseArgs { return envTestCaseArgs{ - name: "dbCredsInjected True", + name: "Load local without copilot", + wantErr: false, + args: args{ + setEnv: []keyValueArgs{ + { + key: "ENV_INJECTION", + value: "true", + }, + { + key: "ENVIRONMENT_NAME", + value: "local", + }, + { + key: "COPILOT_DB_CREDS_VIA_SECRETS_MANAGER", + value: "false", + }, + }, + }, + } +} + +func errorOnDbCredsInjectedInDevEnv() envTestCaseArgs { + return envTestCaseArgs{ + name: "dbCredsInjected True for `develop` environment,with invalid json in db secret", wantErr: true, + args: args{ + setEnv: []keyValueArgs{ + { + key: "ENV_INJECTION", + value: "true", + }, + { + key: "ENVIRONMENT_NAME", + value: "develop", + }, + { + key: "COPILOT_DB_CREDS_VIA_SECRETS_MANAGER", + value: "true", + }, + { + key: "DB_SECRET", + value: "invalid json", + }, + }, + expectedKeyValues: []keyValueArgs{}, + }, + } +} + +func loadOnDbCredsInjectedInDevEnv( + username string, + host string, + dbname string, + password string, + port string, +) envTestCaseArgs { + return envTestCaseArgs{ + name: "dbCredsInjected True for `develop` environment,and should parse the db secret", + wantErr: false, args: args{ setEnv: []keyValueArgs{ { @@ -274,27 +333,75 @@ func loadOnDbCredsInjected(username string, host string, dbname string, password }, { key: "DB_SECRET", - value: fmt.Sprintf(`{"username": "%s", "password": "%s", "port": "%s", "dbname": "%s", "host": "%s"}`, + value: fmt.Sprintf(`{"username": "%s", "password": "%s", "port": %s, "dbname": "%s", "host": "%s"}`, username, password, port, - host, - dbname), + dbname, + host), }, }, expectedKeyValues: []keyValueArgs{ + { + key: "PSQL_USER", + value: username, + }, + { + key: "PSQL_PORT", + value: port, + }, { key: "PSQL_PASS", value: password, }, + { + key: "PSQL_HOST", + value: host, + }, { key: "PSQL_DBNAME", value: dbname, }, + }, + }, + } +} + +func loadDbCredsInjectedInLocalEnv( + username string, + host string, + dbname string, + password string, + port string, +) envTestCaseArgs { + return envTestCaseArgs{ + name: "dbCredsInjected True for local environment,and should parse the db secret", + wantErr: false, + args: args{ + setEnv: []keyValueArgs{ { - key: "PSQL_HOST", - value: host, + key: "ENV_INJECTION", + value: "true", }, + { + key: "ENVIRONMENT_NAME", + value: "local", + }, + { + key: "COPILOT_DB_CREDS_VIA_SECRETS_MANAGER", + value: "true", + }, + { + key: "DB_SECRET", + value: fmt.Sprintf(`{"username": "%s", "password": "%s", "port": %s, "dbname": "%s", "host": "%s"}`, + username, + password, + port, + dbname, + host), + }, + }, + expectedKeyValues: []keyValueArgs{ { key: "PSQL_USER", value: username, @@ -303,6 +410,18 @@ func loadOnDbCredsInjected(username string, host string, dbname string, password key: "PSQL_PORT", value: port, }, + { + key: "PSQL_PASS", + value: password, + }, + { + key: "PSQL_HOST", + value: host, + }, + { + key: "PSQL_DBNAME", + value: dbname, + }, }, }, } @@ -310,7 +429,7 @@ func loadOnDbCredsInjected(username string, host string, dbname string, password func errorOnWrongEnvName() envTestCaseArgs { return envTestCaseArgs{ - name: "Failed to load env", + name: "Failed to load env for local1", wantErr: true, args: args{ setEnv: []keyValueArgs{ @@ -318,16 +437,24 @@ func errorOnWrongEnvName() envTestCaseArgs { key: "ENVIRONMENT_NAME", value: "local1", }, + { + key: "ENV_INJECTION", + value: "false", + }, }, }, } } + func getTestCases(username string, host string, dbname string, password string, port string) []envTestCaseArgs { return []envTestCaseArgs{ loadLocalEnvIfNoEnvName(), loadLocalEnv(), errorOnEnvInjectionAndCopilotFalse(), - loadOnDbCredsInjected(username, host, dbname, password, port), + loadOnCopilotTrueAndLocalEnv(), + errorOnDbCredsInjectedInDevEnv(), + loadOnDbCredsInjectedInDevEnv(username, host, dbname, password, port), + loadDbCredsInjectedInLocalEnv(username, host, dbname, password, port), errorOnWrongEnvName(), } } @@ -342,7 +469,8 @@ func testLoadEnv(t *testing.T, tt struct { name string wantErr bool args args -}) { +}, +) { if err := LoadEnv(); (err != nil) != tt.wantErr { t.Errorf("LoadEnv() error = %v, wantErr %v", err, tt.wantErr) } else {