31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474 | class GatewayClient:
"""Gateway client that handles communication with the upstream gateway"""
def __init__(
self,
gateway_url: str,
server_name: str,
server_id: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
ssl_context: Optional[ssl.SSLContext] = None,
):
self.gateway_url = gateway_url
self.server_name = server_name
self.server_id = server_id or str(uuid.uuid4())
# self.websocket: Optional[ClientConnection] = None
self.websocket: Optional[websockets.ClientConnection] = None
self.message_handlers: dict[str, Callable[[Any, Optional[str]], Awaitable[None]]] = {}
self.is_connected = False
self.headers = headers or {}
self.ssl_context = ssl_context
self._listen_task: Optional[asyncio.Task] = None
def add_message_handler(self, name: str, handler: Callable[[Any, Optional[str]], Awaitable[None]]) -> None:
self.message_handlers[name] = handler
async def send(self, message: Any) -> None:
if not self.websocket or not self.is_connected:
logger.debug("Gateway not connected, unable to send message.")
return
message_str = json.dumps(message)
try:
await self.websocket.send(message_str)
logger.debug(f"Message sent to gateway: {message_str[:100]}...")
except websockets.exceptions.ConnectionClosed:
logger.warning("Failed to send message: Gateway connection closed.")
self.is_connected = False
except Exception: # pylint: disable=broad-except
logger.exception("Failed to send message to gateway")
# self.is_connected = False # Potentially, but ConnectionClosed is more specific
# Dentro da classe GatewayClient
async def _request_stdio_subprocess( # noqa: C901
self,
method_name: str,
params: dict,
proc_stdin: Optional[asyncio.StreamWriter],
pending_stdio_requests: dict[Any, asyncio.Future],
loop: asyncio.AbstractEventLoop,
timeout_sec: float = 15.0,
) -> Any:
if not proc_stdin or proc_stdin.is_closing():
logger.error(f"Subprocess stdin not available for sending request '{method_name}'.")
raise OSError(f"Subprocess stdin not available for '{method_name}'.") # noqa: TRY003
stdio_req_id = str(uuid.uuid4())
future = loop.create_future()
pending_stdio_requests[stdio_req_id] = future
request_to_stdio_payload = {
"jsonrpc": "2.0",
"id": stdio_req_id,
"method": method_name,
"params": params,
}
request_str = json.dumps(request_to_stdio_payload) + "\n"
logger.debug(f"Client (proxy) -> STDIO (Req ID: {stdio_req_id}, Method: {method_name}): {request_str[:200]}...")
try:
proc_stdin.write(request_str.encode("utf-8"))
await proc_stdin.drain()
except (ConnectionResetError, BrokenPipeError) as e:
logger.error(f"Pipe error writing to subprocess stdin for {stdio_req_id} ('{method_name}'): {e}") # noqa: TRY400
if stdio_req_id in pending_stdio_requests: # Garante que só remove se existir
pending_stdio_requests.pop(stdio_req_id)
if not future.done():
future.set_exception(e)
raise
except Exception as e_write:
logger.exception(f"Unexpected error writing to subprocess stdin for {stdio_req_id} ('{method_name}')")
if stdio_req_id in pending_stdio_requests:
pending_stdio_requests.pop(stdio_req_id)
if not future.done():
future.set_exception(e_write)
raise
try:
logger.debug(
f"Waiting for STDIO response for internal req ID {stdio_req_id} ('{method_name}', Timeout: {timeout_sec}s)"
)
return await asyncio.wait_for(future, timeout=timeout_sec)
except asyncio.TimeoutError:
logger.error( # noqa: TRY400
f"Timeout waiting for STDIO response for method '{method_name}', internal req ID '{stdio_req_id}'."
)
if stdio_req_id in pending_stdio_requests: # Limpa se timeout
pending_stdio_requests.pop(stdio_req_id)
raise
except Exception as e_await:
logger.debug(
f"Error from future for STDIO method '{method_name}', internal req ID '{stdio_req_id}': {e_await}"
)
# pending_stdio_requests deve ter sido removido se set_exception foi chamado em read_stdout
raise
async def _internal_listen_loop(self) -> None: # noqa: C901
"""Internal loop to continuously listen for messages."""
try:
if not self.websocket: # Should not happen if called correctly
logger.error("WebSocket not available for listening.")
return
logger.debug("Starting internal listen loop for gateway messages.")
async for message_raw in self.websocket:
message_str = str(message_raw)[:200] # Limit log size
logger.info(f"Received gateway message: {message_str}...")
try:
msg_data = json.loads(message_raw)
msg_id = msg_data.get("id") if isinstance(msg_data, dict) else None
for handler_name, handler_func in self.message_handlers.items():
# Schedule handler execution to not block the listen loop
asyncio.create_task( # noqa: RUF006
handler_func(msg_data, msg_id), # type: ignore[arg-type]
name=f"gw_handler_{handler_name}_{msg_id or 'no_id'}",
)
except json.JSONDecodeError:
logger.exception(f"Received invalid JSON from gateway: {message_raw!r}")
except Exception: # pylint: disable=broad-except
logger.exception("Error processing gateway message in handler")
except websockets.exceptions.ConnectionClosedOK:
logger.info("Gateway connection closed normally (OK).")
except websockets.exceptions.ConnectionClosedError as e:
logger.warning(f"Gateway connection closed with error: {e.code} {e.reason}")
except websockets.exceptions.ConnectionClosed as e: # Catch-all
logger.warning(f"Gateway connection closed unexpectedly: {e!r}")
except asyncio.CancelledError:
logger.info("Gateway listening loop was cancelled.")
except Exception: # pylint: disable=broad-except
logger.exception("Unhandled exception in gateway listening loop.")
finally:
self.is_connected = False
logger.debug("Gateway internal listen loop finished.")
# Dentro da classe GatewayClient
async def _handle_server_request_during_handshake(
self,
request_data: dict,
proc_stdin: Optional[asyncio.StreamWriter],
pending_stdio_requests: dict[Any, asyncio.Future],
loop: asyncio.AbstractEventLoop,
) -> None:
method = request_data.get("method")
remote_gateway_req_id = request_data.get("id")
response_payload = None
# ... (definições de client_name, client_version)
client_name = "mcpgateway-client-stdio"
client_version = "0.1.0"
if method == "initialize":
# ... (como antes, usando "2024-11-05" para protocolVersion) ...
logger.info(f"Responding to 'initialize' request (ID: {remote_gateway_req_id}) from gateway.")
client_announced_protocol_version = "2024-11-05"
response_payload = {
"jsonrpc": "2.0",
"id": remote_gateway_req_id,
"result": {
"protocolVersion": client_announced_protocol_version,
"serverInfo": {"name": client_name, "version": client_version},
"capabilities": {},
},
}
elif method == "tools/list":
logger.info(
f"Received 'tools/list' (ID: {remote_gateway_req_id}) from remote gateway. Querying stdio subprocess."
)
if not proc_stdin:
logger.error("Cannot query stdio for tools/list: proc_stdin is None.")
response_payload = {
"jsonrpc": "2.0",
"id": remote_gateway_req_id,
"error": {"code": -32000, "message": "Internal error: stdio not available"},
}
else:
try:
stdio_result_content = await self._request_stdio_subprocess(
"tools/list", {}, proc_stdin, pending_stdio_requests, loop, timeout_sec=20.0
)
if isinstance(stdio_result_content, dict) and "tools" in stdio_result_content:
response_payload = {
"jsonrpc": "2.0",
"id": remote_gateway_req_id,
"result": stdio_result_content,
}
logger.info(
f"Responding to remote gateway's 'tools/list' (ID: {remote_gateway_req_id}) with {len(stdio_result_content['tools'])} tools from stdio."
)
else:
logger.error(
f"STDIO server provided invalid or unexpected result for tools/list: {stdio_result_content}"
)
response_payload = {
"jsonrpc": "2.0",
"id": remote_gateway_req_id,
"error": {"code": -32002, "message": "Invalid response from stdio server for tools/list"},
}
except asyncio.TimeoutError:
logger.error( # noqa: TRY400
f"Timeout querying stdio subprocess for 'tools/list' (for remote gateway request ID {remote_gateway_req_id})."
)
response_payload = {
"jsonrpc": "2.0",
"id": remote_gateway_req_id,
"error": {"code": -32000, "message": "Timeout obtaining tools from stdio server"},
}
except Exception as e:
logger.exception(
f"Error querying stdio subprocess for 'tools/list' (for remote gateway request ID {remote_gateway_req_id}): {e}" # noqa: TRY401
)
response_payload = {
"jsonrpc": "2.0",
"id": remote_gateway_req_id,
"error": {"code": -32000, "message": f"Internal error obtaining tools: {e!s}"},
}
else:
logger.warning(
f"Received unhandled request method '{method}' (ID: {remote_gateway_req_id}) from gateway during handshake."
)
response_payload = {
"jsonrpc": "2.0",
"id": remote_gateway_req_id,
"error": {"code": -32601, "message": "Method not found"},
}
if response_payload and self.websocket and self.is_connected:
logger.debug(
f"Sending response for gateway request '{method}' (ID: {remote_gateway_req_id}): {str(response_payload)[:200]}..."
)
await self.send(response_payload)
elif not self.websocket or not self.is_connected:
logger.warning(f"Cannot send response for '{method}', websocket not available or not connected.")
async def connect_and_run_listen_loop( # noqa: C901
self,
proc_stdin: Optional[asyncio.StreamWriter],
pending_stdio_requests: dict[Any, asyncio.Future],
loop: asyncio.AbstractEventLoop,
) -> bool:
"""Connects, performs handshake, and runs the listening loop."""
if self._listen_task and not self._listen_task.done():
logger.warning("connect_and_run_listen_loop called while a listen task is already active.")
return self.is_connected
# server_id para nosso pedido de registro
# Se self.server_id não foi passado, GatewayClient gera um uuid4.
# É crucial usar ESTE ID para rastrear a resposta do NOSSO 'register'.
client_registration_id = self.server_id
try:
logger.info(f"Attempting to connect to gateway: {self.gateway_url}")
async with websockets.connect(
self.gateway_url,
ping_interval=20,
ping_timeout=20,
close_timeout=10,
additional_headers=self.headers,
ssl=self.ssl_context,
) as ws_connection:
self.websocket = ws_connection # Atribui aqui
self.is_connected = True
logger.info("Successfully connected to gateway.")
registration_params = {"name": self.server_name, "version": "1.0.0", "capabilities": {}}
registration_req = {
"jsonrpc": "2.0",
"id": client_registration_id,
"method": "register",
"params": registration_params,
}
logger.info(f"Sending client registration request (ID: {client_registration_id}): {registration_req}")
await self.send(registration_req) # Usa self.send que usa self.websocket
# Handshake Loop: Espera por várias mensagens, incluindo respostas e pedidos do servidor
# O servidor mcpport.gateway envia:
# 1. Ack: {"status": "received", "id": client_registration_id, ...}
# 2. Pedido: {"method": "initialize", "id": SERVER_INIT_ID, ...}
# (Cliente responde ao initialize)
# 3. Notificação: {"method": "notifications/initialized", ...}
# (Servidor pode dormir 5s aqui)
# 4. Pedido: {"method": "tools/list", "id": SERVER_TOOLS_ID, ...}
# (Cliente responde ao tools/list)
# 5. Confirmação final: {"status": "registered", "id": client_registration_id, ...}
fully_registered_with_gateway = False
# Timeout total para o handshake, e.g. 60 segundos
handshake_timeout = 60.0
loop_start_time = asyncio.get_running_loop().time()
while not fully_registered_with_gateway:
remaining_time = handshake_timeout - (asyncio.get_running_loop().time() - loop_start_time)
if remaining_time <= 0:
logger.error("Handshake timed out waiting for 'status: registered' from gateway.")
return False
try:
if self.websocket is None: # Checagem para mypy
logger.error("Websocket is None, cannot recv. Breaking handshake.")
return False # Ou levante uma exceção
else:
# Usa um timeout menor para cada recv para permitir a verificação do timeout geral
response_raw = await asyncio.wait_for(
self.websocket.recv(), timeout=min(15.0, remaining_time)
)
except asyncio.TimeoutError:
# Isso significa que o recv individual deu timeout, não necessariamente o handshake todo.
# O loop externo verificará o timeout total do handshake.
logger.debug("Individual recv timed out, continuing handshake loop.")
continue
except websockets.exceptions.ConnectionClosed:
logger.warning("Websocket connection closed during handshake.")
return False
logger.debug(f"Handshake recv: {str(response_raw)[:300]}...")
try:
msg_data = json.loads(response_raw)
msg_id = msg_data.get("id")
# Cenário 1: Resposta ao nosso pedido de registro
if msg_id == client_registration_id:
if msg_data.get("status") == "received":
logger.info(
f"Gateway ACKed our registration (ID: {client_registration_id}). Continuing handshake."
)
elif msg_data.get("status") == "registered":
logger.info(
f"Gateway confirmed full registration (ID: {client_registration_id}). Handshake successful."
)
fully_registered_with_gateway = True
# Não saia do loop ainda, pode haver mais mensagens antes do listen_loop
elif "result" in msg_data: # Caso o servidor envie um JSON-RPC result para o register
logger.info(
f"Gateway responded with 'result' for our registration (ID: {client_registration_id}). Assuming ACK."
)
elif "error" in msg_data:
logger.error(
f"Gateway rejected our registration (ID: {client_registration_id}). Error: {msg_data['error']}"
)
return False
else:
logger.warning(
f"Received unknown status/response for our registration ID {client_registration_id}: {msg_data}"
)
# Cenário 2: Um pedido do servidor para nós
elif "method" in msg_data and msg_id is not None: # msg_id NÃO é client_registration_id
logger.info(f"Received request from gateway: method '{msg_data['method']}', ID '{msg_id}'.")
await self._handle_server_request_during_handshake(
msg_data,
proc_stdin, # Passado aqui
pending_stdio_requests, # Passado aqui
loop, # Passado aqui
)
# Cenário 3: Uma notificação do servidor
elif "method" in msg_data and msg_id is None:
if msg_data["method"] == "notifications/initialized":
logger.info("Received 'notifications/initialized' from gateway.")
else:
logger.info(f"Received notification from gateway: {msg_data['method']}")
# Cenário 4: Um erro geral do servidor (não ligado a um ID nosso)
elif msg_data.get("status") == "error":
logger.error(
f"Handshake: Received general error from gateway: {msg_data.get('message', 'Unknown error')}"
)
if "Initialization failed" in msg_data.get(
"message", ""
) or "Unsupported protocol version" in msg_data.get("message", ""):
logger.error("Server-side initialization or protocol handshake failed. Aborting.")
return False
# Adicione um caso para fechar se o erro for especificamente do nosso registro
if msg_data.get("message", "").startswith(
f"Registration processing error for ID {client_registration_id}"
):
logger.error(
f"Gateway reported error processing our registration {client_registration_id}. Aborting."
)
return False
else:
logger.debug(f"Received other message during handshake: {str(msg_data)[:100]}...")
except json.JSONDecodeError:
logger.error(f"Handshake: Invalid JSON received: {response_raw!r}") # noqa: TRY400
except Exception as e:
logger.exception(f"Error processing message during handshake: {e}") # noqa: TRY401
return False # Erro inesperado no processamento
if not fully_registered_with_gateway: # Saiu do loop por timeout
logger.error(f"Handshake: Invalid JSON received: {response_raw!r}")
return False
logger.info(
"Handshake sequence complete. Starting main message listening loop (_internal_listen_loop)."
)
await self._internal_listen_loop() # Onde o cliente escuta por chamadas de ferramentas etc.
return True # Indica que o handshake e o início do listen loop foram OK
except websockets.exceptions.InvalidStatus as e:
logger.error(f"Gateway connection failed: HTTP {e.response.status_code}. Headers: {e.response.headers}") # noqa: TRY400
except ConnectionRefusedError:
logger.error(f"Connection refused by gateway at {self.gateway_url}") # noqa: TRY400
except asyncio.TimeoutError: # Timeout do websockets.connect ou do handshake_timeout geral
logger.error(f"Timeout during connection or overall handshake with {self.gateway_url}") # noqa: TRY400
except asyncio.CancelledError:
logger.info("Gateway connection and listen task was cancelled.")
except Exception:
logger.exception("Unhandled error during gateway connection or listening")
finally:
self.is_connected = False
self.websocket = None # <<< Garante que está None ao sair
logger.debug("connect_and_run_listen_loop method finished.")
return False
async def close(self) -> None:
logger.info("Closing GatewayClient...")
self.is_connected = False
# The _internal_listen_loop runs within connect_and_run_listen_loop's context.
# If connect_and_run_listen_loop is a task, cancelling that task will
# cause the 'async with websockets.connect(...)' to exit, closing the websocket.
# Direct cancellation of _listen_task is not needed if it's not directly managed.
# However, if connect_and_run_listen_loop is itself a task, that's what needs cancelling.
# if self.websocket and not self.websocket.closed:
if self.websocket:
logger.debug("Explicitly closing websocket in GatewayClient.close().")
try:
await self.websocket.close(code=1000, reason="Client shutdown initiated")
except Exception: # pylint: disable=broad-except
logger.exception("Exception during websocket close in GatewayClient.close")
self.websocket = None
self.message_handlers.clear()
logger.info("GatewayClient resources released.")
|