Skip to content

Commit

Permalink
Add envprovider package for accessing environment variables
Browse files Browse the repository at this point in the history
Signed-off-by: Burak Varlı <[email protected]>
  • Loading branch information
unexge committed Jan 27, 2025
1 parent 22d3773 commit 04180b5
Show file tree
Hide file tree
Showing 7 changed files with 343 additions and 56 deletions.
90 changes: 90 additions & 0 deletions pkg/driver/node/envprovider/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Package envprovider provides utilities for accessing environment variables to pass Mountpoint.
package envprovider

import (
"fmt"
"os"
"slices"
)

const (
EnvRegion = "AWS_REGION"
EnvDefaultRegion = "AWS_DEFAULT_REGION"
EnvSTSRegionalEndpoints = "AWS_STS_REGIONAL_ENDPOINTS"
EnvMaxAttempts = "AWS_MAX_ATTEMPTS"
EnvProfile = "AWS_PROFILE"
EnvConfigFile = "AWS_CONFIG_FILE"
EnvSharedCredentialsFile = "AWS_SHARED_CREDENTIALS_FILE"
EnvRoleARN = "AWS_ROLE_ARN"
EnvWebIdentityTokenFile = "AWS_WEB_IDENTITY_TOKEN_FILE"
EnvEC2MetadataDisabled = "AWS_EC2_METADATA_DISABLED"
EnvAccessKeyID = "AWS_ACCESS_KEY_ID"
EnvSecretAccessKey = "AWS_SECRET_ACCESS_KEY"
EnvSessionToken = "AWS_SESSION_TOKEN"

EnvMountpointCacheKey = "UNSTABLE_MOUNTPOINT_CACHE_KEY"
)

// Key represents an environment variable name.
type Key = string

// Value represents an environment variable value.
type Value = string

// Environment represents a list of environment variables as key-value pairs.
type Environment map[Key]Value

// envAllowlist is the list of environment variables to pass-by by default.
// If any of these set, it will be returned as-is in [Default].
var envAllowlist = []Key{
EnvRegion,
EnvDefaultRegion,
EnvSTSRegionalEndpoints,
}

// Region returns detected region from environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.
// It returns an empty string if both is unset.
func Region() Value {
region := os.Getenv(EnvRegion)
if region != "" {
return region
}
return os.Getenv(EnvDefaultRegion)
}

// Default returns list of environment variables to pass Mountpoint.
func Default() Environment {
environment := make(Environment)
for _, key := range envAllowlist {
val := os.Getenv(key)
if val != "" {
environment[key] = val
}
}
return environment
}

// List returns a sorted slice of environment variables in "KEY=VALUE" format.
func (env Environment) List() []string {
list := []string{}
for key, val := range env {
list = append(list, format(key, val))
}
slices.Sort(list)
return list
}

// Delete deletes the environment variable with the specified key.
func (env Environment) Delete(key Key) {
delete(env, key)
}

// Set adds or updates the environment variable with the specified key and value.
func (env Environment) Set(key Key, value Value) {
env[key] = value
}

// format formats given key and value to be used as an environment variable.
func format(key Key, value Value) string {
return fmt.Sprintf("%s=%s", key, value)
}
211 changes: 211 additions & 0 deletions pkg/driver/node/envprovider/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
package envprovider_test

import (
"testing"

"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider"
"github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert"
)

func TestGettingRegion(t *testing.T) {
testCases := []struct {
name string
envRegion string
envDefaultRegion string
want string
}{
{
name: "both region envs are set",
envRegion: "us-west-1",
envDefaultRegion: "us-east-1",
want: "us-west-1",
},
{
name: "only default region env is set",
envRegion: "",
envDefaultRegion: "us-east-1",
want: "us-east-1",
},
{
name: "no region env is set",
envRegion: "",
envDefaultRegion: "",
want: "",
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("AWS_REGION", testCase.envRegion)
t.Setenv("AWS_DEFAULT_REGION", testCase.envDefaultRegion)
assert.Equals(t, testCase.want, envprovider.Region())
})
}
}

func TestProvidingDefaultEnvironmentVariables(t *testing.T) {
testCases := []struct {
name string
env map[string]string
want []string
}{
{
name: "no env vars set",
env: map[string]string{},
want: []string{},
},
{
name: "some allowed env vars set",
env: map[string]string{
"AWS_REGION": "us-west-1",
"AWS_DEFAULT_REGION": "us-east-1",
"AWS_STS_REGIONAL_ENDPOINTS": "regional",
"AWS_MAX_ATTEMPTS": "10",
},
want: []string{
"AWS_DEFAULT_REGION=us-east-1",
"AWS_REGION=us-west-1",
"AWS_STS_REGIONAL_ENDPOINTS=regional",
},
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
for k, v := range testCase.env {
t.Setenv(k, v)
}
assert.Equals(t, testCase.want, envprovider.Default().List())
})
}
}

func TestRemovingAKeyFromListOfEnvironmentVariables(t *testing.T) {
testCases := []struct {
name string
env envprovider.Environment
key string
want envprovider.Environment
}{
{
name: "empty environment",
env: envprovider.Environment{},
key: "AWS_REGION",
want: envprovider.Environment{},
},
{
name: "remove existing key",
env: envprovider.Environment{"AWS_REGION": "us-west-1", "AWS_DEFAULT_REGION": "us-east-1"},
key: "AWS_REGION",
want: envprovider.Environment{"AWS_DEFAULT_REGION": "us-east-1"},
},
{
name: "remove non-existing key",
env: envprovider.Environment{"AWS_REGION": "us-west-1", "AWS_DEFAULT_REGION": "us-east-1"},
key: "AWS_MAX_ATTEMPTS",
want: envprovider.Environment{"AWS_REGION": "us-west-1", "AWS_DEFAULT_REGION": "us-east-1"},
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
testCase.env.Delete(testCase.key)
assert.Equals(t, testCase.want, testCase.env)
})
}
}

func TestSettingKeyValueInEnvironmentVariables(t *testing.T) {
testCases := []struct {
name string
env envprovider.Environment
key string
value string
want envprovider.Environment
}{
{
name: "add to empty environment",
env: envprovider.Environment{},
key: "AWS_REGION",
value: "us-west-1",
want: envprovider.Environment{"AWS_REGION": "us-west-1"},
},
{
name: "update existing key",
env: envprovider.Environment{"AWS_REGION": "us-west-1"},
key: "AWS_REGION",
value: "us-east-1",
want: envprovider.Environment{"AWS_REGION": "us-east-1"},
},
{
name: "add new key to non-empty environment",
env: envprovider.Environment{"AWS_REGION": "us-west-1"},
key: "AWS_DEFAULT_REGION",
value: "us-east-1",
want: envprovider.Environment{"AWS_REGION": "us-west-1", "AWS_DEFAULT_REGION": "us-east-1"},
},
{
name: "set empty value",
env: envprovider.Environment{"AWS_REGION": "us-west-1"},
key: "AWS_DEFAULT_REGION",
value: "",
want: envprovider.Environment{"AWS_REGION": "us-west-1", "AWS_DEFAULT_REGION": ""},
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
testCase.env.Set(testCase.key, testCase.value)
assert.Equals(t, testCase.want, testCase.env)
})
}
}

func TestEnvironmentList(t *testing.T) {
testCases := []struct {
name string
env envprovider.Environment
want []string
}{
{
name: "empty environment",
env: envprovider.Environment{},
want: []string{},
},
{
name: "single environment variable",
env: envprovider.Environment{"AWS_REGION": "us-west-1"},
want: []string{"AWS_REGION=us-west-1"},
},
{
name: "multiple environment variables are sorted",
env: envprovider.Environment{
"AWS_REGION": "us-west-1",
"AWS_DEFAULT_REGION": "us-east-1",
"AWS_STS_REGIONAL_ENDPOINTS": "regional",
},
want: []string{
"AWS_DEFAULT_REGION=us-east-1",
"AWS_REGION=us-west-1",
"AWS_STS_REGIONAL_ENDPOINTS=regional",
},
},
{
name: "environment variables with empty values",
env: envprovider.Environment{
"AWS_REGION": "",
"AWS_DEFAULT_REGION": "us-east-1",
},
want: []string{
"AWS_DEFAULT_REGION=us-east-1",
"AWS_REGION=",
},
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
assert.Equals(t, testCase.want, testCase.env.List())
})
}
}
23 changes: 12 additions & 11 deletions pkg/driver/node/mounter/credential_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"k8s.io/klog/v2"
k8sstrings "k8s.io/utils/strings"

"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider"
"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext"
"github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint"
)
Expand Down Expand Up @@ -109,14 +110,14 @@ func (c *CredentialProvider) provideFromDriver() (*MountCredentials, error) {

return &MountCredentials{
AuthenticationSource: AuthenticationSourceDriver,
AccessKeyID: os.Getenv(keyIdEnv),
SecretAccessKey: os.Getenv(accessKeyEnv),
SessionToken: os.Getenv(sessionTokenEnv),
Region: os.Getenv(regionEnv),
DefaultRegion: os.Getenv(defaultRegionEnv),
AccessKeyID: os.Getenv(envprovider.EnvAccessKeyID),
SecretAccessKey: os.Getenv(envprovider.EnvSecretAccessKey),
SessionToken: os.Getenv(envprovider.EnvSessionToken),
Region: os.Getenv(envprovider.EnvRegion),
DefaultRegion: os.Getenv(envprovider.EnvDefaultRegion),
WebTokenPath: hostTokenPath,
StsEndpoints: os.Getenv(stsEndpointsEnv),
AwsRoleArn: os.Getenv(roleArnEnv),
StsEndpoints: os.Getenv(envprovider.EnvSTSRegionalEndpoints),
AwsRoleArn: os.Getenv(envprovider.EnvRoleARN),
}, nil
}

Expand Down Expand Up @@ -150,7 +151,7 @@ func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string
return nil, status.Errorf(codes.InvalidArgument, "Failed to detect STS AWS Region, please explicitly set the AWS Region, see "+stsConfigDocsPage)
}

defaultRegion := os.Getenv(defaultRegionEnv)
defaultRegion := os.Getenv(envprovider.EnvDefaultRegion)
if defaultRegion == "" {
defaultRegion = region
}
Expand All @@ -177,7 +178,7 @@ func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string

Region: region,
DefaultRegion: defaultRegion,
StsEndpoints: os.Getenv(stsEndpointsEnv),
StsEndpoints: os.Getenv(envprovider.EnvSTSRegionalEndpoints),
WebTokenPath: hostTokenPath,
AwsRoleArn: awsRoleARN,

Expand Down Expand Up @@ -258,13 +259,13 @@ func (c *CredentialProvider) stsRegion(volumeCtx map[string]string, args mountpo
return region, nil
}

region = os.Getenv(regionEnv)
region = os.Getenv(envprovider.EnvRegion)
if region != "" {
klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from `AWS_REGION` env variable", region)
return region, nil
}

region = os.Getenv(defaultRegionEnv)
region = os.Getenv(envprovider.EnvDefaultRegion)
if region != "" {
klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from `AWS_DEFAULT_REGION` env variable", region)
return region, nil
Expand Down
Loading

0 comments on commit 04180b5

Please sign in to comment.