2424 GPT_MODELS ,
2525 CLAUDE_MODELS ,
2626 DEEPSEEK_MODELS ,
27+ GEMINI_MODELS ,
2728 get_ant_model ,
2829 get_ant_ttc_model ,
2930 get_oai_model ,
@@ -83,7 +84,12 @@ def main(
8384 """
8485 assert (
8586 model
86- in GPT_MODELS + OLLAMA_MODELS + CLAUDE_MODELS + O1_MODELS + DEEPSEEK_MODELS
87+ in GPT_MODELS
88+ + OLLAMA_MODELS
89+ + CLAUDE_MODELS
90+ + O1_MODELS
91+ + DEEPSEEK_MODELS
92+ + GEMINI_MODELS
8793 ), f"{ model } is not supported."
8894
8995 if output_file is None :
@@ -99,11 +105,11 @@ def main(
99105 if model in GPT_MODELS :
100106 assert "OPENAI_API_KEY" in os .environ , "Please set OPEN_API_KEY in environment!"
101107 client = OpenAI (api_key = os .environ ["OPENAI_API_KEY" ])
102- generate = get_oai_model (client , model , seed , temperature , pure )
108+ generate = get_oai_model (client , model , pure )
103109 elif model in O1_MODELS :
104110 assert "OPENAI_API_KEY" in os .environ , "Please set OPEN_API_KEY in environment!"
105111 client = OpenAI (api_key = os .environ ["OPENAI_API_KEY" ])
106- generate = get_o1_model (client , model , seed , temperature , pure )
112+ generate = get_o1_model (client , model , pure )
107113 elif model in CLAUDE_MODELS :
108114 assert (
109115 "ANTHROPIC_API_KEY" in os .environ
@@ -120,7 +126,16 @@ def main(
120126 client = OpenAI (
121127 api_key = os .environ ["DEEPSEEK_API_KEY" ], base_url = "https://api.deepseek.com"
122128 )
123- generate = get_oai_model (client , model , seed , temperature , pure )
129+ generate = get_oai_model (client , model , pure )
130+ elif model in GEMINI_MODELS :
131+ assert (
132+ "GEMINI_API_KEY" in os .environ
133+ ), "Please set GEMINI_API_KEY in environment!"
134+ client = OpenAI (
135+ api_key = os .environ ["GEMINI_API_KEY" ],
136+ base_url = "https://generativelanguage.googleapis.com/v1beta/openai/" ,
137+ )
138+ generate = get_oai_model (client , model )
124139 else :
125140 client = OllamaClient (host = f"http://localhost:{ port } " )
126141 generate = get_ollama_model (client , model , seed , temperature , pure )
0 commit comments