Skip to content

Commit 2102b99

Browse files
committed
add gemini models
1 parent 351fbaf commit 2102b99

2 files changed

Lines changed: 31 additions & 17 deletions

File tree

main.py

100644100755
Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
GPT_MODELS,
2525
CLAUDE_MODELS,
2626
DEEPSEEK_MODELS,
27+
GEMINI_MODELS,
2728
get_ant_model,
2829
get_ant_ttc_model,
2930
get_oai_model,
@@ -83,7 +84,12 @@ def main(
8384
"""
8485
assert (
8586
model
86-
in GPT_MODELS + OLLAMA_MODELS + CLAUDE_MODELS + O1_MODELS + DEEPSEEK_MODELS
87+
in GPT_MODELS
88+
+ OLLAMA_MODELS
89+
+ CLAUDE_MODELS
90+
+ O1_MODELS
91+
+ DEEPSEEK_MODELS
92+
+ GEMINI_MODELS
8793
), f"{model} is not supported."
8894

8995
if output_file is None:
@@ -99,11 +105,11 @@ def main(
99105
if model in GPT_MODELS:
100106
assert "OPENAI_API_KEY" in os.environ, "Please set OPEN_API_KEY in environment!"
101107
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
102-
generate = get_oai_model(client, model, seed, temperature, pure)
108+
generate = get_oai_model(client, model, pure)
103109
elif model in O1_MODELS:
104110
assert "OPENAI_API_KEY" in os.environ, "Please set OPEN_API_KEY in environment!"
105111
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
106-
generate = get_o1_model(client, model, seed, temperature, pure)
112+
generate = get_o1_model(client, model, pure)
107113
elif model in CLAUDE_MODELS:
108114
assert (
109115
"ANTHROPIC_API_KEY" in os.environ
@@ -120,7 +126,16 @@ def main(
120126
client = OpenAI(
121127
api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.deepseek.com"
122128
)
123-
generate = get_oai_model(client, model, seed, temperature, pure)
129+
generate = get_oai_model(client, model, pure)
130+
elif model in GEMINI_MODELS:
131+
assert (
132+
"GEMINI_API_KEY" in os.environ
133+
), "Please set GEMINI_API_KEY in environment!"
134+
client = OpenAI(
135+
api_key=os.environ["GEMINI_API_KEY"],
136+
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
137+
)
138+
generate = get_oai_model(client, model)
124139
else:
125140
client = OllamaClient(host=f"http://localhost:{port}")
126141
generate = get_ollama_model(client, model, seed, temperature, pure)

src/experiment.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from anthropic import Anthropic, InternalServerError
77
from typing import Callable
88

9+
from google import genai
10+
from google.genai import types
11+
912
from src.common import (
1013
SEED,
1114
TEMPERATURE,
@@ -39,12 +42,19 @@
3942
"deepseek-chat",
4043
]
4144

45+
GEMINI_MODELS = [
46+
"gemini-2.5-pro-preview-03-25",
47+
"gemini-2.0-flash",
48+
"gemini-2.0-flash-lite",
49+
"gemini-1.5-flash",
50+
"gemini-1.5-flash-8b",
51+
"gemini-1.5-pro",
52+
]
53+
4254

4355
def get_o1_model(
4456
client: OpenAI,
4557
model: str = "o1-preview-2024-09-12",
46-
seed: int = SEED,
47-
temperature: float = TEMPERATURE,
4858
pure: bool = False,
4959
) -> Callable[[str], str | None]:
5060
def generate_type_signature(prompt: str) -> str | None:
@@ -64,8 +74,6 @@ def generate_type_signature(prompt: str) -> str | None:
6474
def get_oai_model(
6575
client: OpenAI,
6676
model: str = "gpt-3.5-turbo",
67-
seed: int = SEED,
68-
temperature: float = TEMPERATURE,
6977
pure: bool = False,
7078
) -> Callable[[str], str | None]:
7179
def generate_type_signature(prompt: str) -> str | None:
@@ -78,9 +86,6 @@ def generate_type_signature(prompt: str) -> str | None:
7886
{"role": "user", "content": prompt},
7987
],
8088
model=model,
81-
# Set parameters to ensure reproducibility
82-
seed=seed,
83-
temperature=temperature,
8489
)
8590

8691
content = completion.choices[0].message.content
@@ -123,8 +128,6 @@ def generate_type_signature(prompt: str) -> str | None:
123128
def get_ant_model(
124129
client: Anthropic,
125130
model: str = "claude-3-5-sonnet-20240620",
126-
seed: int = SEED,
127-
temperature: float = TEMPERATURE,
128131
pure: bool = False,
129132
) -> Callable[[str], str | None]:
130133
def generate_type_signature(prompt: str) -> str | None:
@@ -136,10 +139,6 @@ def generate_type_signature(prompt: str) -> str | None:
136139
],
137140
model=model,
138141
max_tokens=1024,
139-
# ! the following parameters are not supported by Claude API
140-
# seed=seed,
141-
# temperature=temperature,
142-
# top_p=top_p,
143142
)
144143
except InternalServerError as e:
145144
print(e)

0 commit comments

Comments
 (0)