Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/check-binaries.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ on:
schedule:
- cron: "0 16 * * 1-5" # min h d Mo DoW / 9am PST M-F

permissions:
issues: write

jobs:
check-for-vulnerabilities:
runs-on: ubuntu-latest
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ on:
description: "Information about the release"
required: true
default: "New release"
permissions:
contents: write

jobs:
Release:
environment: Release
Expand Down
20 changes: 20 additions & 0 deletions .github/workflows/validate-branch-into-main.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: Validate PR Branch into Main

on:
pull_request:
branches:
- main

jobs:
validate-pr-branch:
runs-on: ubuntu-latest
steps:
- name: Check source branch
run: |
SOURCE_BRANCH="${{ github.head_ref }}"
if [[ "$SOURCE_BRANCH" != "develop" ]]; then
echo "Error: Only pull requests from develop branch are allowed into main"
echo "Current source branch ($SOURCE_BRANCH)."
exit 1
fi
echo "Source branch is develop - merge allowed"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,4 @@ See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more inform
## License
This project is licensed under the Apache-2.0 License.
This project is licensed under the Apache-2.0 License.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ func (h *HTTPHandler) invoke(w http.ResponseWriter, r *http.Request) {
return
}

invokeReq := rieinvoke.NewRieInvokeRequest(r, w)
invokeReq, err := rieinvoke.NewRieInvokeRequest(r, w)
if err != nil {
h.respondWithError(w, err)
return
}
ctx := logging.WithInvokeID(r.Context(), invokeReq.InvokeID())

metrics := invoke.NewInvokeMetrics(nil, &noOpCounter{})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ func GetInitRequestMessage(fileUtil utils.FileUtil, args []string) (intmodel.Ini
XrayTracingMode: intmodel.XRayTracingModePassThrough,
CurrentWorkingDir: cwd,
RuntimeBinaryCommand: cmd,
AvailabilityZoneId: "",
AmiId: "",

AvailabilityZoneId: "use1-az1",
AmiId: "",
}, nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func Test_getInitRequestMessage(t *testing.T) {
XrayTracingMode: intmodel.XRayTracingModePassThrough,
CurrentWorkingDir: "REPLACE",
RuntimeBinaryCommand: []string{"/path/to/bootstrap"},
AvailabilityZoneId: "",
AvailabilityZoneId: "use1-az1",
AmiId: "",
},
},
Expand Down Expand Up @@ -116,7 +116,7 @@ func Test_getInitRequestMessage(t *testing.T) {
XrayTracingMode: intmodel.XRayTracingModePassThrough,
CurrentWorkingDir: "/var/task",
RuntimeBinaryCommand: []string{"/custom/bootstrap", "custom_handler"},
AvailabilityZoneId: "",
AvailabilityZoneId: "use1-az1",
AmiId: "",
},
},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package invoke

const (
RequestIdHeader = "X-Amzn-RequestId"

ClientContextHeader = "X-Amz-Client-Context"

CognitoIdentityHeader = "X-Amz-Cognito-Identity"
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
package invoke

import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"time"

Expand All @@ -16,6 +20,11 @@ import (
"github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model"
)

type cognitoIdentity struct {
CognitoIdentityID string `json:"cognitoIdentityId"`
CognitoIdentityPoolID string `json:"cognitoIdentityPoolId"`
}

type rieInvokeRequest struct {
request *http.Request
writer http.ResponseWriter
Expand All @@ -35,18 +44,47 @@ type rieInvokeRequest struct {
functionVersionID string
}

func NewRieInvokeRequest(request *http.Request, writer http.ResponseWriter) *rieInvokeRequest {
func NewRieInvokeRequest(request *http.Request, writer http.ResponseWriter) (*rieInvokeRequest, model.AppError) {

contentType := request.Header.Get(invoke.СontentTypeHeader)
if contentType == "" {
contentType = "application/json"
}

invokeID := request.Header.Get("X-Amzn-RequestId")
invokeID := request.Header.Get(RequestIdHeader)
if invokeID == "" {
invokeID = uuid.New().String()
}

clientContext := ""
if encodedClientContext := request.Header.Get(ClientContextHeader); encodedClientContext != "" {
decodedClientContext, err := base64.StdEncoding.DecodeString(encodedClientContext)
if err != nil {
slog.Warn("Failed to decode X-Amz-Client-Context header", "err", err)
return nil, model.NewClientError(
fmt.Errorf("X-Amz-Client-Context must be a valid base64 encoded string: %w", err),
model.ErrorSeverityInvalid,
model.ErrorMalformedRequest,
)
}
clientContext = string(decodedClientContext)
}

var cognitoIdentityId, cognitoIdentityPoolId string
if cognitoIdentityHeader := request.Header.Get(CognitoIdentityHeader); cognitoIdentityHeader != "" {
var cognito cognitoIdentity
if err := json.Unmarshal([]byte(cognitoIdentityHeader), &cognito); err != nil {
slog.Warn("Failed to parse X-Amz-Cognito-Identity header", "err", err)
return nil, model.NewClientError(
fmt.Errorf("X-Amz-Cognito-Identity must be a valid JSON string: %w", err),
model.ErrorSeverityInvalid,
model.ErrorMalformedRequest,
)
}
cognitoIdentityId = cognito.CognitoIdentityID
cognitoIdentityPoolId = cognito.CognitoIdentityPoolID
}

req := &rieInvokeRequest{
request: request,
writer: writer,
Expand All @@ -56,13 +94,13 @@ func NewRieInvokeRequest(request *http.Request, writer http.ResponseWriter) *rie
responseBandwidthRate: 2 * 1024 * 1024,
responseBandwidthBurstSize: 6 * 1024 * 1024,
traceId: request.Header.Get(invoke.TraceIdHeader),
cognitoIdentityId: "",
cognitoIdentityPoolId: "",
clientContext: request.Header.Get("X-Amz-Client-Context"),
cognitoIdentityId: cognitoIdentityId,
cognitoIdentityPoolId: cognitoIdentityPoolId,
clientContext: clientContext,
responseMode: request.Header.Get(invoke.ResponseModeHeader),
}

return req
return req, nil
}

func (r *rieInvokeRequest) ContentType() string {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model"
)

func TestNewRieInvokeRequest(t *testing.T) {
tests := []struct {
name string
request func() *http.Request
writer http.ResponseWriter
want *rieInvokeRequest
name string
request func() *http.Request
writer http.ResponseWriter
want *rieInvokeRequest
wantError bool
wantErrorContain string
}{
{
name: "no_headers_in_request",
Expand All @@ -37,6 +41,7 @@ func TestNewRieInvokeRequest(t *testing.T) {
cognitoIdentityPoolId: "",
clientContext: "",
},
wantError: false,
},
{
name: "all_headers_present_in_request",
Expand All @@ -46,6 +51,7 @@ func TestNewRieInvokeRequest(t *testing.T) {
r.Header.Set("X-Amzn-Trace-Id", "Root=1-5e1b4151-5ac6c58f3375aa3c7c6b73c9")
r.Header.Set("X-Amz-Client-Context", "eyJjdXN0b20iOnsidGVzdCI6InZhbHVlIn19")
r.Header.Set("X-Amzn-RequestId", "test-invoke-id")
r.Header.Set("X-Amz-Cognito-Identity", `{"cognitoIdentityId":"us-east-1:12345678-1234-1234-1234-123456789012","cognitoIdentityPoolId":"us-east-1:87654321-4321-4321-4321-210987654321"}`)
require.NoError(t, err)
return r
},
Expand All @@ -57,16 +63,80 @@ func TestNewRieInvokeRequest(t *testing.T) {
responseBandwidthRate: 2 * 1024 * 1024,
responseBandwidthBurstSize: 6 * 1024 * 1024,
traceId: "Root=1-5e1b4151-5ac6c58f3375aa3c7c6b73c9",
cognitoIdentityId: "",
cognitoIdentityId: "us-east-1:12345678-1234-1234-1234-123456789012",
cognitoIdentityPoolId: "us-east-1:87654321-4321-4321-4321-210987654321",
clientContext: `{"custom":{"test":"value"}}`,
},
wantError: false,
},
{
name: "malformed_cognito_identity_header",
request: func() *http.Request {
r, err := http.NewRequest("GET", "http://localhost/", nil)
r.Header.Set("X-Amzn-RequestId", "test-invoke-id")
r.Header.Set("X-Amz-Cognito-Identity", "not-valid-json{")
require.NoError(t, err)
return r
},
writer: httptest.NewRecorder(),
want: nil,
wantError: true,
wantErrorContain: "X-Amz-Cognito-Identity must be a valid JSON string",
},
{
name: "malformed_client_context_header",
request: func() *http.Request {
r, err := http.NewRequest("GET", "http://localhost/", nil)
r.Header.Set("X-Amzn-RequestId", "test-invoke-id")
r.Header.Set("X-Amz-Client-Context", "not-valid-base64!!!")
require.NoError(t, err)
return r
},
writer: httptest.NewRecorder(),
want: nil,
wantError: true,
wantErrorContain: "X-Amz-Client-Context must be a valid base64 encoded string",
},
{
name: "partial_cognito_identity_header",
request: func() *http.Request {
r, err := http.NewRequest("GET", "http://localhost/", nil)
r.Header.Set("X-Amzn-RequestId", "test-invoke-id")
r.Header.Set("X-Amz-Cognito-Identity", `{"cognitoIdentityId":"us-east-1:only-id"}`)
require.NoError(t, err)
return r
},
writer: httptest.NewRecorder(),
want: &rieInvokeRequest{
invokeID: "test-invoke-id",
contentType: "application/json",
maxPayloadSize: 6*1024*1024 + 100,
responseBandwidthRate: 2 * 1024 * 1024,
responseBandwidthBurstSize: 6 * 1024 * 1024,
traceId: "",
cognitoIdentityId: "us-east-1:only-id",
cognitoIdentityPoolId: "",
clientContext: "eyJjdXN0b20iOnsidGVzdCI6InZhbHVlIn19",
clientContext: "",
},
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := tt.request()
got := NewRieInvokeRequest(r, tt.writer)
got, err := NewRieInvokeRequest(r, tt.writer)

if tt.wantError {
assert.NotNil(t, err)
assert.Nil(t, got)
assert.Equal(t, model.ErrorMalformedRequest, err.ErrorType())
assert.Equal(t, http.StatusBadRequest, err.ReturnCode())
assert.Contains(t, err.Error(), tt.wantErrorContain)
return
}

assert.Nil(t, err)
require.NotNil(t, got)

tt.want.request = r
tt.want.writer = tt.writer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
"os"
"time"

"github.com/google/uuid"
"github.com/aws/aws-lambda-runtime-interface-emulator/internal/lmds"

rieinvoke "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke"
"github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry"
"github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop"
Expand Down Expand Up @@ -47,8 +50,9 @@ func Run(supv supvmodel.ProcessSupervisor, args []string, fileUtil utils.FileUti
responderFactoryFunc := func(_ context.Context, invokeReq interop.InvokeRequest) invoke.InvokeResponseSender {
return rieinvoke.NewResponder(invokeReq)
}
invokeRouter := invoke.NewInvokeRouter(rapid.MaxIdleRuntimesQueueSize, eventsAPI, responderFactoryFunc, timeout.NewRecentCache())
invokeRouter := invoke.NewInvokeRouter(rapid.RuntimePoolSize, eventsAPI, responderFactoryFunc, timeout.NewRecentCache())

metadataToken := uuid.NewString()
deps := rapid.Dependencies{
EventsAPI: eventsAPI,
LogsEgressAPI: telemetry.NewLogsEgress(telemetryAPIRelay, os.Stdout),
Expand All @@ -57,9 +61,10 @@ func Run(supv supvmodel.ProcessSupervisor, args []string, fileUtil utils.FileUti
RuntimeAPIAddrPort: runtimeAPIAddr,
FileUtils: fileUtil,
InvokeRouter: invokeRouter,
MetadataService: lmds.NewService(metadataToken),
}

raptorApp, err := raptor.StartApp(deps, "", noOpLogger{})
raptorApp, err := raptor.StartApp(deps, "", metadataToken, noOpLogger{})
if err != nil {
return nil, nil, nil, fmt.Errorf("could not start runtime api server: %w", err)
}
Expand Down
Loading
Loading