From b811409dd5495184bec649281ba69f5e51df7ade Mon Sep 17 00:00:00 2001 From: Brenden Wu Date: Wed, 25 Feb 2026 23:38:44 +0000 Subject: [PATCH] Add wsARN for SNS Config --- config/notifiers.go | 17 +- go.mod | 10 +- go.sum | 9 +- notify/sns/aws_round_tripper.go | 34 ++ notify/sns/aws_round_tripper_test.go | 97 +++++ notify/sns/sns.go | 577 +++++++++++++++++++++------ notify/sns/sns_test.go | 490 +++++++++++++++++++++-- 7 files changed, 1051 insertions(+), 183 deletions(-) create mode 100644 notify/sns/aws_round_tripper.go create mode 100644 notify/sns/aws_round_tripper_test.go diff --git a/config/notifiers.go b/config/notifiers.go index 66372814ec..724730720d 100644 --- a/config/notifiers.go +++ b/config/notifiers.go @@ -897,14 +897,15 @@ type SNSConfig struct { HTTPConfig *commoncfg.HTTPClientConfig `yaml:"http_config,omitempty" json:"http_config,omitempty"` - APIUrl string `yaml:"api_url,omitempty" json:"api_url,omitempty"` - Sigv4 sigv4.SigV4Config `yaml:"sigv4" json:"sigv4"` - TopicARN string `yaml:"topic_arn,omitempty" json:"topic_arn,omitempty"` - PhoneNumber string `yaml:"phone_number,omitempty" json:"phone_number,omitempty"` - TargetARN string `yaml:"target_arn,omitempty" json:"target_arn,omitempty"` - Subject string `yaml:"subject,omitempty" json:"subject,omitempty"` - Message string `yaml:"message,omitempty" json:"message,omitempty"` - Attributes map[string]string `yaml:"attributes,omitempty" json:"attributes,omitempty"` + APIUrl string `yaml:"api_url,omitempty" json:"api_url,omitempty"` + Sigv4 sigv4.SigV4Config `yaml:"sigv4" json:"sigv4"` + TopicARN string `yaml:"topic_arn,omitempty" json:"topic_arn,omitempty"` + PhoneNumber string `yaml:"phone_number,omitempty" json:"phone_number,omitempty"` + TargetARN string `yaml:"target_arn,omitempty" json:"target_arn,omitempty"` + Subject string `yaml:"subject,omitempty" json:"subject,omitempty"` + Message string `yaml:"message,omitempty" json:"message,omitempty"` + Attributes map[string]string `yaml:"attributes,omitempty" json:"attributes,omitempty"` + WorkspaceArn string `yaml:"workspace_arn,omitempty" json:"workspace_arn,omitempty"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. diff --git a/go.mod b/go.mod index ca5e720a57..67057f00de 100644 --- a/go.mod +++ b/go.mod @@ -6,13 +6,10 @@ require ( github.com/KimMachineGun/automemlimit v0.7.5 github.com/alecthomas/kingpin/v2 v2.4.0 github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b + github.com/aws/aws-sdk-go v1.55.8 github.com/aws/aws-sdk-go-v2 v1.41.1 github.com/aws/aws-sdk-go-v2/config v1.32.7 - github.com/aws/aws-sdk-go-v2/credentials v1.19.7 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.35.6 - github.com/aws/aws-sdk-go-v2/service/sns v1.39.11 - github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 - github.com/aws/smithy-go v1.24.0 github.com/cenkalti/backoff/v4 v4.3.0 github.com/cespare/xxhash/v2 v2.3.0 github.com/coder/quartz v0.3.0 @@ -34,6 +31,7 @@ require ( github.com/matttproud/golang_protobuf_extensions v1.0.4 github.com/oklog/run v1.2.0 github.com/oklog/ulid/v2 v2.1.1 + github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.23.2 github.com/prometheus/common v0.67.5 github.com/prometheus/exporter-toolkit v0.15.1 @@ -63,6 +61,7 @@ require ( require ( github.com/armon/go-metrics v0.4.1 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.19.7 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.17 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17 // indirect @@ -72,6 +71,8 @@ require ( github.com/aws/aws-sdk-go-v2/service/signin v1.0.5 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.30.9 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.13 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 // indirect + github.com/aws/smithy-go v1.24.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/coreos/go-systemd/v22 v22.6.0 // indirect @@ -105,6 +106,7 @@ require ( github.com/hashicorp/go-msgpack/v2 v2.1.5 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jpillora/backoff v1.0.0 // indirect github.com/julienschmidt/httprouter v1.3.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect diff --git a/go.sum b/go.sum index 23f9c1848f..32fe99ee4b 100644 --- a/go.sum +++ b/go.sum @@ -77,6 +77,8 @@ github.com/armon/go-metrics v0.4.1 h1:hR91U9KYmb6bLBYLQjyM+3j+rcd/UhE+G78SFnF8gJ github.com/armon/go-metrics v0.4.1/go.mod h1:E6amYzXo6aW1tqzoZGT755KkbgrJsSdpwZ+3JqfkOG4= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/aws/aws-sdk-go v1.55.8 h1:JRmEUbU52aJQZ2AjX4q4Wu7t4uZjOu71uyNmaWlUkJQ= +github.com/aws/aws-sdk-go v1.55.8/go.mod h1:ZkViS9AqA6otK+JBBNH2++sx1sgxrPKcSzPPvQkUtXk= github.com/aws/aws-sdk-go-v2 v1.41.1 h1:ABlyEARCDLN034NhxlRUSZr4l71mh+T5KAeGh6cerhU= github.com/aws/aws-sdk-go-v2 v1.41.1/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= github.com/aws/aws-sdk-go-v2/config v1.32.7 h1:vxUyWGUwmkQ2g19n7JY/9YL8MfAIl7bTesIUykECXmY= @@ -99,8 +101,6 @@ github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.35.6 h1:l4mxH8imZoflVEWWa github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.35.6/go.mod h1:1qwmvfRBGTQ5shUxu+eQO/S2+O6o6SxbvcvtN62kmc0= github.com/aws/aws-sdk-go-v2/service/signin v1.0.5 h1:VrhDvQib/i0lxvr3zqlUwLwJP4fpmpyD9wYG1vfSu+Y= github.com/aws/aws-sdk-go-v2/service/signin v1.0.5/go.mod h1:k029+U8SY30/3/ras4G/Fnv/b88N4mAfliNn08Dem4M= -github.com/aws/aws-sdk-go-v2/service/sns v1.39.11 h1:Ke7RS0NuP9Xwk31prXYcFGA1Qfn8QmNWcxyjKPcXZdc= -github.com/aws/aws-sdk-go-v2/service/sns v1.39.11/go.mod h1:hdZDKzao0PBfJJygT7T92x2uVcWc/htqlhrjFIjnHDM= github.com/aws/aws-sdk-go-v2/service/sso v1.30.9 h1:v6EiMvhEYBoHABfbGB4alOYmCIrcgyPPiBE1wZAEbqk= github.com/aws/aws-sdk-go-v2/service/sso v1.30.9/go.mod h1:yifAsgBxgJWn3ggx70A3urX2AN49Y5sJTD1UQFlfqBw= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.13 h1:gd84Omyu9JLriJVCbGApcLzVR3XtmC4ZDPcAI6Ftvds= @@ -389,6 +389,10 @@ github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1: github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/jessevdk/go-flags v1.6.1 h1:Cvu5U8UGrLay1rZfv/zP7iLpSHGUZ/Ou68T0iX1bBK4= github.com/jessevdk/go-flags v1.6.1/go.mod h1:Mk8T1hIAWpOiJiHa9rJASDK2UGWji0EuPGBnNLMooyc= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -478,6 +482,7 @@ github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCko github.com/pelletier/go-toml/v2 v2.0.5/go.mod h1:OMHamSCAODeSsVrwwvcJOaoN0LIUIaFVNZzmWyNfXas= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/notify/sns/aws_round_tripper.go b/notify/sns/aws_round_tripper.go new file mode 100644 index 0000000000..9224a76be5 --- /dev/null +++ b/notify/sns/aws_round_tripper.go @@ -0,0 +1,34 @@ +package sns + +import ( + "fmt" + "net/http" + + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/prometheus/alertmanager/config" +) + +type confusedDeputyRoundTripper struct { + workspaceArn arn.ARN + rt http.RoundTripper +} + +func (rt *confusedDeputyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("x-amz-delegation-source-account", rt.workspaceArn.AccountID) + req.Header.Set("x-amz-delegation-source-arn", rt.workspaceArn.String()) + return rt.rt.RoundTrip(req) +} + +// newConfusedDeputyRoundTripper adds confused deputy headers +func newConfusedDeputyRoundTripper(c *config.SNSConfig, rt http.RoundTripper) (http.RoundTripper, error) { + if c.WorkspaceArn == "" { + return rt, nil + } + + arn, err := arn.Parse(c.WorkspaceArn) + + if err != nil { + return nil, fmt.Errorf("%s is not a valid arn", c.WorkspaceArn) + } + return &confusedDeputyRoundTripper{arn, rt}, nil +} diff --git a/notify/sns/aws_round_tripper_test.go b/notify/sns/aws_round_tripper_test.go new file mode 100644 index 0000000000..d6959c2b84 --- /dev/null +++ b/notify/sns/aws_round_tripper_test.go @@ -0,0 +1,97 @@ +package sns + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/prometheus/alertmanager/config" + commoncfg "github.com/prometheus/common/config" +) + +func TestRoundTripperWithArnNotConfigured(t *testing.T) { + var testCases = []struct { + name string + snsConfig config.SNSConfig + expectedHeaders map[string]string + deniedHeaders []string + expectedErrorMessage string + }{ + { + name: "Workspace invalid Arn configured", + snsConfig: config.SNSConfig{ + WorkspaceArn: "arn:--Invalid", + }, + expectedHeaders: map[string]string{}, + deniedHeaders: []string{}, + expectedErrorMessage: "arn:--Invalid is not a valid arn", + }, + { + name: "Workspace Arn not configured", + snsConfig: config.SNSConfig{}, + expectedHeaders: map[string]string{}, + deniedHeaders: []string{ + "x-amz-source-account", + "x-amz-source-arn", + "x-amz-delegation-source-arn", + "x-amz-delegation-source-account", + }, + }, + { + name: "Workspace Arn configured", + snsConfig: config.SNSConfig{ + WorkspaceArn: "arn:aws:aps:us-west-2:948363459592:workspace/ws-de4908b6-950e-4c4c-9e49-ec68169bc4c7", + }, + expectedHeaders: map[string]string{ + "x-amz-delegation-source-account": "948363459592", + "x-amz-delegation-source-arn": "arn:aws:aps:us-west-2:948363459592:workspace/ws-de4908b6-950e-4c4c-9e49-ec68169bc4c7", + }, + deniedHeaders: []string{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + testServer := newTestServer(func(w http.ResponseWriter, r *http.Request) { + for _, name := range tc.deniedHeaders { + if _, ok := r.Header[name]; ok { + t.Fatalf("Header %s should not be set", name) + } + } + + for key, value := range tc.expectedHeaders { + if r.Header.Get(key) != value { + t.Fatalf("The received Headers (%s) does not contain all expected headers (%s).", r.Header, tc.expectedHeaders) + return + } + } + }) + + defer testServer.Close() + + client, err := commoncfg.NewClientFromConfig(commoncfg.HTTPClientConfig{}, "test") + + if err != nil && err.Error() != tc.expectedErrorMessage { + t.Fatal(err.Error()) + } + + client.Transport, err = newConfusedDeputyRoundTripper(&tc.snsConfig, client.Transport) + + if err != nil && err.Error() != tc.expectedErrorMessage { + t.Fatal(err.Error()) + } + + _, err = client.Get(testServer.URL) + + if err != nil && err.Error() != tc.expectedErrorMessage { + t.Fatal(err.Error()) + } + }) + } +} + +func newTestServer(handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server { + testServer := httptest.NewUnstartedServer(http.HandlerFunc(handler)) + testServer.Start() + return testServer +} diff --git a/notify/sns/sns.go b/notify/sns/sns.go index 80b039b8b1..841e47b5d7 100644 --- a/notify/sns/sns.go +++ b/notify/sns/sns.go @@ -15,23 +15,23 @@ package sns import ( "context" - "errors" + "encoding/json" "fmt" "log/slog" "net/http" + "regexp" "strings" + "time" "unicode/utf8" - "github.com/aws/aws-sdk-go-v2/aws" - awsconfig "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go-v2/credentials/stscreds" - "github.com/aws/aws-sdk-go-v2/service/sns" - snstypes "github.com/aws/aws-sdk-go-v2/service/sns/types" - "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/aws/smithy-go" - smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/pkg/errors" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sns" commoncfg "github.com/prometheus/common/config" "github.com/prometheus/alertmanager/config" @@ -40,6 +40,45 @@ import ( "github.com/prometheus/alertmanager/types" ) +const ( + // Message components + Message = "Message" + Subject = "Subject" + MessageAttribute = "MessageAttribute" + + // Modified Message attribute value format + ComponentAndModifiedReason = "%s: %s" + + // The errors + MessageNotValidUtf8 = "Error - not a valid UTF-8 encoded string" + MessageIsEmpty = "Error - the message should not be empty" + MessageSizeExceeded = "Error - the message has been truncated from %dKB because it exceeds the %dKB size limit" + SubjectContainsIllegalChars = "Error - contains control- or non-ASCII characters" + SubjectSizeExceeded = "Error - subject has been truncated from %d characters because it exceeds the 100 character size limit" + SubjectEmpty = "Error - subject, if provided, must be non-empty" + MessageAttributeSizeExceeded = "Error - %d of message attributes have been removed because of %dKB size limit exceeded" + MessageAttributeNotValidKeyOrValue = "Error - %d of message attributes have been removed because of invalid MessageAttributeKey or MessageAttributeValue" + + // Message components size limit + subjectSizeLimitInCharacters = 100 + messageAttributeKeyLimitInCharacters = 256 + // Max message size for a message in a SNS publish request is 256KB, except for SMS messages where the limit is 1600 characters/runes. + messageSizeLimitInBytes = 256 * 1024 + messageSizeLimitInCharactersForSMS = 1600 +) + +var isInvalidMessageAttributeKeyPrefix = regexp.MustCompile(`^(AWS\.)|^(Amazon\.)|^(\.)`).MatchString +var isInvalidMessageAttributeKeySuffix = regexp.MustCompile(`\.$`).MatchString +var isInvalidMessageAttributeKeySubstring = regexp.MustCompile(`\.{2}`).MatchString +var isValidMessageAttributeKeyCharacters = regexp.MustCompile(`^[a-zA-Z0-9_\-.]*$`).MatchString + +var truncatedMessageAttributeKey = "truncated" +var truncatedMessageAttributeValue = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")} +var modifiedMessageAttributeKey = "modified" + +// Used for testing +var jsonMarshal = json.Marshal + // Notifier implements a Notifier for SNS notifications. type Notifier struct { conf *config.SNSConfig @@ -47,14 +86,23 @@ type Notifier struct { logger *slog.Logger client *http.Client retrier *notify.Retrier + isFifo *bool } // New returns a new SNS notification handler. func New(c *config.SNSConfig, t *template.Template, l *slog.Logger, httpOpts ...commoncfg.HTTPClientOption) (*Notifier, error) { - client, err := notify.NewClientWithTracing(*c.HTTPConfig, "sns", httpOpts...) + client, err := commoncfg.NewClientFromConfig(*c.HTTPConfig, "sns", append(httpOpts, commoncfg.WithHTTP2Disabled())...) + if err != nil { + return nil, err + } + + // Custom AWS Round Tripper + client.Transport, err = newConfusedDeputyRoundTripper(c, client.Transport) + if err != nil { return nil, err } + return &Notifier{ conf: c, tmpl: t, @@ -66,199 +114,466 @@ func New(c *config.SNSConfig, t *template.Template, l *slog.Logger, httpOpts ... func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, error) { var ( + err error tmplErr error data = notify.GetTemplateData(ctx, n.tmpl, alert, n.logger) tmpl = notify.TmplText(n.tmpl, data, &tmplErr) ) - client, err := n.createSNSClient(ctx, tmpl, &tmplErr) + client, err := createSNSClient(n.client, n, tmpl, &tmplErr) if err != nil { - // V2 error handling is different. We don't have awserr.RequestFailure. - // We can check for a generic smithy.APIError to see if it's a service error. - var apiErr smithy.APIError - if errors.As(err, &apiErr) { - // To maintain compatibility with the retrier, we attempt to get an HTTP status code. - var respErr *smithyhttp.ResponseError - if errors.As(err, &respErr) && respErr.Response != nil { - return n.retrier.Check(respErr.Response.StatusCode, strings.NewReader(apiErr.ErrorMessage())) - } - // Fallback if we can't get a status code. - return true, fmt.Errorf("failed to create SNS client: %s: %s", apiErr.ErrorCode(), apiErr.ErrorMessage()) + if e, ok := err.(awserr.RequestFailure); ok { + return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message())) + } else { + return true, err } - return true, err } - publishInput, err := n.createPublishInput(ctx, tmpl, &tmplErr) + publishInput, err := createPublishInput(ctx, n, tmpl, &tmplErr) if err != nil { return true, err } - publishOutput, err := client.Publish(ctx, publishInput) + publishOutput, err := client.Publish(publishInput) if err != nil { - // V2 error handling uses errors.As to inspect the error chain. - var apiErr smithy.APIError - if errors.As(err, &apiErr) { - var statusCode int - var respErr *smithyhttp.ResponseError - // Try to extract the HTTP status code for the retrier. - if errors.As(err, &respErr) && respErr.Response != nil { - statusCode = respErr.Response.StatusCode - } - - // If we got a status code, use the retrier logic. - if statusCode != 0 { - retryable, checkErr := n.retrier.Check(statusCode, strings.NewReader(apiErr.ErrorMessage())) - reasonErr := notify.NewErrorWithReason(notify.GetFailureReasonFromStatusCode(statusCode), checkErr) - return retryable, reasonErr - } + if e, ok := err.(awserr.RequestFailure); ok { + retryable, error := n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message())) + reasonErr := notify.NewErrorWithReason(notify.GetFailureReasonFromStatusCode(e.StatusCode()), error) + return retryable, reasonErr + } else { + return true, err } - // Fallback for non-API errors or if status code extraction fails. - return true, err } - n.logger.Debug("SNS message successfully published", "message_id", aws.ToString(publishOutput.MessageId), "sequence_number", aws.ToString(publishOutput.SequenceNumber)) + n.logger.Debug("SNS message successfully published", "message_id", publishOutput.MessageId, "sequence number", publishOutput.SequenceNumber) return false, nil } -func (n *Notifier) createSNSClient(ctx context.Context, tmpl func(string) string, tmplErr *error) (*sns.Client, error) { - // Base configuration options that apply to both STS (if used) and the final SNS client. - baseCfgOpts := []func(*awsconfig.LoadOptions) error{ - awsconfig.WithHTTPClient(n.client), - awsconfig.WithRegion(n.conf.Sigv4.Region), +func createSNSClient(httpClient *http.Client, n *Notifier, tmpl func(string) string, tmplErr *error) (*sns.SNS, error) { + var creds *credentials.Credentials = nil + // If there are provided sigV4 credentials we want to use those to create a session. + if n.conf.Sigv4.AccessKey != "" && n.conf.Sigv4.SecretKey != "" { + creds = credentials.NewStaticCredentials(n.conf.Sigv4.AccessKey, string(n.conf.Sigv4.SecretKey), "") } - if n.conf.Sigv4.Profile != "" { - baseCfgOpts = append(baseCfgOpts, awsconfig.WithSharedConfigProfile(n.conf.Sigv4.Profile)) + sess, err := session.NewSessionWithOptions(session.Options{ + Config: aws.Config{ + Region: aws.String(n.conf.Sigv4.Region), + Endpoint: aws.String(tmpl(n.conf.APIUrl)), + }, + Profile: n.conf.Sigv4.Profile, + }) + if err != nil { + return nil, err } - if n.conf.Sigv4.AccessKey != "" { - creds := credentials.NewStaticCredentialsProvider(n.conf.Sigv4.AccessKey, string(n.conf.Sigv4.SecretKey), "") - baseCfgOpts = append(baseCfgOpts, awsconfig.WithCredentialsProvider(creds)) + if *tmplErr != nil { + return nil, notify.NewErrorWithReason(notify.ClientErrorReason, errors.Wrap(*tmplErr, "execute 'api_url' template")) } - // Final configuration options for the SNS client. - snsCfgOpts := baseCfgOpts - - // If a RoleARN is provided, create an STS client to assume the role. - // This uses a separate config load to ensure the STS client does not use a custom SNS endpoint. if n.conf.Sigv4.RoleARN != "" { - stsCfg, err := awsconfig.LoadDefaultConfig(ctx, baseCfgOpts...) - if err != nil { - return nil, fmt.Errorf("failed to load base config for STS: %w", err) + var stsSess *session.Session + if n.conf.APIUrl == "" { + stsSess = sess + } else { + // If we have set the API URL we need to create a new session to get the STS Credentials. + stsSess, err = session.NewSessionWithOptions(session.Options{ + Config: aws.Config{ + Region: aws.String(n.conf.Sigv4.Region), + Credentials: creds, + }, + Profile: n.conf.Sigv4.Profile, + }) + if err != nil { + return nil, err + } } - stsClient := sts.NewFromConfig(stsCfg) - stsProvider := stscreds.NewAssumeRoleProvider(stsClient, n.conf.Sigv4.RoleARN) - // Add the AssumeRole provider to the options for the SNS client config. - snsCfgOpts = append(snsCfgOpts, awsconfig.WithCredentialsProvider(aws.NewCredentialsCache(stsProvider))) + creds = stscreds.NewCredentials(stsSess, n.conf.Sigv4.RoleARN) } - - // Resolve the API URL from the template. - apiURL := tmpl(n.conf.APIUrl) - if *tmplErr != nil { - return nil, notify.NewErrorWithReason(notify.ClientErrorReason, fmt.Errorf("execute 'api_url' template: %w", *tmplErr)) - } - if apiURL != "" { - snsCfgOpts = append(snsCfgOpts, awsconfig.WithBaseEndpoint(apiURL)) - } - - // Load the final configuration for the SNS client. - snsCfg, err := awsconfig.LoadDefaultConfig(ctx, snsCfgOpts...) - if err != nil { - return nil, fmt.Errorf("failed to load final config for SNS: %w", err) - } - - // We will always need a region to be set. - if snsCfg.Region == "" { + // Use our generated session with credentials to create the SNS Client. + client := sns.New(sess, &aws.Config{Credentials: creds, HTTPClient: httpClient}) + // We will always need a region to be set by either the local config or the environment. + if aws.StringValue(sess.Config.Region) == "" { return nil, fmt.Errorf("region not configured in sns.sigv4.region or in default credentials chain") } - - return sns.NewFromConfig(snsCfg), nil + return client, nil } -func (n *Notifier) createPublishInput(ctx context.Context, tmpl func(string) string, tmplErr *error) (*sns.PublishInput, error) { +func createPublishInput(ctx context.Context, n *Notifier, tmpl func(string) string, tmplErr *error) (*sns.PublishInput, error) { + var modifiedReasons []string publishInput := &sns.PublishInput{} - messageAttributes := n.createMessageAttributes(tmpl) + messageAttributes := createAndValidateMessageAttributes(n, tmpl, &modifiedReasons) if *tmplErr != nil { - return nil, notify.NewErrorWithReason(notify.ClientErrorReason, fmt.Errorf("execute 'attributes' template: %w", *tmplErr)) + return nil, notify.NewErrorWithReason(notify.ClientErrorReason, errors.Wrap(*tmplErr, "execute 'attributes' template")) } - - // Max message size for a message in an SNS publish request is 256KB, - // except for SMS messages where the limit is 1600 characters/runes. - messageSizeLimit := 256 * 1024 if n.conf.TopicARN != "" { - topicARN := tmpl(n.conf.TopicARN) + topicTmpl := tmpl(n.conf.TopicARN) if *tmplErr != nil { - return nil, notify.NewErrorWithReason(notify.ClientErrorReason, fmt.Errorf("execute 'topic_arn' template: %w", *tmplErr)) + return nil, notify.NewErrorWithReason(notify.ClientErrorReason, errors.Wrap(*tmplErr, "execute 'topic_arn' template")) } - publishInput.TopicArn = aws.String(topicARN) - // If we are using a topic ARN, it could be a FIFO topic specified by the topic's suffix ".fifo". - if strings.HasSuffix(topicARN, ".fifo") { - key, err := notify.ExtractGroupKey(ctx) + publishInput.SetTopicArn(topicTmpl) + if n.isFifo == nil { + // If we are using a topic ARN it could be a FIFO topic specified by the topic postfix .fifo. + n.isFifo = aws.Bool(n.conf.TopicARN[len(n.conf.TopicARN)-5:] == ".fifo") + } + if *n.isFifo { + // Deduplication key and Message Group ID are only added if it's a FIFO SNS Topic. + groupKey, err := notify.ExtractGroupKey(ctx) if err != nil { return nil, err } - publishInput.MessageDeduplicationId = aws.String(key.Hash()) - publishInput.MessageGroupId = aws.String(key.Hash()) + now, ok := notify.Now(ctx) + if !ok { + return nil, errors.New("failed to extract now timestamp from context") + } + publishInput.SetMessageDeduplicationId(createMessageDeduplicationId(groupKey, now)) + publishInput.SetMessageGroupId(groupKey.Hash()) } } if n.conf.PhoneNumber != "" { - publishInput.PhoneNumber = aws.String(tmpl(n.conf.PhoneNumber)) + publishInput.SetPhoneNumber(tmpl(n.conf.PhoneNumber)) if *tmplErr != nil { - return nil, notify.NewErrorWithReason(notify.ClientErrorReason, fmt.Errorf("execute 'phone_number' template: %w", *tmplErr)) + return nil, notify.NewErrorWithReason(notify.ClientErrorReason, errors.Wrap(*tmplErr, "execute 'phone_number' template")) } - // If we have an SMS message, we need to truncate to 1600 characters/runes. - messageSizeLimit = 1600 } if n.conf.TargetARN != "" { - publishInput.TargetArn = aws.String(tmpl(n.conf.TargetARN)) + publishInput.SetTargetArn(tmpl(n.conf.TargetARN)) if *tmplErr != nil { - return nil, notify.NewErrorWithReason(notify.ClientErrorReason, fmt.Errorf("execute 'target_arn' template: %w", *tmplErr)) + return nil, notify.NewErrorWithReason(notify.ClientErrorReason, errors.Wrap(*tmplErr, "execute 'target_arn' template")) } } - tmplMessage := tmpl(n.conf.Message) + messageToSend := tmpl(n.conf.Message) if *tmplErr != nil { - return nil, notify.NewErrorWithReason(notify.ClientErrorReason, fmt.Errorf("execute 'message' template: %w", *tmplErr)) + return nil, notify.NewErrorWithReason(notify.ClientErrorReason, errors.Wrap(*tmplErr, "execute 'message' template")) } - messageToSend, isTrunc, err := validateAndTruncateMessage(tmplMessage, messageSizeLimit) + validationErr := validateMessage(n.logger, messageToSend, &modifiedReasons) + + if validationErr != nil { + messageToSend = validationErr.Error() + // If we modified the message with error message we need to add a message attribute showing that it was truncated. + messageAttributes[truncatedMessageAttributeKey] = truncatedMessageAttributeValue + } + + templatedSubject := tmpl(n.conf.Subject) + if *tmplErr != nil { + return nil, notify.NewErrorWithReason(notify.ClientErrorReason, errors.Wrap(*tmplErr, "execute 'subject' template")) + } + if n.conf.Subject != "" || templatedSubject != "" { + subjectToSend := validateAndTruncateSubject(n.logger, templatedSubject, &modifiedReasons) + + publishInput.SetSubject(subjectToSend) + } + + truncateAttributes, truncatedMessage, err := truncateMessageAttributesAndMessage(n.logger, n.conf.PhoneNumber, messageAttributes, messageToSend, validationErr != nil, &modifiedReasons) if err != nil { return nil, err } - if isTrunc { - // If we truncated the message we need to add a message attribute showing that it was truncated. - messageAttributes["truncated"] = snstypes.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")} + + err = addModifiedMessageAttributes(truncateAttributes, modifiedReasons) + if err != nil { + return nil, err } - publishInput.Message = aws.String(messageToSend) - publishInput.MessageAttributes = messageAttributes + publishInput.SetMessage(truncatedMessage) + publishInput.SetMessageAttributes(truncateAttributes) - if n.conf.Subject != "" { - publishInput.Subject = aws.String(tmpl(n.conf.Subject)) - if *tmplErr != nil { - return nil, notify.NewErrorWithReason(notify.ClientErrorReason, fmt.Errorf("execute 'subject' template: %w", *tmplErr)) + return publishInput, nil +} + +func createMessageDeduplicationId(groupKey notify.Key, now time.Time) string { + var deduplicationId = groupKey.String() + now.String() + return notify.Key(deduplicationId).Hash() +} + +func addModifiedMessageAttributes(attributes map[string]*sns.MessageAttributeValue, modifiedReasons []string) error { + if len(modifiedReasons) > 0 { + valueString, err := getModifiedReasonMessageAttributeValue(modifiedReasons) + if err != nil { + return err } + attributes[modifiedMessageAttributeKey] = &sns.MessageAttributeValue{DataType: aws.String("String.Array"), StringValue: aws.String(valueString)} } - return publishInput, nil + return nil } -func validateAndTruncateMessage(message string, maxMessageSizeInBytes int) (string, bool, error) { +func validateMessage(logger *slog.Logger, message string, modifiedReasons *[]string) error { if !utf8.ValidString(message) { - return "", false, fmt.Errorf("non utf8 encoded message string") + *modifiedReasons = append(*modifiedReasons, fmt.Sprintf(ComponentAndModifiedReason, Message, MessageNotValidUtf8)) + logger.Info("Message has been modified because of invalid UTF-8 encoded string.", "originalMessage", message) + return errors.New(MessageNotValidUtf8) } - if len(message) <= maxMessageSizeInBytes { - return message, false, nil + if len(message) == 0 { + *modifiedReasons = append(*modifiedReasons, fmt.Sprintf(ComponentAndModifiedReason, Message, MessageIsEmpty)) + logger.Info("Message has been modified because the content was empty.") + return errors.New(MessageIsEmpty) + } + return nil +} + +func validateAndTruncateSubject(logger *slog.Logger, subject string, modifiedReasons *[]string) string { + if subject == "" { + *modifiedReasons = append(*modifiedReasons, fmt.Sprintf(ComponentAndModifiedReason, Subject, SubjectEmpty)) + logger.Info("Subject has been modified because it is empty.", "originalSubject", subject) + return SubjectEmpty + } + + charactersInSubject := utf8.RuneCountInString(subject) + + if !isASCIINonControl(subject) { + *modifiedReasons = append(*modifiedReasons, fmt.Sprintf(ComponentAndModifiedReason, Subject, SubjectContainsIllegalChars)) + if charactersInSubject > subjectSizeLimitInCharacters { + subject = subject[:subjectSizeLimitInCharacters] + logger.Info("Subject has been modified because it contains control or non-ASCII characters.", "originalSubject(truncated)", subject) + } else { + logger.Info("Subject has been modified because it contains control or non-ASCII characters.", "originalSubject", subject) + } + + return SubjectContainsIllegalChars + } + + if charactersInSubject <= subjectSizeLimitInCharacters { + return subject } + // If the message is larger than our specified size we have to truncate. - truncated := make([]byte, maxMessageSizeInBytes) - copy(truncated, message) - return string(truncated), true, nil + logger.Info("Subject has been truncated because it exceeds size limit.") + *modifiedReasons = append(*modifiedReasons, fmt.Sprintf(ComponentAndModifiedReason, Subject, fmt.Sprintf(SubjectSizeExceeded, charactersInSubject))) + return subject[:subjectSizeLimitInCharacters] } -func (n *Notifier) createMessageAttributes(tmpl func(string) string) map[string]snstypes.MessageAttributeValue { +func createAndValidateMessageAttributes(n *Notifier, tmpl func(string) string, modifiedReasons *[]string) map[string]*sns.MessageAttributeValue { + numberOfInvalidMessageAttributes := 0 // Convert the given attributes map into the AWS Message Attributes Format. - attributes := make(map[string]snstypes.MessageAttributeValue, len(n.conf.Attributes)) + attributes := make(map[string]*sns.MessageAttributeValue, len(n.conf.Attributes)) for k, v := range n.conf.Attributes { - attributes[tmpl(k)] = snstypes.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(tmpl(v))} + attributeKey := tmpl(k) + attributeValue := tmpl(v) + if !isValidateMessageAttribute(attributeKey, attributeValue) { + numberOfInvalidMessageAttributes++ + n.logger.Debug("MessageAttribute has been removed because of invalid key/value.", "attributeKey", attributeKey, "attributeValue", attributeValue) + continue + } + attributes[attributeKey] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(attributeValue)} + } + + if numberOfInvalidMessageAttributes > 0 { + n.logger.Info("MessageAttributes has been removed because of invalid key/value.", "numberOfRemovedAttributes", numberOfInvalidMessageAttributes) + *modifiedReasons = append( + *modifiedReasons, + fmt.Sprintf(ComponentAndModifiedReason, MessageAttribute, fmt.Sprintf(MessageAttributeNotValidKeyOrValue, numberOfInvalidMessageAttributes)), + ) } return attributes } + +func isASCIINonControl(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] < 32 || s[i] >= 127 { + return false + } + } + return true +} + +/* + The priority to fit in the size is + 1. The modified reasons why will become attributes["modified"] + 2. message attributes + 3. message +*/ +func truncateMessageAttributesAndMessage(logger *slog.Logger, phoneNumber string, attributes map[string]*sns.MessageAttributeValue, + message string, isMessageModified bool, modifiedReasons *[]string) (map[string]*sns.MessageAttributeValue, string, error) { + if phoneNumber != "" { + charactersInSubject := utf8.RuneCountInString(message) + if charactersInSubject <= messageSizeLimitInCharactersForSMS { + return attributes, message, nil + } + + // SMS doesn't use customized messageAttributes + return attributes, message[:messageSizeLimitInCharactersForSMS], nil + } + + truncatedAttributes, attributesSize, err := truncateMessageAttributes(logger, attributes, modifiedReasons, isMessageModified, message) + + if err != nil { + return attributes, message, err + } + + truncatedMessage, isMessageTruncated, err := truncateMessage(logger, modifiedReasons, message, attributesSize) + if err != nil { + return attributes, message, err + } + + if isMessageTruncated { + truncatedAttributes[truncatedMessageAttributeKey] = truncatedMessageAttributeValue + } + + return truncatedAttributes, truncatedMessage, nil +} + +func createMessageAttributeSizeExceededReason(numberOfAttributeToBeTruncate int) string { + return fmt.Sprintf(ComponentAndModifiedReason, + MessageAttribute, + fmt.Sprintf(MessageAttributeSizeExceeded, numberOfAttributeToBeTruncate, messageSizeLimitInBytes/1024)) +} + +func createMessageSizeExceededReason(originMessage string) string { + return fmt.Sprintf( + ComponentAndModifiedReason, + Message, + fmt.Sprintf(MessageSizeExceeded, len(originMessage)/1024, messageSizeLimitInBytes/1024)) +} + +func getMessageSizeExceedReservedBytes(message string) (int, error) { + reservedTruncateAttributeValue := truncatedMessageAttributeValue + reservedTruncateAttributeBytes := + len(truncatedMessageAttributeKey) + len(*reservedTruncateAttributeValue.DataType) + len(*reservedTruncateAttributeValue.StringValue) + reservedMessageModifiedReasons := []string{ + createMessageSizeExceededReason(message), + } + reservedMessageModifiedReasonsBytes, err := getModifiedReasonMessageAttributeSize(reservedMessageModifiedReasons) + if err != nil { + return 0, err + } + + return reservedTruncateAttributeBytes + reservedMessageModifiedReasonsBytes, nil +} + +func truncateMessage(logger *slog.Logger, modifiedReasons *[]string, message string, attributeSize int) (string, bool, error) { + modifiedReasonBytes, err := getModifiedReasonMessageAttributeSize(*modifiedReasons) + if err != nil { + return message, false, err + } + + availableBytes := messageSizeLimitInBytes - modifiedReasonBytes - attributeSize + + if len(message) <= availableBytes { + return message, false, nil + } + + messageSizeExceedReservedBytes, err := getMessageSizeExceedReservedBytes(message) + if err != nil { + return message, false, err + } + availableBytes -= messageSizeExceedReservedBytes + // If the message is larger than our specified size we have to truncate. + *modifiedReasons = append(*modifiedReasons, createMessageSizeExceededReason(message)) + + truncated := make([]byte, availableBytes) + copy(truncated, message) + logger.Info("Message has been truncated because it exceeds size limit.", "originSize", len(message), "truncatedSize", len(truncated)) + return string(truncated), true, nil +} + +func truncateMessageAttributes(logger *slog.Logger, attributes map[string]*sns.MessageAttributeValue, + modifiedReasons *[]string, isMessageModified bool, message string) (map[string]*sns.MessageAttributeValue, int, error) { + + modifiedReasonBytes, err := getModifiedReasonMessageAttributeSize(*modifiedReasons) + if err != nil { + return attributes, 0, err + } + + availableBytes := messageSizeLimitInBytes - modifiedReasonBytes + + // We need to at least keep 1 byte for the message + availableBytes = availableBytes - 1 + // If message already gets modified, it means we replace the original message with an error. We don't want to truncate message in this case + if isMessageModified { + availableBytes = availableBytes - len(message) + } + + truncatedAttributes, attributeSize := fitMessageAttributeInAvailableSize(attributes, availableBytes) + + reservedMessageAttributeModifiedReasons := []string{ + // reserved for maximum number of attributes can be truncate + createMessageAttributeSizeExceededReason(len(attributes)), + } + reservedMessageAttributeModifiedReasonsBytes, err := getModifiedReasonMessageAttributeSize(reservedMessageAttributeModifiedReasons) + if err != nil { + return truncatedAttributes, attributeSize, err + } + + if len(truncatedAttributes) < len(attributes) { + availableBytes -= reservedMessageAttributeModifiedReasonsBytes + // truncate message attributes again in order to fit in the message attribute modified reasons + truncatedAttributes, attributeSize = fitMessageAttributeInAvailableSize(attributes, availableBytes) + } + + reservedMessageModifiedBytes, err := getMessageSizeExceedReservedBytes(message) + if err != nil { + return truncatedAttributes, attributeSize, err + } + + if !isMessageModified && len(message) > availableBytes-attributeSize { + availableBytes -= reservedMessageModifiedBytes + // truncate message attributes again in order to fit in the message modified reasons + truncatedAttributes, attributeSize = fitMessageAttributeInAvailableSize(attributes, availableBytes) + } + + if len(truncatedAttributes) < len(attributes) { + removedNumber := len(attributes) - len(truncatedAttributes) + logger.Info("MessageAttribute has been removed because it exceeds size limit.", "numberOfRemovedAttributes", removedNumber) + *modifiedReasons = append(*modifiedReasons, createMessageAttributeSizeExceededReason(removedNumber)) + } + + return truncatedAttributes, attributeSize, nil +} + +func fitMessageAttributeInAvailableSize(attributes map[string]*sns.MessageAttributeValue, availableBytes int) (map[string]*sns.MessageAttributeValue, int) { + attributesSize := 0 + truncatedAttributes := make(map[string]*sns.MessageAttributeValue) + + for k, v := range attributes { + pendingAddingAttributeSize := len(k) + len(*v.DataType) + len(*v.StringValue) + if attributesSize+pendingAddingAttributeSize <= availableBytes { + truncatedAttributes[k] = v + attributesSize += pendingAddingAttributeSize + } + } + + return truncatedAttributes, attributesSize +} + +func getModifiedReasonMessageAttributeSize(modifiedReasons []string) (int, error) { + if len(modifiedReasons) > 0 { + valueString, err := getModifiedReasonMessageAttributeValue(modifiedReasons) + if err != nil { + return 0, err + } + return len("String.Array") + len(modifiedMessageAttributeKey) + len(valueString), nil + } + return 0, nil +} + +func getModifiedReasonMessageAttributeValue(modifiedReasons []string) (string, error) { + jsonString, err := jsonMarshal(modifiedReasons) + if err != nil { + return "", err + } + + return string(jsonString), nil +} + +func isValidateMessageAttribute(messageAttributeKey string, messageAttributeValue string) bool { + if len(messageAttributeKey) == 0 || len(messageAttributeValue) == 0 { + return false + } + + if !isValidMessageAttributeKeyCharacters(messageAttributeKey) || + isInvalidMessageAttributeKeyPrefix(messageAttributeKey) || + isInvalidMessageAttributeKeySuffix(messageAttributeKey) || + isInvalidMessageAttributeKeySubstring(messageAttributeKey) { + return false + } + + if utf8.RuneCountInString(messageAttributeKey) > messageAttributeKeyLimitInCharacters { + return false + } + + if !utf8.ValidString(messageAttributeValue) { + return false + } + + return true +} diff --git a/notify/sns/sns_test.go b/notify/sns/sns_test.go index 7bafdc733d..0b1d0af7d4 100644 --- a/notify/sns/sns_test.go +++ b/notify/sns/sns_test.go @@ -15,55 +15,438 @@ package sns import ( "context" + "encoding/json" + "fmt" "net/url" + "strings" "testing" + "unicode/utf8" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/sns" + "github.com/pkg/errors" + "github.com/prometheus/alertmanager/config" + "github.com/prometheus/alertmanager/template" + "github.com/prometheus/alertmanager/types" commoncfg "github.com/prometheus/common/config" "github.com/prometheus/common/promslog" "github.com/prometheus/sigv4" - "github.com/stretchr/testify/require" - "github.com/prometheus/alertmanager/config" - "github.com/prometheus/alertmanager/template" - "github.com/prometheus/alertmanager/types" + "github.com/stretchr/testify/require" ) var logger = promslog.NewNopLogger() -func TestValidateAndTruncateMessage(t *testing.T) { - sBuff := make([]byte, 257*1024) +func TestValidateMessage(t *testing.T) { + var modifiedReasons []string + + invalidUtf8String := "\xc3\x28" + err := validateMessage(logger, invalidUtf8String, &modifiedReasons) + require.Equal(t, MessageNotValidUtf8, err.Error()) + require.Equal(t, 1, len(modifiedReasons)) + require.Equal(t, "Message: Error - not a valid UTF-8 encoded string", modifiedReasons[0]) + require.Equal(t, len(modifiedReasons), 1) + + emptyString := "" + err = validateMessage(logger, emptyString, &modifiedReasons) + require.Equal(t, MessageIsEmpty, err.Error()) + require.Equal(t, 2, len(modifiedReasons)) + require.Equal(t, "Message: Error - the message should not be empty", modifiedReasons[1]) +} + +func TestValidateAndTruncateSubject(t *testing.T) { + var modifiedReasons []string + notTruncate := make([]rune, 100) + for i := range notTruncate { + notTruncate[i] = 'e' + } + subject := validateAndTruncateSubject(logger, string(notTruncate), &modifiedReasons) + require.Equal(t, string(notTruncate), subject) + require.Equal(t, 100, utf8.RuneCountInString(string(subject))) + + modifiedReasons = nil + willBeTruncate := make([]rune, 101) + for i := range willBeTruncate { + willBeTruncate[i] = 'e' + } + subject = validateAndTruncateSubject(logger, string(willBeTruncate), &modifiedReasons) + require.Equal(t, string(notTruncate), subject) + require.Equal(t, 1, len(modifiedReasons)) + require.Equal(t, "Subject: Error - subject has been truncated from 101 characters because it exceeds the 100 character size limit", modifiedReasons[0]) + + modifiedReasons = nil + subjectWithNonAsciiAndExceedingSize := make([]rune, 102) + subjectWithNonAsciiAndExceedingSize[0] = '\xc3' + subjectWithNonAsciiAndExceedingSize[1] = '\x28' + for i := 2; i < 102; i++ { + subjectWithNonAsciiAndExceedingSize[i] = 'e' + } + + subject = validateAndTruncateSubject(logger, string(subjectWithNonAsciiAndExceedingSize), &modifiedReasons) + require.Equal(t, SubjectContainsIllegalChars, subject) + require.Equal(t, 1, len(modifiedReasons)) + require.Equal(t, "Subject: Error - contains control- or non-ASCII characters", modifiedReasons[0]) + + modifiedReasons = nil + nonAsciiString := "\xc3\x28" + subject = validateAndTruncateSubject(logger, nonAsciiString, &modifiedReasons) + require.Equal(t, SubjectContainsIllegalChars, subject) + require.Equal(t, 1, len(modifiedReasons)) + require.Equal(t, "Subject: Error - contains control- or non-ASCII characters", modifiedReasons[0]) + + modifiedReasons = nil + asciiControlString := "\a\b\t" + subject = validateAndTruncateSubject(logger, asciiControlString, &modifiedReasons) + require.Equal(t, SubjectContainsIllegalChars, subject) + require.Equal(t, 1, len(modifiedReasons)) + require.Equal(t, "Subject: Error - contains control- or non-ASCII characters", modifiedReasons[0]) + + modifiedReasons = nil + newLineString := "abc\ndef" + subject = validateAndTruncateSubject(logger, newLineString, &modifiedReasons) + require.Equal(t, SubjectContainsIllegalChars, subject) + require.Equal(t, 1, len(modifiedReasons)) + require.Equal(t, "Subject: Error - contains control- or non-ASCII characters", modifiedReasons[0]) + + modifiedReasons = nil + emptyString := "" + subject = validateAndTruncateSubject(logger, emptyString, &modifiedReasons) + require.Equal(t, SubjectEmpty, subject) + require.Equal(t, 1, len(modifiedReasons)) + require.Equal(t, "Subject: Error - subject, if provided, must be non-empty", modifiedReasons[0]) +} + +func TestCreateAndValidateMessageAttributes(t *testing.T) { + var modifiedReasons []string + attributes := map[string]string{ + "Invalid0": "", + ".Invalid1": "123", + "Invalid2.": "123", + "AWS.Invalid3": "123", + "Amazon.Invalid4": "123", + "Invalid..5": "123", + "Valid0": "123", + "AmazonValid1": "123", + "valid.2": "123", + "valid-_3": "123", + } + notifier, err := New( + &config.SNSConfig{ + Attributes: attributes, + HTTPConfig: &commoncfg.HTTPClientConfig{}, + }, + CreateTmpl(t), + logger, + ) + require.NoError(t, err) + + attributesAfterValidation := createAndValidateMessageAttributes(notifier, temlFunction(t), &modifiedReasons) + + require.Equal(t, 4, len(attributesAfterValidation)) + require.Equal(t, true, attributesAfterValidation["Valid0"] != nil) + require.Equal(t, true, attributesAfterValidation["AmazonValid1"] != nil) + require.Equal(t, true, attributesAfterValidation["valid.2"] != nil) + require.Equal(t, true, attributesAfterValidation["valid-_3"] != nil) + require.Equal(t, len(modifiedReasons), 1) + require.Equal(t, "MessageAttribute: Error - 6 of message attributes have been removed because of invalid MessageAttributeKey or MessageAttributeValue", modifiedReasons[0]) +} + +func TestAddModifiedMessageAttributes(t *testing.T) { + reasons := []string{"1", "2"} + attributes := map[string]*sns.MessageAttributeValue{ + "truncated": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")}, + } + + addModifiedMessageAttributes(attributes, reasons) + + require.Equal(t, 2, len(attributes)) + require.Equal(t, "[\"1\",\"2\"]", *attributes["modified"].StringValue) +} + +func TestTruncateMessageAttributesAndMessage_TotalSmallerThanSizeLimit(t *testing.T) { + logger := promslog.NewNopLogger() + + reasons := []string{"1", "2"} + sBuff := make([]byte, 30*1024) for i := range sBuff { sBuff[i] = byte(33) } - truncatedMessage, isTruncated, err := validateAndTruncateMessage(string(sBuff), 256*1024) - require.True(t, isTruncated) - require.NoError(t, err) - require.NotEqual(t, sBuff, truncatedMessage) - require.Len(t, truncatedMessage, 256*1024) - sBuff = make([]byte, 100) + attributes := map[string]*sns.MessageAttributeValue{ + "truncated": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")}, + "customized": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(string(sBuff))}, + } + + truncateAttributes, truncatedMessage, _ := truncateMessageAttributesAndMessage(logger, "", attributes, string(sBuff), false, &reasons) + require.Equal(t, 2, len(truncateAttributes)) + require.Equal(t, len(string(sBuff)), len(truncatedMessage)) + require.Equal(t, 2, len(reasons)) + require.Equal(t, true, getTotalSizeInBytes(reasons, truncateAttributes, truncatedMessage) <= messageSizeLimitInBytes) +} + +func TestTruncateMessageAttributesAndMessage_SMS(t *testing.T) { + reasons := []string{"1", "2"} + smsBuff := make([]rune, 1700) + for i := range smsBuff { + smsBuff[i] = 'e' + } + attributes := map[string]*sns.MessageAttributeValue{ + "truncated": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")}, + "customized": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(string(smsBuff))}, + } + _, truncatedMessage, _ := truncateMessageAttributesAndMessage(logger, "123", attributes, string(smsBuff), false, &reasons) + require.Equal(t, messageSizeLimitInCharactersForSMS, utf8.RuneCountInString(truncatedMessage)) +} + +func TestTruncateMessageAttributesAndMessage_MessageAttributesLargerThanSizeLimit(t *testing.T) { + reasons := []string{"1", "2"} + sBuff := make([]byte, 150*1024) for i := range sBuff { sBuff[i] = byte(33) } - truncatedMessage, isTruncated, err = validateAndTruncateMessage(string(sBuff), 100) - require.False(t, isTruncated) + attributes := map[string]*sns.MessageAttributeValue{ + "truncated": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")}, + "customized1": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(string(sBuff))}, + "customized2": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(string(sBuff))}, + } + truncateAttributes, truncatedMessage, _ := truncateMessageAttributesAndMessage(logger, "", attributes, string(sBuff), false, &reasons) + require.Equal(t, 2, len(truncateAttributes)) + require.Equal(t, true, len(truncatedMessage) < 150*1024) + require.Equal(t, "true", *truncateAttributes["truncated"].StringValue) + require.Equal(t, 4, len(reasons)) + require.Equal(t, true, getTotalSizeInBytes(reasons, truncateAttributes, truncatedMessage) < messageSizeLimitInBytes) +} + +func TestTruncateMessageAttributesAndMessage_messageHasBeenModified(t *testing.T) { + // messageAttributes + message > 256KB, however the message has already been modified, truncate the messageAttributes and keep the original message + reasons := []string{"1", "2"} + sBuff := make([]byte, 150*1024) + for i := range sBuff { + sBuff[i] = byte(33) + } + attributes := map[string]*sns.MessageAttributeValue{ + "truncated": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")}, + "customized1": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(string(sBuff))}, + "customized2": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(string(sBuff))}, + } + truncateAttributes, truncatedMessage, _ := truncateMessageAttributesAndMessage(logger, "", attributes, string(sBuff), true, &reasons) + require.Equal(t, 1, len(truncateAttributes)) + require.Equal(t, 150*1024, len(truncatedMessage)) + require.Equal(t, 3, len(reasons)) + require.Equal(t, true, getTotalSizeInBytes(reasons, truncateAttributes, truncatedMessage) <= messageSizeLimitInBytes) + +} + +func TestTruncateMessageAttributesAndMessage_atLeast1ByteForMessage(t *testing.T) { + //we still have rooms for reasons and at least 1 byte for message + messageBuff := make([]byte, messageSizeLimitInBytes) + for i := range messageBuff { + messageBuff[i] = byte(33) + } + + reservedMessageModifiedBytes, _ := getMessageSizeExceedReservedBytes(string(messageBuff)) + reasons := []string{"1", "2"} + modifiedReasonBytes, _ := getModifiedReasonMessageAttributeSize(reasons) + sBuff := make([]byte, messageSizeLimitInBytes-reservedMessageModifiedBytes-1-modifiedReasonBytes-len("customized1")-len("String")) + for i := range sBuff { + sBuff[i] = byte(33) + } + attributes := map[string]*sns.MessageAttributeValue{ + "customized1": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(string(sBuff))}, + } + truncateAttributes, truncatedMessage, _ := truncateMessageAttributesAndMessage(logger, "", attributes, string(sBuff), false, &reasons) + require.Equal(t, 2, len(truncateAttributes)) + require.Equal(t, "true", *truncateAttributes["truncated"].StringValue) + require.Equal(t, true, len(truncatedMessage) >= 1) + require.Equal(t, 3, len(reasons)) + fmt.Println("message", len(truncatedMessage)) + require.Equal(t, true, getTotalSizeInBytes(reasons, truncateAttributes, truncatedMessage) <= messageSizeLimitInBytes) +} + +func TestTruncateMessageAttributesAndMessage_truncateMessage(t *testing.T) { + reasons := []string{"1", "2"} + sBuff := make([]byte, 3*1024) + for i := range sBuff { + sBuff[i] = byte(33) + } + attributes := map[string]*sns.MessageAttributeValue{ + "customized1": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(string(sBuff))}, + } + + sBuffMessage := make([]byte, 256*1024) + + truncateAttributes, truncatedMessage, _ := truncateMessageAttributesAndMessage(logger, "", attributes, string(sBuffMessage), false, &reasons) + require.Equal(t, 2, len(truncateAttributes)) + require.Equal(t, "true", *truncateAttributes["truncated"].StringValue) + require.Equal(t, true, len(truncatedMessage) >= 1) + require.Equal(t, 3, len(reasons)) + require.Equal(t, true, getTotalSizeInBytes(reasons, truncateAttributes, truncatedMessage) <= messageSizeLimitInBytes) +} + +func TestTruncateMessageAttributesAndMessage_exactSize(t *testing.T) { + var reasons []string + sBuff := make([]byte, 128*1024-len("String")-len("customized1")) + for i := range sBuff { + sBuff[i] = byte(33) + } + attributes := map[string]*sns.MessageAttributeValue{ + "customized1": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(string(sBuff))}, + } + + sBuffMessage := make([]byte, 128*1024) + + truncateAttributes, truncatedMessage, _ := truncateMessageAttributesAndMessage(logger, "", attributes, string(sBuffMessage), false, &reasons) + require.Equal(t, 1, len(truncateAttributes)) + require.Equal(t, true, truncateAttributes["truncated"] == nil) + require.Equal(t, true, len(truncatedMessage) == 128*1024) + require.Equal(t, 0, len(reasons)) + require.Equal(t, true, getTotalSizeInBytes(reasons, truncateAttributes, truncatedMessage) == 256*1024) +} + +func TestTruncateMessageAttributesAndMessage_marshalFailure(t *testing.T) { + storedMarshal := jsonMarshal + jsonMarshal = fakemarshal + defer restoremarshal(storedMarshal) + + reasons := []string{"1", "2"} + sBuff := make([]byte, 30*1024) + for i := range sBuff { + sBuff[i] = byte(33) + } + + attributes := map[string]*sns.MessageAttributeValue{ + "truncated": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")}, + "customized": &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(string(sBuff))}, + } + + _, _, err := truncateMessageAttributesAndMessage(logger, "", attributes, string(sBuff), false, &reasons) + require.Equal(t, true, err != nil) +} + +func TestCreatePublishInput_noErrors(t *testing.T) { + var ( + ctx = context.Background() + temlErr error + ) + attributes := map[string]string{ + "attribName1": "attribValue1", + "attribName2": "attribValue2", + "attribName3": "attribValue3", + } + notifier, err := New( + &config.SNSConfig{ + Attributes: attributes, + HTTPConfig: &commoncfg.HTTPClientConfig{}, + TopicARN: "TestTopic", + PhoneNumber: "TestPhone", + TargetARN: "TestTarget", + Subject: "TestSubject", + Message: "TestMessage", + }, + CreateTmpl(t), + logger, + ) require.NoError(t, err) - require.Equal(t, string(sBuff), truncatedMessage) - invalidUtf8String := "\xc3\x28" - _, _, err = validateAndTruncateMessage(invalidUtf8String, 100) - require.Error(t, err) + publishInput, err := createPublishInput(ctx, notifier, temlFunction(t), &temlErr) + + require.Equal(t, "TestTopic", *publishInput.TopicArn) + require.Equal(t, "TestPhone", *publishInput.PhoneNumber) + require.Equal(t, "TestTarget", *publishInput.TargetArn) + require.Equal(t, "TestSubject", *publishInput.Subject) + require.Equal(t, "TestMessage", *publishInput.Message) + + _, hasModifiedAttrib := publishInput.MessageAttributes["modified"] + require.False(t, hasModifiedAttrib) +} + +func TestCreatePublishInput_subjectOmitted(t *testing.T) { + var ( + ctx = context.Background() + temlErr error + ) + attributes := map[string]string{ + "attribName1": "attribValue1", + "attribName2": "attribValue2", + "attribName3": "attribValue3", + } + notifier, err := New( + &config.SNSConfig{ + Attributes: attributes, + HTTPConfig: &commoncfg.HTTPClientConfig{}, + TopicARN: "TestTopic", + PhoneNumber: "TestPhone", + TargetARN: "TestTarget", + Subject: "", + Message: "TestMessage", + }, + CreateTmpl(t), + logger, + ) + require.NoError(t, err) + + publishInput, err := createPublishInput(ctx, notifier, temlFunction(t), &temlErr) + + require.Equal(t, "TestTopic", *publishInput.TopicArn) + require.Equal(t, "TestPhone", *publishInput.PhoneNumber) + require.Equal(t, "TestTarget", *publishInput.TargetArn) + require.Nil(t, publishInput.Subject) + require.Equal(t, "TestMessage", *publishInput.Message) + + require.Nil(t, publishInput.MessageAttributes["modified"]) } -func TestNotifyWithInvalidTemplate(t *testing.T) { +func TestCreatePublishInput_subjectEmpty(t *testing.T) { + var ( + ctx = context.Background() + temlErr error + ) + attributes := map[string]string{ + "attribName1": "attribValue1", + "attribName2": "attribValue2", + "attribName3": "attribValue3", + } + notifier, err := New( + &config.SNSConfig{ + Attributes: attributes, + HTTPConfig: &commoncfg.HTTPClientConfig{}, + TopicARN: "TestTopic", + PhoneNumber: "TestPhone", + TargetARN: "TestTarget", + Subject: "TestSubject", + Message: "TestMessage", + }, + CreateTmpl(t), + logger, + ) + require.NoError(t, err) + temlFunc := func(input string) string { + if input == "TestSubject" { + return "" + } + return input + } + + publishInput, err := createPublishInput(ctx, notifier, temlFunc, &temlErr) + + require.Equal(t, "TestTopic", *publishInput.TopicArn) + require.Equal(t, "TestPhone", *publishInput.PhoneNumber) + require.Equal(t, "TestTarget", *publishInput.TargetArn) + require.Equal(t, SubjectEmpty, *publishInput.Subject) + require.Equal(t, "TestMessage", *publishInput.Message) + + require.Contains(t, *publishInput.MessageAttributes["modified"].StringValue, SubjectEmpty) +} + +func TestNotify_errorInTemplate(t *testing.T) { for _, tc := range []struct { title string - errMsg string + errorMsg string updateCfg func(*config.SNSConfig) }{ { - title: "with invalid Attribute template", - errMsg: "execute 'attributes' template", + title: "with invalid Attribute template", + errorMsg: "execute 'attributes' template", updateCfg: func(cfg *config.SNSConfig) { cfg.Attributes = map[string]string{ "attribName1": "{{ template \"unknown_template\" . }}", @@ -71,48 +454,49 @@ func TestNotifyWithInvalidTemplate(t *testing.T) { }, }, { - title: "with invalid TopicArn template", - errMsg: "execute 'topic_arn' template", + title: "with invalid TopicArn template", + errorMsg: "execute 'topic_arn' template", updateCfg: func(cfg *config.SNSConfig) { cfg.TopicARN = "{{ template \"unknown_template\" . }}" }, }, { - title: "with invalid PhoneNumber template", - errMsg: "execute 'phone_number' template", + title: "with invalid PhoneNumber template", + errorMsg: "execute 'phone_number' template", updateCfg: func(cfg *config.SNSConfig) { cfg.PhoneNumber = "{{ template \"unknown_template\" . }}" }, }, { - title: "with invalid Message template", - errMsg: "execute 'message' template", + title: "with invalid Message template", + errorMsg: "execute 'message' template", updateCfg: func(cfg *config.SNSConfig) { cfg.Message = "{{ template \"unknown_template\" . }}" }, }, { - title: "with invalid Subject template", - errMsg: "execute 'subject' template", + title: "with invalid Subject template", + errorMsg: "execute 'subject' template", updateCfg: func(cfg *config.SNSConfig) { cfg.Subject = "{{ template \"unknown_template\" . }}" }, }, { - title: "with invalid APIUrl template", - errMsg: "execute 'api_url' template", + title: "with invalid APIUrl template", + errorMsg: "execute 'api_url' template", updateCfg: func(cfg *config.SNSConfig) { cfg.APIUrl = "{{ template \"unknown_template\" . }}" }, }, { - title: "with invalid TargetARN template", - errMsg: "execute 'target_arn' template", + title: "with invalid TargetARN template", + errorMsg: "execute 'target_arn' template", updateCfg: func(cfg *config.SNSConfig) { cfg.TargetARN = "{{ template \"unknown_template\" . }}" }, }, } { + tc := tc t.Run(tc.title, func(t *testing.T) { snsCfg := &config.SNSConfig{ HTTPConfig: &commoncfg.HTTPClientConfig{}, @@ -126,23 +510,53 @@ func TestNotifyWithInvalidTemplate(t *testing.T) { } notifier, err := New( snsCfg, - createTmpl(t), + CreateTmpl(t), logger, ) require.NoError(t, err) var alerts []*types.Alert _, err = notifier.Notify(context.Background(), alerts...) require.Error(t, err) - require.Contains(t, err.Error(), "template \"unknown_template\" not defined") - require.Contains(t, err.Error(), tc.errMsg) + require.Equal(t, true, err != nil) + require.True(t, strings.Contains(err.Error(), "template \"unknown_template\" not defined")) + require.True(t, strings.Contains(err.Error(), tc.errorMsg)) }) } } +func getTotalSizeInBytes(modifiedReasons []string, attributes map[string]*sns.MessageAttributeValue, message string) int { + attributesSize := 0 + for k, v := range attributes { + attributesSize += len(k) + len(*v.DataType) + len(*v.StringValue) + } + + modifiedReasonsSize := 0 + if len(modifiedReasons) > 0 { + jsonString, _ := json.Marshal(modifiedReasons) + modifiedReasonsSize = len("String.Array") + len("modified") + len(string(jsonString)) + } + return modifiedReasonsSize + attributesSize + len(message) +} + // CreateTmpl returns a ready-to-use template. -func createTmpl(t *testing.T) *template.Template { +func CreateTmpl(t *testing.T) *template.Template { tmpl, err := template.FromGlobs([]string{}) require.NoError(t, err) tmpl.ExternalURL, _ = url.Parse("http://am") return tmpl } + +// CreateTmpl returns a ready-to-use template. +func temlFunction(t *testing.T) func(string) string { + return func(input string) string { + return input + } +} + +func fakemarshal(v interface{}) ([]byte, error) { + return []byte{}, errors.New("Marshalling failed") +} + +func restoremarshal(replace func(v interface{}) ([]byte, error)) { + jsonMarshal = replace +}