44import json
55from typing import Dict , Any , List , Optional
66
7- from jinja2 import Environment , BaseLoader , Undefined
8- from jinja2 .utils import missing , object_type_repr
7+ from jinja2 import BaseLoader , Undefined
98from jinja2 .sandbox import SandboxedEnvironment
9+ from jinja2 .utils import missing , object_type_repr
1010
11- from cozeloop .spec .tracespec import PROMPT_KEY , INPUT , PROMPT_VERSION , V_SCENE_PROMPT_TEMPLATE , V_SCENE_PROMPT_HUB
1211from cozeloop .entities .prompt import (Prompt , Message , VariableDef , VariableType , TemplateType , Role ,
1312 PromptVariable )
1413from cozeloop .internal import consts
1817from cozeloop .internal .prompt .converter import _convert_prompt , _to_span_prompt_input , _to_span_prompt_output
1918from cozeloop .internal .prompt .openapi import OpenAPIClient , PromptQuery
2019from cozeloop .internal .trace .trace import TraceProvider
20+ from cozeloop .spec .tracespec import PROMPT_KEY , INPUT , PROMPT_VERSION , V_SCENE_PROMPT_TEMPLATE , V_SCENE_PROMPT_HUB , PROMPT_LABEL
2121
2222
2323class PromptProvider :
@@ -39,18 +39,18 @@ def __init__(
3939 auto_refresh = True )
4040 self .prompt_trace = prompt_trace
4141
42- def get_prompt (self , prompt_key : str , version : str = '' ) -> Optional [Prompt ]:
42+ def get_prompt (self , prompt_key : str , version : str = '' , label : str = '' ) -> Optional [Prompt ]:
4343 # Trace reporting
4444 if self .prompt_trace and self .trace_provider is not None :
4545 with self .trace_provider .start_span (consts .TRACE_PROMPT_HUB_SPAN_NAME ,
4646 consts .TRACE_PROMPT_HUB_SPAN_TYPE ,
4747 scene = V_SCENE_PROMPT_HUB ) as prompt_hub_pan :
4848 prompt_hub_pan .set_tags ({
4949 PROMPT_KEY : prompt_key ,
50- INPUT : json .dumps ({PROMPT_KEY : prompt_key , PROMPT_VERSION : version })
50+ INPUT : json .dumps ({PROMPT_KEY : prompt_key , PROMPT_VERSION : version , PROMPT_LABEL : label })
5151 })
5252 try :
53- prompt = self ._get_prompt (prompt_key , version )
53+ prompt = self ._get_prompt (prompt_key , version , label )
5454 if prompt is not None :
5555 prompt_hub_pan .set_tags ({
5656 PROMPT_VERSION : prompt .version ,
@@ -65,20 +65,20 @@ def get_prompt(self, prompt_key: str, version: str = '') -> Optional[Prompt]:
6565 prompt_hub_pan .set_error (str (e ))
6666 raise e
6767 else :
68- return self ._get_prompt (prompt_key , version )
68+ return self ._get_prompt (prompt_key , version , label )
6969
70- def _get_prompt (self , prompt_key : str , version : str ) -> Optional [Prompt ]:
70+ def _get_prompt (self , prompt_key : str , version : str , label : str = '' ) -> Optional [Prompt ]:
7171 """
7272 Get Prompt, prioritize retrieving from cache, if not found then fetch from server
7373 """
7474 # Try to get from cache
75- prompt = self .cache .get (prompt_key , version )
75+ prompt = self .cache .get (prompt_key , version , label )
7676 # If not in cache, fetch from server and cache it
7777 if prompt is None :
78- result = self .openapi_client .mpull_prompt (self .workspace_id , [PromptQuery (prompt_key = prompt_key , version = version )])
78+ result = self .openapi_client .mpull_prompt (self .workspace_id , [PromptQuery (prompt_key = prompt_key , version = version , label = label )])
7979 if result :
8080 prompt = _convert_prompt (result [0 ].prompt )
81- self .cache .set (prompt_key , version , prompt )
81+ self .cache .set (prompt_key , version , label , prompt )
8282 # object cache item should be read only
8383 return prompt .copy (deep = True )
8484
0 commit comments