44from dspy .adapters import Image as dspy_Image
55from dspy .signatures import Signature as dspy_Signature
66from dspy .utils .callback import BaseCallback
7- from langfuse .client import Langfuse , StatefulGenerationClient # type: ignore
8- from langfuse .decorators import langfuse_context # type: ignore
7+ from langfuse import Langfuse , LangfuseGeneration , get_client
98from litellm .cost_calculator import completion_cost
109from loguru import logger as log
1110from pydantic import BaseModel , Field , ValidationError
@@ -52,7 +51,7 @@ def __init__(self, signature: type[dspy_Signature]) -> None:
5251 )
5352 self .current_prompt = contextvars .ContextVar [str ]("current_prompt" )
5453 self .current_completion = contextvars .ContextVar [str ]("current_completion" )
55- self .current_span = contextvars .ContextVar [StatefulGenerationClient | None ](
54+ self .current_span = contextvars .ContextVar [LangfuseGeneration | None ](
5655 "current_span"
5756 )
5857 self .model_name_at_span_creation = contextvars .ContextVar [str | None ](
@@ -91,8 +90,8 @@ def on_module_end( # noqa
9190 exception : Exception | None = None , # noqa
9291 ) -> None :
9392 metadata = {
94- "existing_trace_id" : langfuse_context .get_current_trace_id (),
95- "parent_observation_id" : langfuse_context .get_current_observation_id (),
93+ "existing_trace_id" : get_client () .get_current_trace_id (),
94+ "parent_observation_id" : get_client () .get_current_observation_id (),
9695 }
9796 outputs_extracted = {} # Default to empty dict
9897 if outputs is not None :
@@ -102,7 +101,7 @@ def on_module_end( # noqa
102101 outputs_extracted = {"value" : outputs }
103102 except Exception as e :
104103 outputs_extracted = {"error_extracting_module_output" : str (e )}
105- langfuse_context . update_current_observation (
104+ get_client (). update_current_span (
106105 input = self .input_field_values .get (None ) or {},
107106 output = outputs_extracted ,
108107 metadata = metadata ,
@@ -134,9 +133,9 @@ def on_lm_start( # noqa
134133 self .current_system_prompt .set (system_prompt )
135134 self .current_prompt .set (user_input )
136135 self .model_name_at_span_creation .set (model_name )
137- trace_id = langfuse_context .get_current_trace_id ()
138- parent_observation_id = langfuse_context .get_current_observation_id ()
139- span_obj : StatefulGenerationClient | None = None
136+ trace_id = get_client () .get_current_trace_id ()
137+ parent_observation_id = get_client () .get_current_observation_id ()
138+ span_obj : LangfuseGeneration | None = None
140139 if trace_id :
141140 span_obj = self .langfuse .generation (
142141 input = user_input ,
@@ -392,8 +391,8 @@ def on_tool_start( # noqa
392391
393392 log .debug (f"Tool call started: { tool_name } with args: { tool_args } " )
394393
395- trace_id = langfuse_context .get_current_trace_id ()
396- parent_observation_id = langfuse_context .get_current_observation_id ()
394+ trace_id = get_client () .get_current_trace_id ()
395+ parent_observation_id = get_client () .get_current_observation_id ()
397396
398397 if trace_id :
399398 # Create a span for the tool call
0 commit comments