Skip to content

Commit 2625475

Browse files
Feat/jinja2 (#20)
* feat/jinja2
1 parent f52c547 commit 2625475

6 files changed

Lines changed: 710 additions & 3 deletions

File tree

cozeloop/entities/prompt.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
class TemplateType(str, Enum):
1111
NORMAL = "normal"
12+
JINJA2 = "jinja2"
1213

1314

1415
class Role(str, Enum):
@@ -26,6 +27,15 @@ class ToolType(str, Enum):
2627
class VariableType(str, Enum):
2728
STRING = "string"
2829
PLACEHOLDER = "placeholder"
30+
BOOLEAN = "boolean"
31+
INTEGER = "integer"
32+
FLOAT = "float"
33+
OBJECT = "object"
34+
ARRAY_STRING = "array<string>"
35+
ARRAY_BOOLEAN = "array<boolean>"
36+
ARRAY_INTEGER = "array<integer>"
37+
ARRAY_FLOAT = "array<float>"
38+
ARRAY_OBJECT = "array<object>"
2939

3040

3141
class ToolChoiceType(str, Enum):

cozeloop/internal/prompt/converter.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,16 @@ def _convert_message(msg: OpenAPIMessage) -> EntityMessage:
5959
def _convert_variable_type(openapi_type: OpenAPIVariableType) -> EntityVariableType:
6060
type_mapping = {
6161
OpenAPIVariableType.STRING: EntityVariableType.STRING,
62-
OpenAPIVariableType.PLACEHOLDER: EntityVariableType.PLACEHOLDER
62+
OpenAPIVariableType.PLACEHOLDER: EntityVariableType.PLACEHOLDER,
63+
OpenAPIVariableType.BOOLEAN: EntityVariableType.BOOLEAN,
64+
OpenAPIVariableType.INTEGER: EntityVariableType.INTEGER,
65+
OpenAPIVariableType.FLOAT: EntityVariableType.FLOAT,
66+
OpenAPIVariableType.OBJECT: EntityVariableType.OBJECT,
67+
OpenAPIVariableType.ARRAY_STRING: EntityVariableType.ARRAY_STRING,
68+
OpenAPIVariableType.ARRAY_INTEGER: EntityVariableType.ARRAY_INTEGER,
69+
OpenAPIVariableType.ARRAY_FLOAT: EntityVariableType.ARRAY_FLOAT,
70+
OpenAPIVariableType.ARRAY_BOOLEAN: EntityVariableType.ARRAY_BOOLEAN,
71+
OpenAPIVariableType.ARRAY_OBJECT: EntityVariableType.ARRAY_OBJECT
6372
}
6473
return type_mapping.get(openapi_type, EntityVariableType.STRING) # Default to STRING type
6574

@@ -122,7 +131,8 @@ def _convert_llm_config(config: OpenAPIModelConfig) -> EntityModelConfig:
122131

123132
def _convert_template_type(openapi_template_type: OpenAPITemplateType) -> EntityTemplateType:
124133
template_mapping = {
125-
OpenAPITemplateType.NORMAL: EntityTemplateType.NORMAL
134+
OpenAPITemplateType.NORMAL: EntityTemplateType.NORMAL,
135+
OpenAPITemplateType.JINJA2: EntityTemplateType.JINJA2
126136
}
127137
return template_mapping.get(openapi_template_type, EntityTemplateType.NORMAL) # Default to NORMAL type
128138

cozeloop/internal/prompt/openapi.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
class TemplateType(str, Enum):
1616
NORMAL = "normal"
17+
JINJA2 = "jinja2"
1718

1819

1920
class Role(str, Enum):
@@ -31,6 +32,15 @@ class ToolType(str, Enum):
3132
class VariableType(str, Enum):
3233
STRING = "string"
3334
PLACEHOLDER = "placeholder"
35+
BOOLEAN = "boolean"
36+
INTEGER = "integer"
37+
FLOAT = "float"
38+
OBJECT = "object"
39+
ARRAY_STRING = "array<string>"
40+
ARRAY_BOOLEAN = "array<boolean>"
41+
ARRAY_INTEGER = "array<integer>"
42+
ARRAY_FLOAT = "array<float>"
43+
ARRAY_OBJECT = "array<object>"
3444

3545

3646
class ToolChoiceType(str, Enum):

cozeloop/internal/prompt/prompt.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from jinja2 import Environment, BaseLoader, Undefined
88
from jinja2.utils import missing, object_type_repr
9+
from jinja2.sandbox import SandboxedEnvironment
910

1011
from cozeloop.spec.tracespec import PROMPT_KEY, INPUT, PROMPT_VERSION, V_SCENE_PROMPT_TEMPLATE, V_SCENE_PROMPT_HUB
1112
from cozeloop.entities.prompt import (Prompt, Message, VariableDef, VariableType, TemplateType, Role,
@@ -153,6 +154,27 @@ def _validate_variable_values_type(self, variable_defs: List[VariableDef], varia
153154
elif var_def.type == VariableType.PLACEHOLDER:
154155
if not (isinstance(val, Message) or (isinstance(val, List) and all(isinstance(item, Message) for item in val))):
155156
raise ValueError(f"type of variable '{var_def.key}' should be Message like object")
157+
elif var_def.type == VariableType.BOOLEAN:
158+
if not isinstance(val, bool):
159+
raise ValueError(f"type of variable '{var_def.key}' should be bool")
160+
elif var_def.type == VariableType.INTEGER:
161+
if not isinstance(val, int):
162+
raise ValueError(f"type of variable '{var_def.key}' should be int")
163+
elif var_def.type == VariableType.FLOAT:
164+
if not isinstance(val, float):
165+
raise ValueError(f"type of variable '{var_def.key}' should be float")
166+
elif var_def.type == VariableType.ARRAY_STRING:
167+
if not isinstance(val, list) or not all(isinstance(item, str) for item in val):
168+
raise ValueError(f"type of variable '{var_def.key}' should be array<string>")
169+
elif var_def.type == VariableType.ARRAY_BOOLEAN:
170+
if not isinstance(val, list) or not all(isinstance(item, bool) for item in val):
171+
raise ValueError(f"type of variable '{var_def.key}' should be array<boolean>")
172+
elif var_def.type == VariableType.ARRAY_INTEGER:
173+
if not isinstance(val, list) or not all(isinstance(item, int) for item in val):
174+
raise ValueError(f"type of variable '{var_def.key}' should be array<integer>")
175+
elif var_def.type == VariableType.ARRAY_FLOAT:
176+
if not isinstance(val, list) or not all(isinstance(item, float) for item in val):
177+
raise ValueError(f"type of variable '{var_def.key}' should be array<float>")
156178

157179
def _format_normal_messages(
158180
self,
@@ -217,7 +239,7 @@ def _render_text_content(
217239
) -> str:
218240
if template_type == TemplateType.NORMAL:
219241
# Create custom Environment using DebugUndefined to preserve original form of undefined variables
220-
env = Environment(
242+
env = SandboxedEnvironment(
221243
loader=BaseLoader(),
222244
undefined=CustomUndefined,
223245
variable_start_string='{{',
@@ -230,10 +252,20 @@ def _render_text_content(
230252
render_vars = {k: variables.get(k, '') for k in variable_def_map.keys()}
231253
# Render template
232254
return template.render(**render_vars)
255+
elif template_type == TemplateType.JINJA2:
256+
return self._render_jinja2_template(template_str, variable_def_map, variables)
233257
else:
234258
raise ValueError(f"text render unsupported template type: {template_type}")
235259

236260

261+
def _render_jinja2_template(self, template_str: str, variable_def_map: Dict[str, VariableDef],
262+
variables: Dict[str, Any]) -> str:
263+
"""渲染 Jinja2 模板"""
264+
env = SandboxedEnvironment()
265+
template = env.from_string(template_str)
266+
render_vars = {k: variables[k] for k in variable_def_map.keys() if variables is not None and k in variables}
267+
return template.render(**render_vars)
268+
237269
class CustomUndefined(Undefined):
238270
__slots__ = ()
239271

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2+
# SPDX-License-Identifier: MIT
3+
4+
import json
5+
import time
6+
from typing import List
7+
8+
import cozeloop
9+
from cozeloop import Message
10+
from cozeloop.entities.prompt import Role
11+
from cozeloop.spec.tracespec import CALL_OPTIONS, ModelCallOption, ModelMessage, ModelInput
12+
13+
14+
def convert_model_input(messages: List[Message]) -> ModelInput:
15+
model_messages = []
16+
for message in messages:
17+
model_messages.append(ModelMessage(
18+
role=str(message.role),
19+
content=message.content if message.content is not None else ""
20+
))
21+
22+
return ModelInput(
23+
messages=model_messages
24+
)
25+
26+
27+
class LLMRunner:
28+
def __init__(self, client):
29+
self.client = client
30+
31+
def llm_call(self, input_data):
32+
"""
33+
Simulate an LLM call and set relevant span tags.
34+
"""
35+
span = self.client.start_span("llmCall", "model")
36+
try:
37+
# Assuming llm is processing
38+
# output = ChatOpenAI().invoke(input=input_data)
39+
40+
# mock resp
41+
time.sleep(1)
42+
output = "I'm a robot. I don't have a specific name. You can give me one."
43+
input_token = 232
44+
output_token = 1211
45+
46+
# set tag key: `input`
47+
span.set_input(convert_model_input(input_data))
48+
# set tag key: `output`
49+
span.set_output(output)
50+
# set tag key: `model_provider`, e.g., openai, etc.
51+
span.set_model_provider("openai")
52+
# set tag key: `start_time_first_resp`
53+
# Timestamp of the first packet return from LLM, unit: microseconds.
54+
# When `start_time_first_resp` is set, a tag named `latency_first_resp` calculated
55+
# based on the span's StartTime will be added, meaning the latency for the first packet.
56+
span.set_start_time_first_resp(int(time.time() * 1000000))
57+
# set tag key: `input_tokens`. The amount of input tokens.
58+
# when the `input_tokens` value is set, it will automatically sum with the `output_tokens` to calculate the `tokens` tag.
59+
span.set_input_tokens(input_token)
60+
# set tag key: `output_tokens`. The amount of output tokens.
61+
# when the `output_tokens` value is set, it will automatically sum with the `input_tokens` to calculate the `tokens` tag.
62+
span.set_output_tokens(output_token)
63+
# set tag key: `model_name`, e.g., gpt-4-1106-preview, etc.
64+
span.set_model_name("gpt-4-1106-preview")
65+
span.set_tags({CALL_OPTIONS: ModelCallOption(
66+
temperature=0.5,
67+
top_p=0.5,
68+
top_k=10,
69+
presence_penalty=0.5,
70+
frequency_penalty=0.5,
71+
max_tokens=1024,
72+
)})
73+
74+
return None
75+
except Exception as e:
76+
raise e
77+
finally:
78+
span.finish()
79+
80+
# If you want to use the jinja templates in prompts, you can refer to the following.
81+
if __name__ == '__main__':
82+
# 1.Create a prompt on the platform
83+
# You can create a Prompt on the platform's Prompt development page (set Prompt Key to 'prompt_hub_demo'),
84+
# add the following messages to the template, and submit a version.
85+
# System: You are a helpful bot, the conversation topic is {{var1}}.
86+
# Placeholder: placeholder1
87+
# User: My question is {{var2}}
88+
# Placeholder: placeholder2
89+
90+
# Set the following environment variables first.
91+
# COZELOOP_WORKSPACE_ID=your workspace id
92+
# COZELOOP_API_TOKEN=your token
93+
# 2.New loop client
94+
client = cozeloop.new_client(
95+
# Set whether to report a trace span when get or format prompt.
96+
# Default value is false.
97+
prompt_trace=True)
98+
99+
# 3. new root span
100+
rootSpan = client.start_span("root_span", "main_span")
101+
102+
# 4. Get the prompt
103+
# If no specific version is specified, the latest version of the corresponding prompt will be obtained
104+
prompt = client.get_prompt(prompt_key="prompt_hub_demo", version="0.0.1")
105+
if prompt is not None:
106+
# Get messages of the prompt
107+
if prompt.prompt_template is not None:
108+
messages = prompt.prompt_template.messages
109+
print(
110+
f"prompt messages: {json.dumps([message.model_dump(exclude_none=True) for message in messages], ensure_ascii=False)}")
111+
# Get llm config of the prompt
112+
if prompt.llm_config is not None:
113+
llm_config = prompt.llm_config
114+
print(f"prompt llm_config: {llm_config.model_dump_json(exclude_none=True)}")
115+
116+
# 5.Format messages of the prompt
117+
formatted_messages = client.prompt_format(prompt, {
118+
"var_string": "hi",
119+
"var_int": 5,
120+
"var_bool": True,
121+
"var_float": 1.0,
122+
"var_object": {
123+
"name": "John",
124+
"age": 30,
125+
"hobbies": ["reading", "coding"],
126+
"address": {
127+
"city": "bejing",
128+
"street": "123 Main",
129+
},
130+
},
131+
"var_array_string": ["hello", "nihao"],
132+
"var_array_boolean": [True, False, True],
133+
"var_array_int": [1, 2, 3, 4],
134+
"var_array_float": [1.0, 2.0],
135+
"var_array_object": [{"key": "123"}, {"value": 100}],
136+
# Placeholder variable type should be Message/List[Message]
137+
"placeholder1": [Message(role=Role.USER, content="Hello!"),
138+
Message(role=Role.ASSISTANT, content="Hello!")]
139+
# Other variables in the prompt template that are not provided with corresponding values will be
140+
# considered as empty values.
141+
})
142+
print(
143+
f"formatted_messages: {json.dumps([message.model_dump(exclude_none=True) for message in formatted_messages], ensure_ascii=False)}")
144+
145+
# 6.LLM call
146+
llm_runner = LLMRunner(client)
147+
llm_runner.llm_call(formatted_messages)
148+
149+
rootSpan.finish()
150+
# 4. (optional) flush or close
151+
# -- force flush, report all traces in the queue
152+
# Warning! In general, this method is not needed to be call, as spans will be automatically reported in batches.
153+
# Note that flush will block and wait for the report to complete, and it may cause frequent reporting,
154+
# affecting performance.
155+
client.flush()

0 commit comments

Comments
 (0)