Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 33 additions & 41 deletions localstack/services/sqs/sqs_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from localstack.services.awslambda import lambda_api
from localstack.services.generic_proxy import ProxyListener


XMLNS_SQS = 'http://queue.amazonaws.com/doc/2012-11-05/'

SUCCESSFUL_SEND_MESSAGE_XML_TEMPLATE = """
Expand All @@ -33,7 +32,7 @@

# list of valid attribute names, and names not supported by the backend (elasticmq)
VALID_ATTRIBUTE_NAMES = ['DelaySeconds', 'MaximumMessageSize', 'MessageRetentionPeriod',
'Policy', 'ReceiveMessageWaitTimeSeconds', 'RedrivePolicy', 'VisibilityTimeout']
'ReceiveMessageWaitTimeSeconds', 'RedrivePolicy', 'VisibilityTimeout']
UNSUPPORTED_ATTRIBUTE_NAMES = [
'DelaySeconds', 'MaximumMessageSize', 'MessageRetentionPeriod', 'Policy', 'RedrivePolicy']

Expand All @@ -56,7 +55,10 @@ def forward_request(self, method, path, data, headers):
if new_response:
return new_response
elif action == 'SetQueueAttributes':
self._set_queue_attributes(path, req_data, headers)
queue_url = self._queue_url(path, req_data, headers)
self._set_queue_attributes(queue_url, req_data)
elif action == 'DeleteQueue':
QUEUE_ATTRIBUTES.pop(self._queue_url(path, req_data, headers), None)

if 'QueueName' in req_data:
encoded_data = urlencode(req_data, doseq=True) if method == 'POST' else ''
Expand Down Expand Up @@ -92,35 +94,6 @@ def return_response(self, method, path, data, headers, response, request_handler
content_str = content_str_original = to_str(response.content)

if response.status_code >= 400:

# Since the following 2 API calls are not implemented in ElasticMQ, we're mocking them
# and letting them to return an empty response
if action == 'TagQueue':
new_response = Response()
new_response.status_code = 200
new_response._content = ("""
<?xml version="1.0"?>
<TagQueueResponse>
<ResponseMetadata>
<RequestId>{}</RequestId>
</ResponseMetadata>
</TagQueueResponse>
""").strip().format(uuid.uuid4())
return new_response
elif action == 'ListQueueTags':
new_response = Response()
new_response.status_code = 200
new_response._content = ("""
<?xml version="1.0"?>
<ListQueueTagsResponse xmlns="{}">
<ListQueueTagsResult/>
<ResponseMetadata>
<RequestId>{}</RequestId>
</ResponseMetadata>
</ListQueueTagsResponse>
""").strip().format(XMLNS_SQS, uuid.uuid4())
return new_response

return response

self._fire_event(req_data, response)
Expand All @@ -137,10 +110,15 @@ def return_response(self, method, path, data, headers, response, request_handler
# expose external hostname:port
external_port = SQS_PORT_EXTERNAL or get_external_port(headers, request_handler)
content_str = re.sub(r'<QueueUrl>\s*([a-z]+)://[^<]*:([0-9]+)/([^<]*)\s*</QueueUrl>',
r'<QueueUrl>\1://%s:%s/\3</QueueUrl>' % (HOSTNAME_EXTERNAL, external_port), content_str)
r'<QueueUrl>\1://%s:%s/\3</QueueUrl>' % (HOSTNAME_EXTERNAL, external_port),
content_str)
# fix queue ARN
content_str = re.sub(r'<([a-zA-Z0-9]+)>\s*arn:aws:sqs:elasticmq:([^<]+)</([a-zA-Z0-9]+)>',
r'<\1>arn:aws:sqs:%s:\2</\3>' % (region_name), content_str)
r'<\1>arn:aws:sqs:%s:\2</\3>' % (region_name), content_str)

if action == 'CreateQueue':
queue_url = re.match(r'.*<QueueUrl>(.*)</QueueUrl>', content_str, re.DOTALL).group(1)
self._set_queue_attributes(queue_url, req_data)

if content_str_original != content_str:
# if changes have been made, return patched response
Expand Down Expand Up @@ -203,7 +181,7 @@ def format_message_attributes(self, data):
msg_attrs[key_name] = {}
# Find vals for each key_id
attrs = [(k, data[k]) for k in data
if k.startswith('{}.{}.'.format(prefix, key_id)) and not k.endswith('.Name')]
if k.startswith('{}.{}.'.format(prefix, key_id)) and not k.endswith('.Name')]
for (attr_k, attr_v) in attrs:
attr_name = attr_k.split('.')[3]
msg_attrs[key_name][attr_name[0].lower() + attr_name[1:]] = attr_v[0]
Expand Down Expand Up @@ -259,14 +237,27 @@ def _format_attributes(self, req_data):
result[key_name] = key_value
return result

# Format attributes as a list. Example input:
# {
# 'AttributeName.1': ['Policy'],
# 'AttributeName.2': ['MessageRetentionPeriod']
# }
def _format_attributes_names(self, req_data):
result = set()
for i in range(1, 500):
key = 'AttributeName.%s' % i
if key not in req_data:
break
result.add(req_data[key][0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming we're accessing index 0 here because req_data is created from urlparse which creates values as lists. We should probably update the documentation comment of this method then: ... 'AttributeName.1': ['Policy'] ....

In the future, we may apply some small preprocessing to extract all values from the lists contained in req_data, as usually they only contain a single item anyway.

return result

def _send_message(self, path, data, req_data, headers):
queue_url = self._queue_url(path, req_data, headers)
queue_name = queue_url.rpartition('/')[2]
message_body = req_data.get('MessageBody', [None])[0]
message_attributes = self.format_message_attributes(req_data)

process_result = lambda_api.process_sqs_message(message_body,
message_attributes, queue_name)
process_result = lambda_api.process_sqs_message(message_body, message_attributes, queue_name)
if process_result:
# If a Lambda was listening, do not add the message to the queue
new_response = Response()
Expand All @@ -279,8 +270,7 @@ def _send_message(self, path, data, req_data, headers):
new_response.status_code = 200
return new_response

def _set_queue_attributes(self, path, req_data, headers):
queue_url = self._queue_url(path, req_data, headers)
def _set_queue_attributes(self, queue_url, req_data):
attrs = self._format_attributes(req_data)
# select only the attributes in UNSUPPORTED_ATTRIBUTE_NAMES
attrs = dict([(k, v) for k, v in attrs.items() if k in UNSUPPORTED_ATTRIBUTE_NAMES])
Expand All @@ -290,13 +280,15 @@ def _set_queue_attributes(self, path, req_data, headers):
def _add_queue_attributes(self, path, req_data, content_str, headers):
flags = re.MULTILINE | re.DOTALL
queue_url = self._queue_url(path, req_data, headers)
requested_attributes = self._format_attributes_names(req_data)
regex = r'(.*<GetQueueAttributesResult>)(.*)(</GetQueueAttributesResult>.*)'
attrs = re.sub(regex, r'\2', content_str, flags=flags)
for key, value in QUEUE_ATTRIBUTES.get(queue_url, {}).items():
if not re.match(r'<Name>\s*%s\s*</Name>' % key, attrs, flags=flags):
if (not requested_attributes or requested_attributes.intersection({'All', key})) and \
not re.match(r'<Name>\s*%s\s*</Name>' % key, attrs, flags=flags):
attrs += '<Attribute><Name>%s</Name><Value>%s</Value></Attribute>' % (key, value)
content_str = (re.sub(regex, r'\1', content_str, flags=flags) +
attrs + re.sub(regex, r'\3', content_str, flags=flags))
attrs + re.sub(regex, r'\3', content_str, flags=flags))
return content_str

def _fire_event(self, req_data, response):
Expand Down
2 changes: 1 addition & 1 deletion localstack/services/sqs/sqs_starter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def check_sqs(expect_shutdown=False, print_error=False):
try:
# wait for port to be opened
wait_for_port_open(DEFAULT_PORT_SQS_BACKEND)
# check S3
# check SQS
out = aws_stack.connect_to_service(service_name='sqs').list_queues()
except Exception as e:
if print_error:
Expand Down
56 changes: 50 additions & 6 deletions tests/integration/test_sqs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
import unittest

from localstack.utils import testutil
from localstack.utils.aws import aws_stack
from localstack.utils.common import short_uid, load_file, retry
from .test_lambda import TEST_LAMBDA_PYTHON, LAMBDA_RUNTIME_PYTHON36, TEST_LAMBDA_LIBS
from .lambdas import lambda_integration
from .test_lambda import TEST_LAMBDA_PYTHON, LAMBDA_RUNTIME_PYTHON36, TEST_LAMBDA_LIBS

TEST_QUEUE_NAME = 'TestQueue'

Expand Down Expand Up @@ -70,7 +71,7 @@ def test_publish_get_delete_message(self):
response = self.client.receive_message(QueueUrl=queue_url)
self.assertFalse(response.get('Messages'))
self.client.change_message_visibility(QueueUrl=queue_url,
ReceiptHandle=messages[0]['ReceiptHandle'], VisibilityTimeout=0)
ReceiptHandle=messages[0]['ReceiptHandle'], VisibilityTimeout=0)
for i in range(2):
messages = self.client.receive_message(QueueUrl=queue_url, VisibilityTimeout=0)['Messages']
self.assertEquals(len(messages), 1)
Expand Down Expand Up @@ -116,7 +117,7 @@ def test_send_message_attributes(self):
payload = {}
attrs = {'attr1': {'StringValue': 'val1', 'DataType': 'String'}}
self.client.send_message(QueueUrl=queue_url, MessageBody=json.dumps(payload),
MessageAttributes=attrs)
MessageAttributes=attrs)

result = self.client.receive_message(QueueUrl=queue_url, MessageAttributeNames=['All'])
messages = result['Messages']
Expand Down Expand Up @@ -149,15 +150,16 @@ def test_dead_letter_queue_execution(self):
queue_arn1 = aws_stack.sqs_queue_arn(queue_name1)
policy = {'deadLetterTargetArn': queue_arn1, 'maxReceiveCount': 1}
queue_url2 = self.client.create_queue(QueueName=queue_name2,
Attributes={'RedrivePolicy': json.dumps(policy)})['QueueUrl']
Attributes={'RedrivePolicy': json.dumps(policy)})['QueueUrl']
queue_arn2 = aws_stack.sqs_queue_arn(queue_name2)

# create Lambda and add source mapping
lambda_name = 'test-%s' % short_uid()
zip_file = testutil.create_lambda_archive(load_file(TEST_LAMBDA_PYTHON),
get_content=True, libs=TEST_LAMBDA_LIBS, runtime=LAMBDA_RUNTIME_PYTHON36)
get_content=True, libs=TEST_LAMBDA_LIBS,
runtime=LAMBDA_RUNTIME_PYTHON36)
testutil.create_lambda_function(func_name=lambda_name, zip_file=zip_file,
runtime=LAMBDA_RUNTIME_PYTHON36)
runtime=LAMBDA_RUNTIME_PYTHON36)
lambda_client.create_event_source_mapping(EventSourceArn=queue_arn2, FunctionName=lambda_name)

# add message to SQS, which will trigger the Lambda, resulting in an error
Expand All @@ -174,4 +176,46 @@ def receive_dlq():
self.assertIn('RequestID', msg_attrs)
self.assertIn('ErrorCode', msg_attrs)
self.assertIn('ErrorMessage', msg_attrs)

retry(receive_dlq, retries=8, sleep=2)

def test_set_queue_attribute_at_creation(self):
queue_name = 'queue-%s' % short_uid()

attributes = {
'MessageRetentionPeriod': '604800', # This one is unsupported by ElasticMq and should be saved in memory
'ReceiveMessageWaitTimeSeconds': '10',
'VisibilityTimeout': '30'
}

queue_url = self.client.create_queue(QueueName=queue_name, Attributes=attributes)['QueueUrl']
creation_attributes = self.client.get_queue_attributes(QueueUrl=queue_url, AttributeNames=['All'])

# assertion
self.assertIn('MessageRetentionPeriod', creation_attributes['Attributes'].keys())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: We could also assert that the correct value is returned here (604800). Same for the test test_get_specific_queue_attribute_response below.

self.assertEqual('604800', creation_attributes['Attributes']['MessageRetentionPeriod'])

# cleanup
self.client.delete_queue(QueueUrl=queue_url)

def test_get_specific_queue_attribute_response(self):
queue_name = 'queue-%s' % short_uid()

# Two attributes unsupported by ElasticMq
attributes = {
'MessageRetentionPeriod': '604800',
'DelaySeconds': '10',
}

queue_url = self.client.create_queue(QueueName=queue_name, Attributes=attributes)['QueueUrl']
unsupported_attribute_get = self.client.get_queue_attributes(QueueUrl=queue_url,
AttributeNames=['MessageRetentionPeriod'])
supported_attribute_get = self.client.get_queue_attributes(QueueUrl=queue_url,
AttributeNames=['QueueArn'])
# assertion
self.assertTrue('MessageRetentionPeriod' in unsupported_attribute_get['Attributes'].keys())
self.assertEqual('604800', unsupported_attribute_get['Attributes']['MessageRetentionPeriod'])
self.assertTrue('QueueArn' in supported_attribute_get['Attributes'].keys())

# cleanup
self.client.delete_queue(QueueUrl=queue_url)