Skip to content

Commit 22a2973

Browse files
committed
got llm querying working on container
1 parent 7e69e24 commit 22a2973

12 files changed

Lines changed: 6683 additions & 0 deletions

llm-querying/__init__.py

Whitespace-only changes.

llm-querying/agent.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# make a langchain graph instance with the model
2+
from typing_extensions import TypedDict, List, Annotated, Literal
3+
from dataset_and_llm import llm
4+
from prompts import make_prompt, FLOPCounts
5+
from io_cost import get_query_cost
6+
import operator
7+
from langgraph.graph import StateGraph, END
8+
from langgraph.checkpoint.sqlite import SqliteSaver
9+
from langchain.schema import AIMessage
10+
from configuration import Configuration
11+
import sqlite3
12+
13+
class BaselineQueryState(TypedDict, total=False):
14+
source_code: str
15+
combined_name: str
16+
kernel_name: str
17+
exec_args: str
18+
grid_size: str
19+
block_size: str
20+
total_num_threads: str
21+
compile_commands: str
22+
23+
empirical_sp_flop_count: float
24+
empirical_dp_flop_count: float
25+
26+
prompt_type: Literal["simple", "full"]
27+
28+
raw_flop_counts: Annotated[List[AIMessage], operator.add]
29+
30+
predicted_sp_flop_count: int
31+
predicted_dp_flop_count: int
32+
predicted_sp_flop_count_explanation: str
33+
predicted_dp_flop_count_explanation: str
34+
35+
input_tokens: Annotated[List[int], operator.add]
36+
output_tokens: Annotated[List[int], operator.add]
37+
total_cost: Annotated[List[float], operator.add]
38+
39+
total_query_time: Annotated[List[float], operator.add]
40+
error: Annotated[List[str], operator.add]
41+
42+
# Calculate the total number of threads from the gridSz and the blockSz
43+
# grid size is a string of format "(x, y, z)"
44+
# block size is a string of format "(x, y, z)"
45+
def calc_total_threads(gridSz:str, blockSz:str):
46+
gridSz = eval(gridSz)
47+
blockSz = eval(blockSz)
48+
total_threads = gridSz[0] * gridSz[1] * gridSz[2] * blockSz[0] * blockSz[1] * blockSz[2]
49+
return str(total_threads)
50+
51+
def get_input_problem(state: BaselineQueryState, config):
52+
verbose = config.get("configurable", {}).get("verbose_printing", False)
53+
54+
row = config.get("configurable", {}).get("input_problem_row", None)
55+
56+
prompt_type = config.get("configurable", {}).get("prompt_type", "simple")
57+
58+
combined_name = row['combined_name']
59+
60+
assert row is not None, f"Target problem '{combined_name}' not found in the dataset."
61+
62+
if verbose:
63+
print("---------- BEGIN STEP 0: GET INPUT PROBLEM ----------", flush=True)
64+
65+
to_return = {'source_code' : row['source_code'],
66+
'combined_name' : combined_name,
67+
'kernel_name' : row['Kernel Name'],
68+
'exec_args' : row['exeArgs'],
69+
'grid_size' : row['Grid Size'],
70+
'block_size' : row['Block Size'],
71+
'total_num_threads' : calc_total_threads(row['Grid Size'], row['Block Size']),
72+
'compile_commands' : row['compile_commands'],
73+
# these "true" values do not get passed to the LLMs
74+
# they are used to calculate how close the LLM prediction is to the ground-truth
75+
'empirical_sp_flop_count' : row['SP_FLOP'],
76+
'empirical_dp_flop_count' : row['DP_FLOP'],
77+
'prompt_type' : prompt_type
78+
}
79+
80+
if verbose:
81+
for k, v in to_return.items():
82+
if k != "source_code":
83+
print(f"\t{k}: {v}", flush=True)
84+
print("---------- END STEP 0: GET INPUT PROBLEM ----------", flush=True)
85+
86+
return to_return
87+
88+
89+
def query_for_flop_count(state: BaselineQueryState, config):
90+
verbose = config.get("configurable", {}).get("verbose_printing", False)
91+
92+
configured_llm = llm.with_config(configurable=config.get("configurable", {})).with_structured_output(FLOPCounts, include_raw=True)
93+
94+
prompt = make_prompt(state['prompt_type'])
95+
96+
chain = prompt | configured_llm
97+
98+
if verbose:
99+
print("---------- BEGIN STEP 1: QUERY FOR FLOP COUNT ----------", flush=True)
100+
print(f"\tQuerying for FLOP count of kernel: {state['combined_name']}", flush=True)
101+
102+
result = chain.invoke({
103+
"source_code": state['source_code'],
104+
"kernel_name": state['kernel_name'],
105+
"exec_args": state['exec_args'],
106+
"grid_size": state['grid_size'],
107+
"block_size": state['block_size'],
108+
"total_num_threads": state['total_num_threads'],
109+
"compile_commands": state['compile_commands']
110+
})
111+
112+
parsed_result = result['parsed']
113+
114+
if verbose:
115+
result['raw'].pretty_print()
116+
# check if the sp_flop_count attributes are present and not None
117+
if parsed_result.sp_flop_count is not None and parsed_result.dp_flop_count is not None:
118+
print(f"\tGot an LLM response!: \n\tSP_FLOP:[{parsed_result.sp_flop_count}], \n\tDP_FLOP:[{parsed_result.dp_flop_count}]\n", flush=True)
119+
120+
query_cost = get_query_cost(result['raw'], verbose)
121+
122+
return query_cost | {'predicted_sp_flop_count': parsed_result.sp_flop_count,
123+
'predicted_dp_flop_count': parsed_result.dp_flop_count,
124+
'predicted_sp_flop_count_explanation': parsed_result.sp_flop_explanation,
125+
'predicted_dp_flop_count_explanation': parsed_result.dp_flop_explanation,
126+
'raw_flop_counts': [result['raw']]
127+
}
128+
129+
130+
def make_graph(sqlite_db_path: str):
131+
# now let's set up the StateGraph to represent the agent
132+
workflow = StateGraph(BaselineQueryState, context_schema=Configuration)
133+
workflow.add_node("get_input_problem_0", get_input_problem)
134+
workflow.add_node("query_for_flop_count_1", query_for_flop_count)
135+
136+
workflow.add_edge("get_input_problem_0", "query_for_flop_count_1")
137+
workflow.add_edge("query_for_flop_count_1", END)
138+
139+
workflow.set_entry_point("get_input_problem_0")
140+
141+
# let's also add a checkpointer to save intermediate results
142+
# sqlite_db_path: path to sqlite database used by SqliteSaver to persist graph checkpoints
143+
conn = sqlite3.connect(sqlite_db_path, check_same_thread=False)
144+
checkpointer = SqliteSaver(conn)
145+
graph = workflow.compile(checkpointer=checkpointer)
146+
147+
return graph
148+
149+
150+

llm-querying/configuration.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from pydantic import BaseModel, Field
2+
from typing import Annotated, Literal
3+
4+
llm_nodes = [
5+
"query_for_flop_count_1"
6+
]
7+
8+
all_nodes = llm_nodes + ["get_input_problem_0"]
9+
10+
class Configuration(BaseModel):
11+
temp : float = Field(default=0.2,
12+
description="The temperature to use for the LLM. Higher values make the output more random, lower values make it more deterministic.",
13+
json_schema_extra={"langgraph_nodes": llm_nodes}
14+
)
15+
16+
top_p : float = Field(default=0.1,
17+
description="The top_p value to use for the LLM. Higher values make the output more random, lower values make it more deterministic. This is used in conjunction with temperature to control the randomness of the output.",
18+
json_schema_extra={"langgraph_nodes": llm_nodes}
19+
)
20+
21+
provider_url : str = Field(default="https://openrouter.com/api/v1",
22+
description="The URL of the provider's API endpoint. This is used to connect to the LLM provider.",
23+
json_schema_extra={"langgraph_nodes": llm_nodes}
24+
)
25+
26+
provider_api_key: str = Field(default="",
27+
description="The API key for the LLM provider. This is used to authenticate requests to the provider's API.",
28+
json_schema_extra={"langgraph_nodes": llm_nodes}
29+
)
30+
31+
api_version: str = Field(default="",
32+
description="(Azure only) The API version to use when connecting to the Azure OpenAI service.",
33+
json_schema_extra={"langgraph_nodes": llm_nodes}
34+
)
35+
model: Annotated[
36+
Literal[
37+
"openai/gpt-4.1-nano", # in $0.1 out $0.4
38+
"openai/gpt-4.1-mini", # in $0.4 out $1.6
39+
"openai/gpt-4o-mini", # in $0.15 out $0.6
40+
"openai/o4-mini-high", # in $1.1 out $4.4
41+
"openai/o4-mini", # in $1.1 out $4.4
42+
"openai/o3-mini-high", # in $1.1 out $4.4
43+
"openai/o3-mini", # in $1.1 out $4.4
44+
"google/gemini-flash-1.5", # in $0.075 out $0.3
45+
"google/gemini-2.0-flash-lite-001", # in $0.075 out $0.3
46+
"google/gemini-2.0-flash-001", # in $0.1 out $0.4
47+
"google/gemini-2.5-flash", # in $0.3 out $2.5
48+
"anthropic/claude-3.5-haiku" # in $0.8 out $4.0
49+
"gpt-5-mini", # in $
50+
],
51+
{"__template_metadata__": {"kind": "llm"}},
52+
] = Field(
53+
default="openai/gpt-4.1-mini",
54+
#default="openai/o3-mini",
55+
description="The name of the language model to use for the agent's main interactions. "
56+
"Should be in the form: provider/model-name.",
57+
json_schema_extra={"langgraph_nodes": llm_nodes},
58+
)
59+
60+
verbose_printing: bool = Field(
61+
default=False,
62+
description="If True, the agent will print detailed information about each step of the analysis.",
63+
json_schema_extra={"langgraph_nodes": all_nodes},
64+
)

llm-querying/dataset_and_llm.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import pandas as pd
2+
3+
from langchain_openai import ChatOpenAI, AzureChatOpenAI
4+
from langchain_core.runnables import ConfigurableField
5+
6+
import os
7+
import csv
8+
9+
10+
GPU_FLOPBENCH_ROOT = os.environ.get('GPU_FLOPBENCH_ROOT')
11+
print(GPU_FLOPBENCH_ROOT)
12+
13+
hard_dataset_path = os.path.join(GPU_FLOPBENCH_ROOT, 'dataset-creation', 'hard_kernels_to_inference_unbalanced_with_compile_commands.csv')
14+
print('hard_dataset_path', hard_dataset_path)
15+
hard_df_to_query = pd.read_csv(hard_dataset_path, quotechar='"', quoting=csv.QUOTE_NONNUMERIC)
16+
17+
easy_dataset_path = os.path.join(GPU_FLOPBENCH_ROOT, 'dataset-creation', 'kernels_to_inference_balanced_with_compile_commands.csv')
18+
print('easy_dataset_path', easy_dataset_path)
19+
easy_df_to_query = pd.read_csv(easy_dataset_path, quotechar='"', quoting=csv.QUOTE_NONNUMERIC)
20+
21+
try:
22+
# for some reason, the AzureChatOpenAI class fails to initialize properly
23+
# because it seems like it tries to reach out to the node to get metadata or check alive state
24+
# if the node is not set up for a particular model, we get a 404 error
25+
# we put this guard here to avoid erroring out when we are not using Azure
26+
azureModel = AzureChatOpenAI(
27+
openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"),
28+
azure_endpoint="https://galor-m8yvytc2-swedencentral.cognitiveservices.azure.com",
29+
openai_api_version="2025-04-01-preview",
30+
temperature=1,
31+
top_p=1,
32+
model_name="gpt-5-mini",
33+
timeout=120,
34+
).configurable_fields(
35+
model_name=ConfigurableField(
36+
id="model",
37+
),
38+
temperature=ConfigurableField(
39+
id="temp",
40+
),
41+
top_p=ConfigurableField(
42+
id="top_p",
43+
),
44+
azure_endpoint=ConfigurableField(
45+
id="provider_url",
46+
),
47+
openai_api_key=ConfigurableField(
48+
id="provider_api_key",
49+
),
50+
openai_api_version=ConfigurableField(
51+
id="api_version",
52+
),
53+
request_timeout=ConfigurableField(
54+
id="timeout"
55+
)
56+
)
57+
except Exception as e:
58+
print(f"Azure model could not be setup correctly! Falling back to OpenAI model in its place.", flush=True)
59+
print(f"Error: {e}", flush=True)
60+
61+
azureModel = ChatOpenAI()
62+
63+
64+
openrouterModel = ChatOpenAI(
65+
openai_api_key=os.getenv("OPENAI_API_KEY"),
66+
openai_api_base="https://openrouter.ai/api/v1",
67+
temperature=0.2,
68+
top_p=0.1,
69+
model_name="openai/gpt-5-mini",
70+
timeout=120,
71+
).configurable_fields(
72+
model_name=ConfigurableField(
73+
id="opr_model",
74+
),
75+
temperature=ConfigurableField(
76+
id="opr_temp",
77+
),
78+
top_p=ConfigurableField(
79+
id="opr_top_p",
80+
),
81+
openai_api_base=ConfigurableField(
82+
id="opr_provider_url",
83+
),
84+
openai_api_key=ConfigurableField(
85+
id="opr_provider_api_key",
86+
),
87+
request_timeout=ConfigurableField(
88+
id="opr_timeout"
89+
)
90+
)
91+
92+
llm = openrouterModel.configurable_alternatives(
93+
ConfigurableField(id="llm"),
94+
default_key="openai",
95+
azure=azureModel
96+
)

llm-querying/doAllBaselineRuns.sh

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
3+
#python3 ./run_llm_queries.py --skipConfirm --modelName openai/gpt-5-mini --numTrials 3 --verbose 2>&1 | tee -a ./gpt-5-mini-simplePrompt.log
4+
#python3 ./run_llm_queries.py --skipConfirm --modelName openai/gpt-5-mini --useFullPrompt --numTrials 3 --verbose 2>&1 | tee -a ./gpt-5-mini-fullPrompt.log
5+
#python3 ./run_llm_queries.py --hardDataset --useAzure --api_version 2025-04-01-preview --provider_url https://galor-m8yvytc2-swedencentral.cognitiveservices.azure.com --skipConfirm --modelName gpt-5-mini --useFullPrompt --numTrials 3 --top_p 1.0 --temp 1.0 --verbose 2>&1 | tee -a ./gpt-5-mini-fullPrompt-hardDataset.log
6+
7+
8+
#python3 ./run_llm_queries.py --hardDataset --useAzure --api_version 2025-04-01-preview --provider_url https://galor-m8yvytc2-swedencentral.cognitiveservices.azure.com --skipConfirm --modelName gpt-5-mini --numTrials 3 --top_p 1.0 --temp 1.0 --verbose 2>&1 | tee -a ./gpt-5-mini-simplePrompt-hardDataset.log
9+
10+
# 4o-mini
11+
# https://galor-m8yvytc2-swedencentral.cognitiveservices.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2025-01-01-preview
12+
13+
# o1-mini -- doesn't support system messages -- can't work on our platform
14+
# https://galor-m8yvytc2-swedencentral.cognitiveservices.azure.com/openai/deployments/o1-mini/chat/completions?api-version=2025-01-01-preview
15+
16+
# o3-mini
17+
# https://galor-m8yvytc2-swedencentral.cognitiveservices.azure.com/openai/deployments/o3-mini/chat/completions?api-version=2025-01-01-preview
18+
19+
20+
21+
# 4o-mini
22+
python3 ./run_llm_queries.py --useAzure --api_version 2025-01-01-preview --provider_url https://galor-m8yvytc2-swedencentral.cognitiveservices.azure.com --skipConfirm --modelName gpt-4o-mini --numTrials 3 --top_p 0.5 --temp 0.2 --verbose 2>&1 | tee -a ./gpt-4o-mini-simplePrompt-easyDataset.log
23+
python3 ./run_llm_queries.py --useAzure --api_version 2025-01-01-preview --provider_url https://galor-m8yvytc2-swedencentral.cognitiveservices.azure.com --skipConfirm --modelName gpt-4o-mini --numTrials 3 --top_p 0.5 --temp 0.2 --verbose --hardDataset 2>&1 | tee -a ./gpt-4o-mini-simplePrompt-hardDataset.log
24+
25+
# o3-mini
26+
python3 ./run_llm_queries.py --useAzure --api_version 2025-01-01-preview --provider_url https://galor-m8yvytc2-swedencentral.cognitiveservices.azure.com --skipConfirm --modelName o3-mini --numTrials 3 --top_p 1.0 --temp 1.0 --verbose 2>&1 | tee -a ./o3-mini-simplePrompt-easyDataset.log
27+
python3 ./run_llm_queries.py --useAzure --api_version 2025-01-01-preview --provider_url https://galor-m8yvytc2-swedencentral.cognitiveservices.azure.com --skipConfirm --modelName o3-mini --numTrials 3 --top_p 1.0 --temp 1.0 --verbose --hardDataset 2>&1 | tee -a ./o3-mini-simplePrompt-hardDataset.log

0 commit comments

Comments
 (0)