-
Notifications
You must be signed in to change notification settings - Fork 961
Expand file tree
/
Copy pathevaluate.py
More file actions
100 lines (85 loc) · 4.19 KB
/
evaluate.py
File metadata and controls
100 lines (85 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import logging
import os
import re
from pathlib import Path
from typing import Any
import azure.identity
from dotenv import load_dotenv
from evaltools.eval.evaluate import run_evaluate_from_config
from evaltools.eval.evaluate_metrics import register_metric
from evaltools.eval.evaluate_metrics.base_metric import BaseMetric
from rich.logging import RichHandler
logger = logging.getLogger("ragapp")
class CitationsMatchedMetric(BaseMetric):
METRIC_NAME = "citations_matched"
@classmethod
def evaluator_fn(cls, **kwargs):
def citations_overlap(*, response, ground_truth, **kwargs):
if response is None:
logger.warning("Received response of None, can't compute citation_match metric. Setting to -1.")
return {cls.METRIC_NAME: -1}
truth_citations = set(re.findall(r"\[(\d+)\]", ground_truth))
response_citations = set(re.findall(r"\[(\d+)\]", response))
# Count the percentage of citations that are present in the response
num_citations = len(truth_citations)
num_matched_citations = len(truth_citations.intersection(response_citations))
return {cls.METRIC_NAME: num_matched_citations / num_citations}
return citations_overlap
@classmethod
def get_aggregate_stats(cls, df):
df = df[df[cls.METRIC_NAME] != -1]
return {"mean": round(df[cls.METRIC_NAME].mean(), 2)}
def get_openai_config() -> dict:
openai_config: dict[str, Any]
if os.environ.get("OPENAI_CHAT_HOST") == "azure":
azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
azure_deployment = os.environ["AZURE_OPENAI_EVAL_DEPLOYMENT"]
if os.environ.get("AZURE_OPENAI_KEY"):
logger.info("Using Azure OpenAI Service with API Key from AZURE_OPENAI_KEY")
openai_config = {
"azure_endpoint": azure_endpoint,
"azure_deployment": azure_deployment,
"api_key": os.environ["AZURE_OPENAI_KEY"],
}
else:
if tenant_id := os.getenv("AZURE_TENANT_ID"):
logger.info("Authenticating to Azure using Azure Developer CLI Credential for tenant %s", tenant_id)
azure_credential = azure.identity.AzureDeveloperCliCredential(tenant_id=tenant_id, process_timeout=60)
else:
logger.info("Authenticating to Azure using Azure Developer CLI Credential")
azure_credential = azure.identity.AzureDeveloperCliCredential(process_timeout=60)
openai_config = {
"azure_endpoint": azure_endpoint,
"azure_deployment": azure_deployment,
"credential": azure_credential,
}
# azure-ai-evaluate will call DefaultAzureCredential behind the scenes,
# so we must be logged in to Azure CLI with the correct tenant
openai_config["model"] = os.environ["AZURE_OPENAI_EVAL_MODEL"]
elif os.environ.get("OPENAI_CHAT_HOST") == "ollama":
raise NotImplementedError("Ollama is not supported. Switch to Azure or OpenAI.com")
else:
logger.info("Using OpenAI Service with API Key from OPENAICOM_KEY")
openai_config = {"api_key": os.environ["OPENAICOM_KEY"], "model": "gpt-4"}
return openai_config
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO, format="%(message)s", datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)]
)
load_dotenv(".env", override=True)
parser = argparse.ArgumentParser(description="Run evaluation with OpenAI configuration.")
parser.add_argument("--targeturl", type=str, help="Specify the target URL.")
parser.add_argument("--resultsdir", type=Path, help="Specify the results directory.")
parser.add_argument("--numquestions", type=int, help="Specify the number of questions.")
args = parser.parse_args()
openai_config = get_openai_config()
register_metric(CitationsMatchedMetric)
run_evaluate_from_config(
working_dir=Path(__file__).parent,
config_path="eval_config.json",
num_questions=args.numquestions,
target_url=args.targeturl,
results_dir=args.resultsdir,
openai_config=openai_config,
)