1010from openai import OpenAI
1111from ollama import Client as OllamaClient
1212from anthropic import Anthropic
13+ from google import genai
1314
1415import fire
1516from src .common import (
1617 BenchmarkTask ,
17- SEED ,
18- TEMPERATURE ,
1918 get_prompt ,
2019)
2120
2221from src .experiment import (
23- O1_MODELS ,
24- GPT_MODELS ,
22+ OAI_MODELS ,
23+ OAI_TTC_MODELS ,
2524 CLAUDE_MODELS ,
25+ CLAUDE_TTC_MODELS ,
2626 DEEPSEEK_MODELS ,
2727 GEMINI_MODELS ,
28+ GEMINI_TTC_MODELS ,
2829 get_ant_model ,
2930 get_ant_ttc_model ,
3031 get_oai_model ,
31- get_o1_model ,
32+ get_oai_ttc_model ,
33+ get_gemini_model ,
34+ get_gemini_ttc_model ,
3235)
33- from src .experiment_ollama import OLLAMA_MODELS , get_model as get_ollama_model
36+ from src .experiment_ollama import OLLAMA_MODELS , get_ollama_model
3437from src .postprocessing import postprocess , RESPONSE_STRATEGIES
3538from src .evaluation import evaluate
3639
3740
3841def main (
39- input_file : str = "Benchmark-F.removed.json" ,
40- output_file : str | None = None ,
41- log_file : str | None = None ,
42- full_type : bool = True ,
43- model : str = "gpt-3.5-turbo" ,
44- seed : int = SEED ,
45- temperature : float = TEMPERATURE ,
42+ model : str ,
4643 port : int = 11434 ,
4744 pure : bool = False ,
48- reasoning : bool = False ,
45+ thinking_budget : int = 1000 ,
46+ output_file : str | None = None ,
47+ log_file : str = "evaluation_log.jsonl" ,
4948):
5049 """
5150 Run an experiment using various AI models to generate and evaluate type signatures.
5251
5352 Parameters:
54- input_file (str): Path to the input JSON file containing benchmark tasks.
55- Default is "Benchmark-F.removed.json".
56-
57- output_file (str | None): Path to the output file where generated type signatures will be saved.
58- If None, the output will be saved to "result/{model}.txt". Default is None.
59-
60- log_file (str | None): Path to the log file where evaluation metrics will be appended.
61- If None, defaults to "evaluation_log.jsonl". Default is None.
62-
63- full_type (bool): Determines whether to ask the model to predict the full type signature in the prompt.
64- If True, the model will be asked to complete full type signature.
65- If False, the model will be asked to complete the return type in type signature. Default is True.
66-
6753 model (str): Name of the model to use for generating type signatures. Must be one of:
6854 - GPT_MODELS: ["gpt-3.5-turbo-0125", "gpt-4-turbo-2024-04-09", ...]
6955 - OLLAMA_MODELS, CLAUDE_MODELS, or O1_MODELS.
7056 Default is "gpt-3.5-turbo".
7157
72- seed (int): Random seed to ensure reproducibility in experiments. Default is 0.
73-
74- temperature (float): Sampling temperature for the model's outputs. Higher values
75- produce more diverse outputs. Default is 0.0 (deterministic outputs).
76-
7758 port (int): Port number for connecting to the Ollama server (if using Ollama models).
7859 Ignored for other models. Default is 11434.
7960
8061 pure (bool): If True, uses the original variable naming in type inference.
8162 If False, uses rewritten variable naming (e.g., `v1`, `v2`, ...). Default is False.
82-
83- reasoning (bool): If True, uses the reasoning prompt for the model. NOTE: this is not for claude-3-7-sonnet.
8463 """
8564 assert (
8665 model
87- in GPT_MODELS
66+ in OAI_MODELS
67+ + OAI_TTC_MODELS
8868 + OLLAMA_MODELS
89- + CLAUDE_MODELS
90- + O1_MODELS
9169 + DEEPSEEK_MODELS
70+ + CLAUDE_MODELS
71+ + CLAUDE_TTC_MODELS
9272 + GEMINI_MODELS
73+ + GEMINI_TTC_MODELS
9374 ), f"{ model } is not supported."
9475
76+ # hard-coding benchmark file path for experiment
77+ input_file = "tfb.pure.json" if pure else "tfb.json"
78+ input_file = os .path .abspath (input_file )
79+ assert os .path .exists (
80+ input_file
81+ ), f"{ input_file } does not exist! Please download or build it first."
82+
9583 if output_file is None :
9684 os .makedirs ("result" , exist_ok = True )
9785 output_file = f"result/{ model } .txt"
9886
99- if log_file is None :
100- log_file = "evaluation_log.jsonl"
101-
102- client : OpenAI | Anthropic | OllamaClient
87+ client : OpenAI | Anthropic | OllamaClient | genai .Client
10388 generate : Callable [[str ], str | None ]
10489
105- if model in GPT_MODELS :
90+ if model in OAI_MODELS :
10691 assert "OPENAI_API_KEY" in os .environ , "Please set OPEN_API_KEY in environment!"
10792 client = OpenAI (api_key = os .environ ["OPENAI_API_KEY" ])
10893 generate = get_oai_model (client , model , pure )
109- elif model in O1_MODELS :
94+
95+ elif model in OAI_TTC_MODELS :
11096 assert "OPENAI_API_KEY" in os .environ , "Please set OPEN_API_KEY in environment!"
11197 client = OpenAI (api_key = os .environ ["OPENAI_API_KEY" ])
112- generate = get_o1_model (client , model , pure )
98+ generate = get_oai_ttc_model (client , model , pure )
11399 elif model in CLAUDE_MODELS :
114100 assert (
115101 "ANTHROPIC_API_KEY" in os .environ
116102 ), "Please set ANTHROPIC_API_KEY in environment!"
117103 client = Anthropic (api_key = os .environ ["ANTHROPIC_API_KEY" ])
118- if reasoning :
119- generate = get_ant_ttc_model (client , model , pure )
120- else :
121- generate = get_ant_model (client , model , pure )
104+ generate = get_ant_model (client , model , pure )
105+ elif model in CLAUDE_TTC_MODELS :
106+ client = Anthropic (api_key = os .environ ["ANTHROPIC_API_KEY" ])
107+ generate = get_ant_ttc_model (client , model , pure , thinking_budget )
108+
122109 elif model in DEEPSEEK_MODELS :
123110 assert (
124111 "DEEPSEEK_API_KEY" in os .environ
@@ -127,23 +114,28 @@ def main(
127114 api_key = os .environ ["DEEPSEEK_API_KEY" ], base_url = "https://api.deepseek.com"
128115 )
129116 generate = get_oai_model (client , model , pure )
117+
130118 elif model in GEMINI_MODELS :
131119 assert (
132- "GEMINI_API_KEY" in os .environ
133- ), "Please set GEMINI_API_KEY in environment!"
134- client = OpenAI (
135- api_key = os .environ ["GEMINI_API_KEY" ],
136- base_url = "https://generativelanguage.googleapis.com/v1beta/openai/" ,
137- )
138- generate = get_oai_model (client , model )
120+ "GOOGLE_API_KEY" in os .environ
121+ ), "Please set GOOGLE_API_KEY in environment!"
122+ client = genai .Client (api_key = os .environ ["GOOGLE_API_KEY" ])
123+ generate = get_gemini_model (client , model , pure )
124+ elif model in GEMINI_TTC_MODELS :
125+ assert (
126+ "GOOGLE_API_KEY" in os .environ
127+ ), "Please set GOOGLE_API_KEY in environment!"
128+ client = genai .Client (api_key = os .environ ["GOOGLE_API_KEY" ])
129+ generate = get_gemini_ttc_model (client , model , pure , thinking_budget )
130+
139131 else :
140132 client = OllamaClient (host = f"http://localhost:{ port } " )
141- generate = get_ollama_model (client , model , seed , temperature , pure )
133+ generate = get_ollama_model (client , model , pure )
142134
143135 with open (input_file , "r" ) as fp :
144136 tasks = [from_dict (data_class = BenchmarkTask , data = d ) for d in json .load (fp )]
145137
146- prompts = lmap (lambda x : get_prompt ( x , full_type ) , tasks )
138+ prompts = lmap (get_prompt , tasks )
147139 responses = lmap (generate , tqdm (prompts , desc = model ))
148140 gen_results = (
149141 Chain (responses )
@@ -161,7 +153,7 @@ def main(
161153
162154 os .makedirs (os .path .dirname (output_file ), exist_ok = True )
163155 with open (log_file , "a" ) as fp :
164- logging_result = {"model_name" : model , ** eval_acc }
156+ logging_result = {"model_name" : model , ** eval_acc , "pure" : pure }
165157 fp .write (f"{ json .dumps (logging_result )} \n " )
166158
167159
0 commit comments