|
| 1 | +import asyncio |
1 | 2 | import logging |
2 | 3 | from functools import wraps |
3 | 4 | from typing import Any, Callable, Optional, TypeVar |
@@ -121,13 +122,99 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Optional[R]: |
121 | 122 | # Return the output of the decorated function |
122 | 123 | return func_output # type: ignore [return-value] |
123 | 124 |
|
| 125 | + @wraps(func) |
| 126 | + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Optional[R]: |
| 127 | + span: Span |
| 128 | + with set_decorator_context( |
| 129 | + DecoratorContext( |
| 130 | + path=decorator_path, |
| 131 | + type="flow", |
| 132 | + version=flow_kernel, |
| 133 | + ) |
| 134 | + ) as decorator_context: |
| 135 | + with opentelemetry_tracer.start_as_current_span(HUMANLOOP_FLOW_SPAN_NAME) as span: # type: ignore |
| 136 | + span.set_attribute(HUMANLOOP_FILE_PATH_KEY, decorator_path) |
| 137 | + span.set_attribute(HUMANLOOP_FILE_TYPE_KEY, file_type) |
| 138 | + trace_id = get_trace_id() |
| 139 | + func_args = bind_args(func, args, kwargs) |
| 140 | + |
| 141 | + # Create the trace ahead so we have a parent ID to reference |
| 142 | + init_log_inputs = { |
| 143 | + "inputs": {k: v for k, v in func_args.items() if k != "messages"}, |
| 144 | + "messages": func_args.get("messages"), |
| 145 | + "trace_parent_id": trace_id, |
| 146 | + } |
| 147 | + this_flow_log: FlowLogResponse = client.flows._log( # type: ignore [attr-defined] |
| 148 | + path=decorator_context.path, |
| 149 | + flow=decorator_context.version, |
| 150 | + log_status="incomplete", |
| 151 | + **init_log_inputs, |
| 152 | + ) |
| 153 | + |
| 154 | + with set_trace_id(this_flow_log.id): |
| 155 | + func_output: Optional[R] |
| 156 | + log_output: Optional[str] |
| 157 | + log_error: Optional[str] |
| 158 | + log_output_message: Optional[ChatMessage] |
| 159 | + try: |
| 160 | + func_output = await func(*args, **kwargs) |
| 161 | + if ( |
| 162 | + isinstance(func_output, dict) |
| 163 | + and len(func_output.keys()) == 2 |
| 164 | + and "role" in func_output |
| 165 | + and "content" in func_output |
| 166 | + ): |
| 167 | + log_output_message = func_output # type: ignore [assignment] |
| 168 | + log_output = None |
| 169 | + else: |
| 170 | + log_output = process_output(func=func, output=func_output) |
| 171 | + log_output_message = None |
| 172 | + log_error = None |
| 173 | + except HumanloopRuntimeError as e: |
| 174 | + # Critical error, re-raise |
| 175 | + client.logs.delete(id=this_flow_log.id) |
| 176 | + span.record_exception(e) |
| 177 | + raise e |
| 178 | + except Exception as e: |
| 179 | + logger.error(f"Error calling {func.__name__}: {e}") |
| 180 | + log_output = None |
| 181 | + log_output_message = None |
| 182 | + log_error = str(e) |
| 183 | + func_output = None |
| 184 | + |
| 185 | + updated_flow_log = { |
| 186 | + "log_status": "complete", |
| 187 | + "output": log_output, |
| 188 | + "error": log_error, |
| 189 | + "output_message": log_output_message, |
| 190 | + "id": this_flow_log.id, |
| 191 | + } |
| 192 | + # Write the Flow Log to the Span on HL_LOG_OT_KEY |
| 193 | + write_to_opentelemetry_span( |
| 194 | + span=span, # type: ignore [arg-type] |
| 195 | + key=HUMANLOOP_LOG_KEY, |
| 196 | + value=updated_flow_log, # type: ignore |
| 197 | + ) |
| 198 | + # Return the output of the decorated function |
| 199 | + return func_output # type: ignore [return-value] |
| 200 | + |
| 201 | + # If the decorated function is an async function, return the async wrapper |
| 202 | + if asyncio.iscoroutinefunction(func): |
| 203 | + async_wrapper.file = File( # type: ignore |
| 204 | + path=decorator_path, |
| 205 | + type=file_type, # type: ignore [arg-type, typeddict-item] |
| 206 | + version=FlowDict(**flow_kernel), # type: ignore |
| 207 | + callable=async_wrapper, |
| 208 | + ) |
| 209 | + return async_wrapper |
| 210 | + |
| 211 | + # If the decorated function is a sync function, return the sync wrapper |
124 | 212 | wrapper.file = File( # type: ignore |
125 | 213 | path=decorator_path, |
126 | 214 | type=file_type, # type: ignore [arg-type, typeddict-item] |
127 | 215 | version=FlowDict(**flow_kernel), # type: ignore |
128 | 216 | callable=wrapper, |
129 | 217 | ) |
130 | | - |
131 | 218 | return wrapper |
132 | 219 |
|
133 | 220 | return decorator |
0 commit comments