Skip to content

Commit fe08978

Browse files
authored
Merge pull request openwallet-foundation#1853 from TimoGlastra/fix/return-processing-no-response
fix: return if return route but no response
2 parents 22ca606 + bc084cb commit fe08978

5 files changed

Lines changed: 52 additions & 2 deletions

File tree

aries_cloudagent/transport/inbound/http.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,12 @@ async def inbound_message_handler(self, request: web.BaseRequest):
9696
raise web.HTTPBadRequest()
9797

9898
if inbound.receipt.direct_response_requested:
99-
response = await session.wait_response()
99+
# Wait for the message to be processed. Only send a response if a response
100+
# buffer is present.
101+
await inbound.wait_processing_complete()
102+
response = (
103+
await session.wait_response() if session.response_buffer else None
104+
)
100105

101106
# no more responses
102107
session.can_respond = False

aries_cloudagent/transport/inbound/manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def dispatch_complete(self, message: InboundMessage, completed: CompletedTask):
181181
if session and session.accept_undelivered and not session.response_buffered:
182182
self.process_undelivered(session)
183183

184+
message.dispatch_processing_complete()
185+
184186
def closed_session(self, session: InboundSession):
185187
"""
186188
Clean up a closed session.

aries_cloudagent/transport/inbound/message.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Classes representing inbound messages."""
22

3+
import asyncio
34
from typing import Union
45

56
from .receipt import MessageReceipt
@@ -23,3 +24,12 @@ def __init__(
2324
self.receipt = receipt
2425
self.session_id = session_id
2526
self.transport_type = transport_type
27+
self.processing_complete_event = asyncio.Event()
28+
29+
def dispatch_processing_complete(self):
30+
"""Dispatch processing complete."""
31+
self.processing_complete_event.set()
32+
33+
async def wait_processing_complete(self):
34+
"""Wait for processing to complete."""
35+
await self.processing_complete_event.wait()

aries_cloudagent/transport/inbound/tests/test_http_transport.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def receive_message(
6262
message: InboundMessage,
6363
can_respond: bool = False,
6464
):
65+
message.wait_processing_complete = async_mock.CoroutineMock()
6566
self.message_results.append((message.payload, message.receipt, can_respond))
6667
if self.result_event:
6768
self.result_event.set()
@@ -119,13 +120,15 @@ async def test_send_message_outliers(self):
119120
mock_session.return_value = async_mock.MagicMock(
120121
receive=async_mock.CoroutineMock(
121122
return_value=async_mock.MagicMock(
122-
receipt=async_mock.MagicMock(direct_response_requested=True)
123+
receipt=async_mock.MagicMock(direct_response_requested=True),
124+
wait_processing_complete=async_mock.CoroutineMock(),
123125
)
124126
),
125127
can_respond=True,
126128
profile=InMemoryProfile.test_profile(),
127129
clear_response=async_mock.MagicMock(),
128130
wait_response=async_mock.CoroutineMock(return_value=b"Hello world"),
131+
response_buffer="something",
129132
)
130133
async with self.client.post("/", data=test_message) as resp:
131134
result = await resp.text()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import asyncio
2+
3+
from asynctest import TestCase
4+
5+
from ..message import InboundMessage
6+
from ..receipt import MessageReceipt
7+
8+
9+
class TestInboundMessage(TestCase):
10+
async def test_wait_response(self):
11+
message = InboundMessage(
12+
payload="test",
13+
connection_id="conn_id",
14+
receipt=MessageReceipt(),
15+
session_id="session_id",
16+
)
17+
assert not message.processing_complete_event.is_set()
18+
message.dispatch_processing_complete()
19+
assert message.processing_complete_event.is_set()
20+
21+
message = InboundMessage(
22+
payload="test",
23+
connection_id="conn_id",
24+
receipt=MessageReceipt(),
25+
session_id="session_id",
26+
)
27+
assert not message.processing_complete_event.is_set()
28+
task = message.wait_processing_complete()
29+
message.dispatch_processing_complete()
30+
await asyncio.wait_for(task, 1)

0 commit comments

Comments
 (0)