Skip to content

Commit f8828b0

Browse files
committed
update
1 parent 9c0c0bb commit f8828b0

1 file changed

Lines changed: 45 additions & 31 deletions

File tree

cozeloop/decorator/decorator.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Optional, Callable, Any, overload, Dict, Generic, Iterator, TypeVar, List, cast, AsyncIterator
55
from functools import wraps
66

7-
from langchain_core.runnables import RunnableLambda
7+
from langchain_core.runnables import RunnableLambda, RunnableConfig
88

99
from cozeloop import Client, Span, start_span
1010
from cozeloop.decorator.utils import is_async_func, is_gen_func, is_async_gen_func, is_class_func
@@ -322,8 +322,7 @@ def to_runnable(
322322
Decorator to be RunnableLambda.
323323
324324
:param func: The function to be decorated, Requirements are as follows:
325-
1. The input parameter should only have one field, not multiple fields. You can use dict or Class as the input parameter.
326-
2. When the func is called, parameter config(RunnableConfig) is required, you must use the config containing cozeloop callback handler of 'current request', otherwise, the trace may be lost!
325+
1. When the func is called, parameter config(RunnableConfig) is required, you must use the config containing cozeloop callback handler of 'current request', otherwise, the trace may be lost!
327326
328327
Examples:
329328
@to_runnable
@@ -338,15 +337,14 @@ def decorator(func: Callable):
338337

339338
@wraps(func)
340339
def sync_wrapper(*args: Any, **kwargs: Any):
341-
config = kwargs.pop("config", None)
340+
config = _get_config(**kwargs)
342341
res = None
343342
try:
344-
inp = None
345-
if len(args) == 1:
346-
inp = args[0]
347-
else:
348-
inp = kwargs
349-
res = RunnableLambda(func).invoke(input=inp, config=config, **kwargs)
343+
inp = {
344+
"args": args,
345+
"kwargs": kwargs
346+
}
347+
res = RunnableLambda(_param_wrapped_func).invoke(input=inp, config=config, **kwargs)
350348
if hasattr(res, "__iter__"):
351349
return res
352350
except StopIteration:
@@ -359,15 +357,14 @@ def sync_wrapper(*args: Any, **kwargs: Any):
359357

360358
@wraps(func)
361359
async def async_wrapper(*args: Any, **kwargs: Any):
362-
config = kwargs.pop("config", None)
360+
config = _get_config(**kwargs)
363361
res = None
364362
try:
365-
inp = None
366-
if len(args) == 1:
367-
inp = args[0]
368-
else:
369-
inp = kwargs
370-
res = await RunnableLambda(func).ainvoke(input=inp, config=config, **kwargs)
363+
inp = {
364+
"args": args,
365+
"kwargs": kwargs
366+
}
367+
res = await RunnableLambda(_param_wrapped_func_async).ainvoke(input=inp, config=config, **kwargs)
371368
if hasattr(res, "__aiter__"):
372369
return res
373370
except StopIteration:
@@ -385,14 +382,13 @@ async def async_wrapper(*args: Any, **kwargs: Any):
385382

386383
@wraps(func)
387384
def gen_wrapper(*args: Any, **kwargs: Any):
388-
config = kwargs.pop("config", None)
385+
config = _get_config(**kwargs)
389386
try:
390-
inp = None
391-
if len(args) == 1:
392-
inp = args[0]
393-
else:
394-
inp = kwargs
395-
gen = RunnableLambda(func).invoke(input=inp, config=config, **kwargs)
387+
inp = {
388+
"args": args,
389+
"kwargs": kwargs
390+
}
391+
gen = RunnableLambda(_param_wrapped_func).invoke(input=inp, config=config, **kwargs)
396392
try:
397393
for item in gen:
398394
yield item
@@ -403,14 +399,13 @@ def gen_wrapper(*args: Any, **kwargs: Any):
403399

404400
@wraps(func)
405401
async def async_gen_wrapper(*args: Any, **kwargs: Any):
406-
config = kwargs.pop("config", None)
402+
config = _get_config(**kwargs)
407403
try:
408-
inp = None
409-
if len(args) == 1:
410-
inp = args[0]
411-
else:
412-
inp = kwargs
413-
gen = RunnableLambda(func).invoke(input=inp, config=config, **kwargs)
404+
inp = {
405+
"args": args,
406+
"kwargs": kwargs
407+
}
408+
gen = RunnableLambda(_param_wrapped_func_async).invoke(input=inp, config=config, **kwargs)
414409
items = []
415410
try:
416411
async for item in gen:
@@ -428,6 +423,25 @@ async def async_gen_wrapper(*args: Any, **kwargs: Any):
428423
else:
429424
raise e
430425

426+
# for convert parameter
427+
def _param_wrapped_func(input_dict: dict) -> Any:
428+
args = input_dict.get("args", ())
429+
kwargs = input_dict.get("kwargs", {})
430+
return func(*args, **kwargs)
431+
432+
async def _param_wrapped_func_async(input_dict: dict) -> Any:
433+
args = input_dict.get("args", ())
434+
kwargs = input_dict.get("kwargs", {})
435+
return await func(*args, **kwargs)
436+
437+
def _get_config(**kwargs: Any) -> RunnableConfig | None:
438+
config = kwargs.pop("config", None)
439+
if config is None:
440+
config = RunnableConfig(run_name=func.__name__)
441+
config['run_name'] = func.__name__
442+
elif isinstance(config, dict):
443+
config['run_name'] = func.__name__
444+
return config
431445

432446
if is_async_gen_func(func):
433447
return async_gen_wrapper

0 commit comments

Comments
 (0)