Skip to content

Commit 7d230eb

Browse files
authored
add: Google-genai for Gemini models (#58)
* chore: mv ollama_pull to scripts * fix: openai ttc models * use google.genai sdk for gemini models * fix: google-genai sdk version * remove max_token for max performance * use direnv * chore: remove unused docstring
1 parent b7a8c7f commit 7d230eb

11 files changed

Lines changed: 172 additions & 112 deletions

File tree

.envrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
export PYTHONPATH=$PYTHONPATH:`pwd`

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ uv sync # create a virtual environment, and install dependencies
2121
This script will build the benchmark (Prelude with NL) from the raw data.
2222

2323
```sh
24-
uv run --project . scripts/preprocess_benchmark.py
24+
uv run scripts/preprocess_benchmark.py
2525
```
2626

2727
### TF-Bench_pure
@@ -68,7 +68,7 @@ We use [Ollama](https://ollama.com/) to manage and run the OSS models.
6868
```sh
6969
curl -fsSL https://ollama.com/install.sh | sh # install ollama, you need sudo for this
7070
ollama serve # start your own instance instead of a system service
71-
uv run ollama_pull.sh # install required models
71+
uv run --project . scripts/ollama_pull.sh # install required models
7272
```
7373

7474
```sh

main.py

Lines changed: 50 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -10,115 +10,102 @@
1010
from openai import OpenAI
1111
from ollama import Client as OllamaClient
1212
from anthropic import Anthropic
13+
from google import genai
1314

1415
import fire
1516
from src.common import (
1617
BenchmarkTask,
17-
SEED,
18-
TEMPERATURE,
1918
get_prompt,
2019
)
2120

2221
from src.experiment import (
23-
O1_MODELS,
24-
GPT_MODELS,
22+
OAI_MODELS,
23+
OAI_TTC_MODELS,
2524
CLAUDE_MODELS,
25+
CLAUDE_TTC_MODELS,
2626
DEEPSEEK_MODELS,
2727
GEMINI_MODELS,
28+
GEMINI_TTC_MODELS,
2829
get_ant_model,
2930
get_ant_ttc_model,
3031
get_oai_model,
31-
get_o1_model,
32+
get_oai_ttc_model,
33+
get_gemini_model,
34+
get_gemini_ttc_model,
3235
)
33-
from src.experiment_ollama import OLLAMA_MODELS, get_model as get_ollama_model
36+
from src.experiment_ollama import OLLAMA_MODELS, get_ollama_model
3437
from src.postprocessing import postprocess, RESPONSE_STRATEGIES
3538
from src.evaluation import evaluate
3639

3740

3841
def main(
39-
input_file: str = "Benchmark-F.removed.json",
40-
output_file: str | None = None,
41-
log_file: str | None = None,
42-
full_type: bool = True,
43-
model: str = "gpt-3.5-turbo",
44-
seed: int = SEED,
45-
temperature: float = TEMPERATURE,
42+
model: str,
4643
port: int = 11434,
4744
pure: bool = False,
48-
reasoning: bool = False,
45+
thinking_budget: int = 1000,
46+
output_file: str | None = None,
47+
log_file: str = "evaluation_log.jsonl",
4948
):
5049
"""
5150
Run an experiment using various AI models to generate and evaluate type signatures.
5251
5352
Parameters:
54-
input_file (str): Path to the input JSON file containing benchmark tasks.
55-
Default is "Benchmark-F.removed.json".
56-
57-
output_file (str | None): Path to the output file where generated type signatures will be saved.
58-
If None, the output will be saved to "result/{model}.txt". Default is None.
59-
60-
log_file (str | None): Path to the log file where evaluation metrics will be appended.
61-
If None, defaults to "evaluation_log.jsonl". Default is None.
62-
63-
full_type (bool): Determines whether to ask the model to predict the full type signature in the prompt.
64-
If True, the model will be asked to complete full type signature.
65-
If False, the model will be asked to complete the return type in type signature. Default is True.
66-
6753
model (str): Name of the model to use for generating type signatures. Must be one of:
6854
- GPT_MODELS: ["gpt-3.5-turbo-0125", "gpt-4-turbo-2024-04-09", ...]
6955
- OLLAMA_MODELS, CLAUDE_MODELS, or O1_MODELS.
7056
Default is "gpt-3.5-turbo".
7157
72-
seed (int): Random seed to ensure reproducibility in experiments. Default is 0.
73-
74-
temperature (float): Sampling temperature for the model's outputs. Higher values
75-
produce more diverse outputs. Default is 0.0 (deterministic outputs).
76-
7758
port (int): Port number for connecting to the Ollama server (if using Ollama models).
7859
Ignored for other models. Default is 11434.
7960
8061
pure (bool): If True, uses the original variable naming in type inference.
8162
If False, uses rewritten variable naming (e.g., `v1`, `v2`, ...). Default is False.
82-
83-
reasoning (bool): If True, uses the reasoning prompt for the model. NOTE: this is not for claude-3-7-sonnet.
8463
"""
8564
assert (
8665
model
87-
in GPT_MODELS
66+
in OAI_MODELS
67+
+ OAI_TTC_MODELS
8868
+ OLLAMA_MODELS
89-
+ CLAUDE_MODELS
90-
+ O1_MODELS
9169
+ DEEPSEEK_MODELS
70+
+ CLAUDE_MODELS
71+
+ CLAUDE_TTC_MODELS
9272
+ GEMINI_MODELS
73+
+ GEMINI_TTC_MODELS
9374
), f"{model} is not supported."
9475

76+
# hard-coding benchmark file path for experiment
77+
input_file = "tfb.pure.json" if pure else "tfb.json"
78+
input_file = os.path.abspath(input_file)
79+
assert os.path.exists(
80+
input_file
81+
), f"{input_file} does not exist! Please download or build it first."
82+
9583
if output_file is None:
9684
os.makedirs("result", exist_ok=True)
9785
output_file = f"result/{model}.txt"
9886

99-
if log_file is None:
100-
log_file = "evaluation_log.jsonl"
101-
102-
client: OpenAI | Anthropic | OllamaClient
87+
client: OpenAI | Anthropic | OllamaClient | genai.Client
10388
generate: Callable[[str], str | None]
10489

105-
if model in GPT_MODELS:
90+
if model in OAI_MODELS:
10691
assert "OPENAI_API_KEY" in os.environ, "Please set OPEN_API_KEY in environment!"
10792
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
10893
generate = get_oai_model(client, model, pure)
109-
elif model in O1_MODELS:
94+
95+
elif model in OAI_TTC_MODELS:
11096
assert "OPENAI_API_KEY" in os.environ, "Please set OPEN_API_KEY in environment!"
11197
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
112-
generate = get_o1_model(client, model, pure)
98+
generate = get_oai_ttc_model(client, model, pure)
11399
elif model in CLAUDE_MODELS:
114100
assert (
115101
"ANTHROPIC_API_KEY" in os.environ
116102
), "Please set ANTHROPIC_API_KEY in environment!"
117103
client = Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
118-
if reasoning:
119-
generate = get_ant_ttc_model(client, model, pure)
120-
else:
121-
generate = get_ant_model(client, model, pure)
104+
generate = get_ant_model(client, model, pure)
105+
elif model in CLAUDE_TTC_MODELS:
106+
client = Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
107+
generate = get_ant_ttc_model(client, model, pure, thinking_budget)
108+
122109
elif model in DEEPSEEK_MODELS:
123110
assert (
124111
"DEEPSEEK_API_KEY" in os.environ
@@ -127,23 +114,28 @@ def main(
127114
api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.deepseek.com"
128115
)
129116
generate = get_oai_model(client, model, pure)
117+
130118
elif model in GEMINI_MODELS:
131119
assert (
132-
"GEMINI_API_KEY" in os.environ
133-
), "Please set GEMINI_API_KEY in environment!"
134-
client = OpenAI(
135-
api_key=os.environ["GEMINI_API_KEY"],
136-
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
137-
)
138-
generate = get_oai_model(client, model)
120+
"GOOGLE_API_KEY" in os.environ
121+
), "Please set GOOGLE_API_KEY in environment!"
122+
client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])
123+
generate = get_gemini_model(client, model, pure)
124+
elif model in GEMINI_TTC_MODELS:
125+
assert (
126+
"GOOGLE_API_KEY" in os.environ
127+
), "Please set GOOGLE_API_KEY in environment!"
128+
client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])
129+
generate = get_gemini_ttc_model(client, model, pure, thinking_budget)
130+
139131
else:
140132
client = OllamaClient(host=f"http://localhost:{port}")
141-
generate = get_ollama_model(client, model, seed, temperature, pure)
133+
generate = get_ollama_model(client, model, pure)
142134

143135
with open(input_file, "r") as fp:
144136
tasks = [from_dict(data_class=BenchmarkTask, data=d) for d in json.load(fp)]
145137

146-
prompts = lmap(lambda x: get_prompt(x, full_type), tasks)
138+
prompts = lmap(get_prompt, tasks)
147139
responses = lmap(generate, tqdm(prompts, desc=model))
148140
gen_results = (
149141
Chain(responses)
@@ -161,7 +153,7 @@ def main(
161153

162154
os.makedirs(os.path.dirname(output_file), exist_ok=True)
163155
with open(log_file, "a") as fp:
164-
logging_result = {"model_name": model, **eval_acc}
156+
logging_result = {"model_name": model, **eval_acc, "pure": pure}
165157
fp.write(f"{json.dumps(logging_result)}\n")
166158

167159

pyproject.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ dependencies = [
1010
"fire==0.5.0",
1111
"funcy==2.0",
1212
"funcy-chain==0.2.0",
13-
"google-genai>=1.2.0",
13+
"google-genai>=1.11.0",
1414
"groq==0.8.0",
1515
"hypothesis>=6.98.6",
1616
"markdown-to-json==2.1.2",
1717
"matplotlib>=3.8.3",
1818
"numpy>=1.26.4",
19-
"ollama==0.2.1",
20-
"openai==1.30.5",
19+
"ollama>=0.2.1",
20+
"openai==1.75.0",
2121
"pathos>=0.3.3",
2222
"pylint>=3.3.6",
2323
"pytest>=8.0.0",
@@ -35,3 +35,5 @@ dependencies = [
3535
[tool.pytest.ini_options]
3636
pythonpath = ["."]
3737

38+
[tool.uv]
39+
package = true
File renamed without changes.

scripts/result_token_stat.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import tiktoken
2+
import fire
3+
from dacite import from_dict
4+
import json
5+
from funcy_chain import Chain
6+
import pandas
7+
8+
from src.common import BenchmarkTask, get_prompt
9+
10+
11+
def main(input_file="tfb.json"):
12+
with open(input_file, "r") as fp:
13+
tasks = [from_dict(data_class=BenchmarkTask, data=d) for d in json.load(fp)]
14+
# count the max, min, and average token length of the task.code
15+
16+
enc = tiktoken.encoding_for_model("gpt-4o")
17+
token_counts = [len(enc.encode(task.signature)) for task in tasks]
18+
df = pandas.DataFrame(token_counts, columns=["token_count"])
19+
print(f"max: {df.token_count.max()}")
20+
print(f"min: {df.token_count.min()}")
21+
print(f"avg: {df.token_count.mean()}")
22+
23+
if __name__ == "__main__":
24+
fire.Fire(main)

src/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import logging
2+
3+
logging.getLogger("openai").setLevel(logging.ERROR)
4+
logging.getLogger("httpx").setLevel(logging.ERROR)

src/common.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import markdown_to_json
99

1010
# Default hyper-parameters
11-
SEED = 0
12-
TEMPERATURE = 0.0
11+
MAX_TOKENS = 1024
1312

1413
SYSTEM_PROMPT = """
1514
Act as a static analysis tool for type inference.
@@ -105,11 +104,9 @@ def remove_return_type(sig: str) -> str:
105104
return sig
106105

107106

108-
def get_prompt(task: BenchmarkTask, full_type: bool = True) -> str:
107+
def get_prompt(task: BenchmarkTask) -> str:
109108
"""get prompt from a task instance"""
110109

111-
signature = "" if full_type else remove_return_type(task.signature)
112-
113110
fn_name = extract_function_name(task)
114111
assert fn_name is not None
115112

@@ -125,7 +122,7 @@ def get_prompt(task: BenchmarkTask, full_type: bool = True) -> str:
125122
\n\n
126123
{code}
127124
--complete the following type signature for '{fn_name}'
128-
{fn_name} :: {signature}
125+
{fn_name} ::
129126
"""
130127
return prompt
131128

0 commit comments

Comments
 (0)