diff --git a/lambdas/functions/webhook/src/sqs/index.test.ts b/lambdas/functions/webhook/src/sqs/index.test.ts index a01b1a9299..96f16da129 100644 --- a/lambdas/functions/webhook/src/sqs/index.test.ts +++ b/lambdas/functions/webhook/src/sqs/index.test.ts @@ -1,37 +1,54 @@ import { SendMessageCommandInput } from '@aws-sdk/client-sqs'; -import { sendActionRequest } from '.'; -import { describe, it, expect, afterEach, vi } from 'vitest'; +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +const { mockSqsClients, sqsConstructorSpy, tracedClients, logger } = vi.hoisted(() => ({ + mockSqsClients: [] as Array<{ sendMessage: ReturnType }>, + sqsConstructorSpy: vi.fn(), + tracedClients: [] as unknown[], + logger: { debug: vi.fn() }, +})); + +function MockSQS(this: unknown, config?: unknown) { + sqsConstructorSpy(config); + const client = { + sendMessage: vi.fn().mockResolvedValue({}), + }; + mockSqsClients.push(client); + return client; +} -const mockSQS = { - sendMessage: vi.fn(() => { - return {}; - }), -}; vi.mock('@aws-sdk/client-sqs', () => ({ - SQS: vi.fn().mockImplementation(function () { - return mockSQS; + SQS: vi.fn(MockSQS), +})); + +vi.mock('@aws-github-runner/aws-powertools-util', () => ({ + createChildLogger: vi.fn(() => logger), + getTracedAWSV3Client: vi.fn((client: unknown) => { + tracedClients.push(client); + return client; }), })); -vi.mock('@aws-github-runner/aws-ssm-util'); + +const cleanEnv = process.env; describe('Test sending message to SQS.', () => { - const queueUrl = 'https://sqs.eu-west-1.amazonaws.com/123456789/queued-builds'; - const message = { - eventType: 'type', - id: 0, - installationId: 0, - repositoryName: 'test', - repositoryOwner: 'owner', - queueId: queueUrl, - queueFifo: false, - repoOwnerType: 'Organization', - }; + beforeEach(() => { + vi.resetModules(); + vi.clearAllMocks(); + mockSqsClients.length = 0; + tracedClients.length = 0; + process.env = { ...cleanEnv }; + }); afterEach(() => { - vi.clearAllMocks(); + process.env = { ...cleanEnv }; }); it('no fifo queue', async () => { + const queueUrl = 'https://sqs.eu-west-1.amazonaws.com/123456789/queued-builds'; + const message = createMessage(queueUrl); + const { sendActionRequest } = await import('.'); + // Arrange const sqsMessage: SendMessageCommandInput = { QueueUrl: queueUrl, @@ -42,7 +59,73 @@ describe('Test sending message to SQS.', () => { const result = sendActionRequest(message); // Assert - expect(mockSQS.sendMessage).toHaveBeenCalledWith(sqsMessage); + expect(sqsConstructorSpy).toHaveBeenCalledWith({ region: 'eu-west-1' }); + expect(mockSqsClients[0].sendMessage).toHaveBeenCalledWith(sqsMessage); + expect(tracedClients).toHaveLength(1); + expect(logger.debug).toHaveBeenCalledTimes(1); await expect(result).resolves.not.toThrow(); }); + + it('falls back to AWS_REGION when the queue url is invalid', async () => { + process.env.AWS_REGION = 'us-east-2'; + const { sendActionRequest } = await import('.'); + + await sendActionRequest(createMessage('not-a-valid-url')); + + expect(sqsConstructorSpy).toHaveBeenCalledTimes(1); + expect(sqsConstructorSpy).toHaveBeenCalledWith({ region: 'us-east-2' }); + expect(mockSqsClients[0].sendMessage).toHaveBeenCalledTimes(1); + expect(tracedClients).toHaveLength(1); + }); + + it('creates a client without an explicit region when no region can be resolved', async () => { + delete process.env.AWS_REGION; + const { sendActionRequest } = await import('.'); + + await sendActionRequest(createMessage('not-a-valid-url')); + + expect(sqsConstructorSpy).toHaveBeenCalledTimes(1); + expect(sqsConstructorSpy).toHaveBeenCalledWith({}); + expect(mockSqsClients[0].sendMessage).toHaveBeenCalledTimes(1); + expect(tracedClients).toHaveLength(1); + }); + + it('reuses the same client for multiple queues in the same region', async () => { + const { sendActionRequest } = await import('.'); + + await sendActionRequest(createMessage('https://sqs.us-east-1.amazonaws.com/123456789/queue-a')); + await sendActionRequest(createMessage('https://sqs.us-east-1.amazonaws.com/123456789/queue-b')); + + expect(sqsConstructorSpy).toHaveBeenCalledTimes(1); + expect(sqsConstructorSpy).toHaveBeenCalledWith({ region: 'us-east-1' }); + expect(mockSqsClients[0].sendMessage).toHaveBeenCalledTimes(2); + expect(tracedClients).toHaveLength(1); + }); + + it('creates a separate client per region', async () => { + const { sendActionRequest } = await import('.'); + + await sendActionRequest(createMessage('https://sqs.us-east-1.amazonaws.com/123456789/queue-a')); + await sendActionRequest(createMessage('https://sqs.eu-west-1.amazonaws.com/123456789/queue-b')); + + expect(sqsConstructorSpy).toHaveBeenCalledTimes(2); + expect(sqsConstructorSpy).toHaveBeenNthCalledWith(1, { region: 'us-east-1' }); + expect(sqsConstructorSpy).toHaveBeenNthCalledWith(2, { region: 'eu-west-1' }); + expect(mockSqsClients).toHaveLength(2); + expect(mockSqsClients[0].sendMessage).toHaveBeenCalledTimes(1); + expect(mockSqsClients[1].sendMessage).toHaveBeenCalledTimes(1); + expect(tracedClients).toHaveLength(2); + }); }); + +function createMessage(queueId: string) { + return { + eventType: 'type', + id: 0, + installationId: 0, + repositoryName: 'test', + repositoryOwner: 'owner', + queueId, + repoOwnerType: 'Organization', + }; +} diff --git a/lambdas/functions/webhook/src/sqs/index.ts b/lambdas/functions/webhook/src/sqs/index.ts index a028d7dcc4..8af6096b59 100644 --- a/lambdas/functions/webhook/src/sqs/index.ts +++ b/lambdas/functions/webhook/src/sqs/index.ts @@ -4,6 +4,8 @@ import { createChildLogger, getTracedAWSV3Client } from '@aws-github-runner/aws- const logger = createChildLogger('sqs'); +const sqsClientsByRegion = new Map(); + export interface ActionRequestMessage { id: number; eventType: string; @@ -32,7 +34,8 @@ export interface GithubWorkflowEvent { } export const sendActionRequest = async (message: ActionRequestMessage): Promise => { - const sqs = getTracedAWSV3Client(new SQS({ region: process.env.AWS_REGION })); + const region = getRegionFromQueueUrl(message.queueId) ?? process.env.AWS_REGION; + const sqs = getSqsClient(region); const sqsMessage: SendMessageCommandInput = { QueueUrl: message.queueId, @@ -43,3 +46,30 @@ export const sendActionRequest = async (message: ActionRequestMessage): Promise< await sqs.sendMessage(sqsMessage); }; + +function getSqsClient(region: string | undefined): SQS { + if (!region) { + return getTracedAWSV3Client(new SQS({})); + } + + const cached = sqsClientsByRegion.get(region); + if (cached) { + return cached; + } + + const client = getTracedAWSV3Client(new SQS({ region })); + sqsClientsByRegion.set(region, client); + return client; +} + +function getRegionFromQueueUrl(queueUrl: string): string | undefined { + try { + const url = new URL(queueUrl); + const parts = url.hostname.split('.'); + if (parts.length >= 3 && parts[0] === 'sqs') { + return parts[1]; + } + } catch {} + + return undefined; +}