Contributed by Austin Moehle on 2024-03-04
This guide demonstrates how we developed Braintrust’s AI-powered search bar, harnessing the power of Braintrust’s evaluation workflow along the way. If you’ve used Braintrust before, you may be familiar with the project page, which serves as a home base for collections of eval experiments:

experiments run on git commit 2a43fd1 or score under 0.5 and see a corresponding SQL query appear automatically. Let’s achieve this using AI, with assistance from Braintrust’s eval framework.
We’ll start by installing some packages and setting up our OpenAI client.
Report incorrect code
Copy
Ask AI
%pip install -U Levenshtein autoevals braintrust chevron duckdb openai pydantic
Report incorrect code
Copy
Ask AI
import os
import braintrust
import openai
PROJECT_NAME = "AI Search Cookbook"
# We use the Braintrust proxy here to get access to caching, but this is totally optional!
openai_opts = dict(
base_url="https://api.braintrust.dev/v1/proxy",
api_key=os.environ.get("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY"),
)
client = braintrust.wrap_openai(openai.AsyncOpenAI(default_headers={"x-bt-use-cache": "always"}, **openai_opts))
braintrust.login(api_key=os.environ.get("BRAINTRUST_API_KEY", "YOUR_BRAINTRUST_API_KEY"))
dataset = braintrust.init_dataset(PROJECT_NAME, "AI Search Cookbook Data", use_output=False)
Load the data and render the templates
When we ask GPT to translate a search query, we have to account for multiple output options: (1) a SQL filter, (2) a SQL sort, (3) both of the above, or (4) an unsuccessful translation (e.g. for a nonsensical user input). We’ll use function calling to robustly handle each distinct scenario, with the following output format:match: Whether or not the model was able to translate the search into a valid SQL filter/sort.filter: AWHEREclause.sort: AnORDER BYclause.explanation: Explanation for the choices above — this is useful for debugging and evaluation.
Report incorrect code
Copy
Ask AI
import dataclasses
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field, create_model
@dataclasses.dataclass
class FunctionCallOutput:
match: Optional[bool] = None
filter: Optional[str] = None
sort: Optional[str] = None
explanation: Optional[str] = None
error: Optional[str] = None
class Match(BaseModel):
type: Literal["MATCH"] = "MATCH"
explanation: str = Field(
..., description="Explanation of why I called the MATCH function"
)
class SQL(BaseModel):
type: Literal["SQL"] = "SQL"
filter: Optional[str] = Field(..., description="SQL filter clause")
sort: Optional[str] = Field(..., description="SQL sort clause")
explanation: str = Field(
...,
description="Explanation of why I called the SQL function and how I chose the filter and/or sort clauses",
)
class Query(BaseModel):
value: Union[Match, SQL] = Field(
...,
)
def function_choices():
return [
{
"name": "QUERY",
"description": "Break down the query either into a MATCH or SQL call",
"parameters": Query.model_json_schema(),
},
]
Prepare prompts for evaluation in Braintrust
Let’s evaluate two different prompts: a shorter prompt with a brief explanation of the problem statement and description of the experiment schema, and a longer prompt that additionally contains a feed of example cases to guide the model. There’s nothing special about either of these prompts, and that’s OK — we can iterate and improve the prompts when we use Braintrust to drill down into the results.Report incorrect code
Copy
Ask AI
import json
SHORT_PROMPT_FILE = "./assets/short_prompt.tmpl"
LONG_PROMPT_FILE = "./assets/long_prompt.tmpl"
FEW_SHOT_EXAMPLES_FILE = "./assets/few_shot.json"
with open(SHORT_PROMPT_FILE) as f:
short_prompt = f.read()
with open(LONG_PROMPT_FILE) as f:
long_prompt = f.read()
with open(FEW_SHOT_EXAMPLES_FILE, "r") as f:
few_shot_examples = json.load(f)
more than 40 examples or score < 0.5 that don’t directly reference a column in the base table. We need to tell the model how the data is structured and what each fields actually means. We’ll construct a descriptive schema using pydantic and paste it into each prompt to provide the model with this information.
Report incorrect code
Copy
Ask AI
from typing import Any, Callable, Dict, List
import chevron
class ExperimentGitState(BaseModel):
commit: str = Field(
...,
description="Git commit hash. Any prefix of this hash at least 7 characters long should be considered an exact match, so use a substring filter rather than string equality to check the commit, e.g. `(source->>'commit') ILIKE '{COMMIT}%'`",
)
branch: str = Field(..., description="Git branch name")
tag: Optional[str] = Field(..., description="Git commit tag")
commit_time: int = Field(..., description="Git commit timestamp")
author_name: str = Field(..., description="Author of git commit")
author_email: str = Field(..., description="Email address of git commit author")
commit_message: str = Field(..., description="Git commit message")
dirty: Optional[bool] = Field(
...,
description="Whether the git state was dirty when the experiment was run. If false, the git state was clean",
)
class Experiment(BaseModel):
id: str = Field(..., description="Experiment ID, unique")
name: str = Field(..., description="Name of the experiment")
last_updated: int = Field(
...,
description="Timestamp marking when the experiment was last updated. If the query deals with some notion of relative time, like age or recency, refer to this timestamp and, if appropriate, compare it to the current time `get_current_time()` by adding or subtracting an interval.",
)
creator: Dict[str, str] = Field(..., description="Information about the experiment creator")
source: ExperimentGitState = Field(..., description="Git state that the experiment was run on")
metadata: Dict[str, Any] = Field(
...,
description="Custom metadata provided by the user. Ignore this field unless the query mentions metadata or refers to a metadata key specifically",
)
def build_experiment_schema(score_fields: List[str]):
ExperimentWithScoreFields = create_model(
"Experiment",
__base__=Experiment,
**{field: (Optional[float], ...) for field in score_fields},
)
return json.dumps(ExperimentWithScoreFields.model_json_schema())
Load sample data
Let’s load our examples. Each example case containsinput (the search query) and expected (function call output).
Report incorrect code
Copy
Ask AI
import json
@dataclasses.dataclass
class Example:
input: str
expected: FunctionCallOutput
metadata: Optional[Dict[str, Any]] = None
EXAMPLES_FILE = "./assets/examples.json"
with open(EXAMPLES_FILE) as f:
examples_json = json.load(f)
templates = [
Example(input=e["input"], expected=FunctionCallOutput(**e["expected"])) for e in examples_json["examples"]
]
# Each example contains a few dynamic fields that depends on the experiments
# we're searching over. For simplicity, we'll hard-code these fields here.
SCORE_FIELDS = ["avg_sql_score", "avg_factuality_score"]
def render_example(example: Example, args: Dict[str, Any]) -> Example:
render_optional = lambda template: (chevron.render(template, args, warn=True) if template is not None else None)
return Example(
input=render_optional(example.input),
expected=FunctionCallOutput(
match=example.expected.match,
filter=render_optional(example.expected.filter),
sort=render_optional(example.expected.sort),
explanation=render_optional(example.expected.explanation),
),
)
examples = [render_example(t, {"score_fields": SCORE_FIELDS}) for t in templates]
Report incorrect code
Copy
Ask AI
for i, e in enumerate(examples):
if i < 0.8 * len(examples):
e.metadata = {"split": "train"}
else:
e.metadata = {"split": "test"}
Report incorrect code
Copy
Ask AI
for example in examples:
dataset.insert(
input=example.input, expected=example.expected, metadata=example.metadata
)
dataset.flush()
records = list(dataset)
print(f"Generated {len(records)} records. Here are the first 2...")
for record in records[:2]:
print(record)
Report incorrect code
Copy
Ask AI
Generated 45 records. Here are the first 2...
{'id': '05e44f2c-da5c-4f5e-a253-d6ce1d081ca4', 'span_id': 'c2329825-10d3-462f-890b-ef54323f8060', 'root_span_id': 'c2329825-10d3-462f-890b-ef54323f8060', '_xact_id': '1000192628646491178', 'created': '2024-03-04T08:08:12.977238Z', 'project_id': '61ce386b-1dac-4027-980f-2f3baf32c9f4', 'dataset_id': 'cbb856d4-b2d9-41ea-a5a7-ba5b78be6959', 'input': 'name is foo', 'expected': {'sort': None, 'error': None, 'match': False, 'filter': "name = 'foo'", 'explanation': 'I interpret the query as a string equality filter on the "name" column. The query does not have any sort semantics, so there is no sort.'}, 'metadata': {'split': 'train'}, 'tags': None}
{'id': '0d127613-505c-404c-8140-2c287313b682', 'span_id': '1e72c902-fe72-4438-adf4-19950f8a2c57', 'root_span_id': '1e72c902-fe72-4438-adf4-19950f8a2c57', '_xact_id': '1000192628646491178', 'created': '2024-03-04T08:08:12.981295Z', 'project_id': '61ce386b-1dac-4027-980f-2f3baf32c9f4', 'dataset_id': 'cbb856d4-b2d9-41ea-a5a7-ba5b78be6959', 'input': "'highest score'", 'expected': {'sort': None, 'error': None, 'match': True, 'filter': None, 'explanation': 'According to directive 2, a query entirely wrapped in quotes should use the MATCH function.'}, 'metadata': {'split': 'train'}, 'tags': None}
Define scoring functions
How do we score our outputs against the ground truth queries? We can’t rely on an exact text match, since there are multiple correct ways to translate a SQL query. Instead, we’ll use two approximate scoring methods: (1)SQLScorer, which roundtrips each query through json_serialize_sql to normalize before attempting a direct comparison, and (2) AutoScorer, which delegates the scoring task to gpt-4.
Report incorrect code
Copy
Ask AI
import duckdb
from braintrust import current_span, traced
from Levenshtein import distance
from autoevals import Score, Scorer, Sql
EXPERIMENTS_TABLE = "./assets/experiments.parquet"
SUMMARY_TABLE = "./assets/experiments_summary.parquet"
duckdb.sql(f"DROP TABLE IF EXISTS experiments; CREATE TABLE experiments AS SELECT * FROM '{EXPERIMENTS_TABLE}'")
duckdb.sql(
f"DROP TABLE IF EXISTS experiments_summary; CREATE TABLE experiments_summary AS SELECT * FROM '{SUMMARY_TABLE}'"
)
def _test_clause(*, filter=None, sort=None) -> bool:
clause = f"""
SELECT
experiments.id AS id,
experiments.name,
experiments_summary.last_updated,
experiments.user AS creator,
experiments.repo_info AS source,
experiments_summary.* EXCLUDE (experiment_id, last_updated),
FROM experiments
LEFT JOIN experiments_summary ON experiments.id = experiments_summary.experiment_id
{'WHERE ' + filter if filter else ''}
{'ORDER BY ' + sort if sort else ''}
"""
current_span().log(metadata=dict(test_clause=clause))
try:
duckdb.sql(clause).fetchall()
return True
except Exception:
return False
def _single_quote(s):
return f"""'{s.replace("'", "''")}'"""
def _roundtrip_filter(s):
return duckdb.sql(
f"""
SELECT json_deserialize_sql(json_serialize_sql({_single_quote(f"SELECT 1 WHERE {s}")}))
"""
).fetchall()[0][0]
def _roundtrip_sort(s):
return duckdb.sql(
f"""
SELECT json_deserialize_sql(json_serialize_sql({_single_quote(f"SELECT 1 ORDER BY {s}")}))
"""
).fetchall()[0][0]
def score_clause(
output: Optional[str],
expected: Optional[str],
roundtrip: Callable[[str], str],
test_clause: Callable[[str], bool],
) -> float:
exact_match = 1 if output == expected else 0
current_span().log(scores=dict(exact_match=exact_match))
if exact_match:
return 1
roundtrip_match = 0
try:
if roundtrip(output) == roundtrip(expected):
roundtrip_match = 1
except Exception as e:
current_span().log(metadata=dict(roundtrip_error=str(e)))
current_span().log(scores=dict(roundtrip_match=roundtrip_match))
if roundtrip_match:
return 1
# If the queries aren't equivalent after roundtripping, it's not immediately clear
# whether they are semantically equivalent. Let's at least check that the generated
# clause is valid SQL by running the `test_clause` function defined above, which
# runs a test query against our sample data.
valid_clause_score = 1 if test_clause(output) else 0
current_span().log(scores=dict(valid_clause=valid_clause_score))
if valid_clause_score == 0:
return 0
max_len = max(len(clause) for clause in [output, expected])
if max_len == 0:
current_span().log(metadata=dict(error="Bad example: empty clause"))
return 0
return 1 - (distance(output, expected) / max_len)
class SQLScorer(Scorer):
"""SQLScorer uses DuckDB's `json_serialize_sql` function to determine whether
the model's chosen filter/sort clause(s) are equivalent to the expected
outputs. If not, we assign partial credit to each clause depending on
(1) whether the clause is valid SQL, as determined by running it against
the actual data and seeing if it errors, and (2) a distance-wise comparison
to the expected text.
"""
def _run_eval_sync(
self,
output,
expected=None,
**kwargs,
):
if expected is None:
raise ValueError("SQLScorer requires an expected value")
name = "SQLScorer"
expected = FunctionCallOutput(**expected)
function_choice_score = 1 if output.match == expected.match else 0
current_span().log(scores=dict(function_choice=function_choice_score))
if function_choice_score == 0:
return Score(name=name, score=0)
if expected.match:
return Score(name=name, score=1)
filter_score = None
if output.filter and expected.filter:
with current_span().start_span("SimpleFilter") as span:
filter_score = score_clause(
output.filter,
expected.filter,
_roundtrip_filter,
lambda s: _test_clause(filter=s),
)
elif output.filter or expected.filter:
filter_score = 0
current_span().log(scores=dict(filter=filter_score))
sort_score = None
if output.sort and expected.sort:
with current_span().start_span("SimpleSort") as span:
sort_score = score_clause(
output.sort,
expected.sort,
_roundtrip_sort,
lambda s: _test_clause(sort=s),
)
elif output.sort or expected.sort:
sort_score = 0
current_span().log(scores=dict(sort=sort_score))
scores = [s for s in [filter_score, sort_score] if s is not None]
if len(scores) == 0:
return Score(
name=name,
score=0,
error="Bad example: no filter or sort for SQL function call",
)
return Score(name=name, score=sum(scores) / len(scores))
@traced("auto_score_filter")
def auto_score_filter(openai_opts, **kwargs):
return Sql(**openai_opts)(**kwargs)
@traced("auto_score_sort")
def auto_score_sort(openai_opts, **kwargs):
return Sql(**openai_opts)(**kwargs)
class AutoScorer(Scorer):
"""AutoScorer uses the `Sql` scorer from the autoevals library to auto-score
the model's chosen filter/sort clause(s) against the expected outputs
using an LLM.
"""
def __init__(self, **openai_opts):
self.openai_opts = openai_opts
def _run_eval_sync(
self,
output,
expected=None,
**kwargs,
):
if expected is None:
raise ValueError("AutoScorer requires an expected value")
input = kwargs.get("input")
if input is None or not isinstance(input, str):
raise ValueError("AutoScorer requires an input value of type str")
name = "AutoScorer"
expected = FunctionCallOutput(**expected)
function_choice_score = 1 if output.match == expected.match else 0
current_span().log(scores=dict(function_choice=function_choice_score))
if function_choice_score == 0:
return Score(name=name, score=0)
if expected.match:
return Score(name=name, score=1)
filter_score = None
if output.filter and expected.filter:
result = auto_score_filter(
openai_opts=self.openai_opts,
input=input,
output=output.filter,
expected=expected.filter,
)
filter_score = result.score or 0
elif output.filter or expected.filter:
filter_score = 0
current_span().log(scores=dict(filter=filter_score))
sort_score = None
if output.sort and expected.sort:
result = auto_score_sort(
openai_opts=self.openai_opts,
input=input,
output=output.sort,
expected=expected.sort,
)
sort_score = result.score or 0
elif output.sort or expected.sort:
sort_score = 0
current_span().log(scores=dict(sort=sort_score))
scores = [s for s in [filter_score, sort_score] if s is not None]
if len(scores) == 0:
return Score(
name=name,
score=0,
error="Bad example: no filter or sort for SQL function call",
)
return Score(name=name, score=sum(scores) / len(scores))
Run the evals!
We’ll use the BraintrustEval framework to set up our experiments according to the prompts, dataset, and scoring functions defined above.
Report incorrect code
Copy
Ask AI
def build_completion_kwargs(
*,
query: str,
model: str,
prompt: str,
score_fields: List[str],
**kwargs,
):
# Inject the JSON schema into the prompt to assist the model.
schema = build_experiment_schema(score_fields=score_fields)
system_message = chevron.render(
prompt.strip(), {"schema": schema, "examples": few_shot_examples}, warn=True
)
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": f"Query: {query}"},
]
# We use the legacy function choices format for now, because fine-tuning still requires it.
return dict(
model=model,
temperature=0,
messages=messages,
functions=function_choices(),
function_call={"name": "QUERY"},
)
def format_output(completion):
try:
function_call = completion.choices[0].message.function_call
arguments = json.loads(function_call.arguments)["value"]
match = arguments.pop("type").lower() == "match"
return FunctionCallOutput(match=match, **arguments)
except Exception as e:
return FunctionCallOutput(error=str(e))
GRADER = "gpt-4" # Used by AutoScorer to grade the model outputs
def make_task(model, prompt, score_fields):
async def task(input):
completion_kwargs = build_completion_kwargs(
query=input,
model=model,
prompt=prompt,
score_fields=score_fields,
)
return format_output(await client.chat.completions.create(**completion_kwargs))
return task
async def run_eval(experiment_name, prompt, model, score_fields=SCORE_FIELDS):
task = make_task(model, prompt, score_fields)
await braintrust.Eval(
name=PROJECT_NAME,
experiment_name=experiment_name,
data=dataset,
task=task,
scores=[SQLScorer(), AutoScorer(**openai_opts, model=GRADER)],
)
Report incorrect code
Copy
Ask AI
args = build_completion_kwargs(
query=list(dataset)[0]["input"],
model="gpt-3.5-turbo",
prompt=short_prompt,
score_fields=SCORE_FIELDS,
)
response = await client.chat.completions.create(**args)
format_output(response)
Report incorrect code
Copy
Ask AI
FunctionCallOutput(match=False, filter="(name) = 'foo'", sort=None, explanation="Filtered for experiments where the name is 'foo'.", error=None)
gpt-3.5-turbo for both.
Report incorrect code
Copy
Ask AI
await run_eval("Short Prompt", short_prompt, "gpt-3.5-turbo")
Report incorrect code
Copy
Ask AI
Experiment Short Prompt is running at https://www.braintrust.dev/app/braintrust.dev/p/AI%20Search%20Cookbook/Short%20Prompt
AI Search Cookbook [experiment_name=Short Prompt] (data): 45it [00:00, 73071.50it/s]
Report incorrect code
Copy
Ask AI
AI Search Cookbook [experiment_name=Short Prompt] (tasks): 0%| | 0/45 [00:00<?, ?it/s]
Report incorrect code
Copy
Ask AI
=========================SUMMARY=========================
Short Prompt compared to Long Prompt 2.0:
46.28% (-21.68%) 'SQLScorer' score (10 improvements, 25 regressions)
15.00% (-36.52%) 'exact_match' score (2 improvements, 7 regressions)
40.89% (-32.19%) 'sort' score (0 improvements, 4 regressions)
16.67% (+01.96%) 'roundtrip_match' score (2 improvements, 3 regressions)
69.36% (-04.67%) 'filter' score (6 improvements, 10 regressions)
60.00% (-22.22%) 'function_choice' score (5 improvements, 15 regressions)
70.00% (-16.67%) 'valid_clause' score (1 improvements, 0 regressions)
43.33% (-12.22%) 'AutoScorer' score (9 improvements, 15 regressions)
4.54s (-210.10%) 'duration' (28 improvements, 17 regressions)
See results for Short Prompt at https://www.braintrust.dev/app/braintrust.dev/p/AI%20Search%20Cookbook/Short%20Prompt
Report incorrect code
Copy
Ask AI
await run_eval("Long Prompt", long_prompt, "gpt-3.5-turbo")
Report incorrect code
Copy
Ask AI
Experiment Long Prompt is running at https://www.braintrust.dev/app/braintrust.dev/p/AI%20Search%20Cookbook/Long%20Prompt
AI Search Cookbook [experiment_name=Long Prompt] (data): 45it [00:00, 35385.02it/s]
Report incorrect code
Copy
Ask AI
AI Search Cookbook [experiment_name=Long Prompt] (tasks): 0%| | 0/45 [00:00<?, ?it/s]
Report incorrect code
Copy
Ask AI
=========================SUMMARY=========================
Long Prompt compared to Short Prompt:
67.99% (+21.71%) 'SQLScorer' score (21 improvements, 5 regressions)
50.00% (+35.00%) 'exact_match' score (6 improvements, 1 regressions)
71.92% (+31.02%) 'sort' score (3 improvements, 0 regressions)
03.12% (-13.54%) 'roundtrip_match' score (1 improvements, 2 regressions)
71.53% (+02.17%) 'filter' score (10 improvements, 5 regressions)
77.78% (+17.78%) 'function_choice' score (9 improvements, 1 regressions)
84.38% (+14.38%) 'valid_clause' score (1 improvements, 1 regressions)
55.56% (+12.22%) 'AutoScorer' score (9 improvements, 4 regressions)
5.90s (+136.66%) 'duration' (11 improvements, 34 regressions)
See results for Long Prompt at https://www.braintrust.dev/app/braintrust.dev/p/AI%20Search%20Cookbook/Long%20Prompt
View the results in Braintrust
The evals will generate a link to the experiment page. Click into an experiment to view the results! If you’ve just been following along, you can check out some sample results here. Type some searches into the search bar to see AI search in action. :)
Fine-tuning
Let’s try to fine-tune the model with an exceedingly short prompt. We’ll use the same dataset and scoring functions, but we’ll change the prompt to be more concise. To start, let’s play with one example:Report incorrect code
Copy
Ask AI
first = list(dataset.fetch())[0]
print(first["input"])
print(json.dumps(first["expected"], indent=2))
Report incorrect code
Copy
Ask AI
name is foo
{
"sort": null,
"error": null,
"match": false,
"filter": "name = 'foo'",
"explanation": "I interpret the query as a string equality filter on the \"name\" column. The query does not have any sort semantics, so there is no sort."
}
Report incorrect code
Copy
Ask AI
from dataclasses import asdict
from pprint import pprint
long_prompt_args = build_completion_kwargs(
query=first["input"],
model="gpt-3.5-turbo",
prompt=long_prompt,
score_fields=SCORE_FIELDS,
)
output = await client.chat.completions.create(**long_prompt_args)
function_call = output.choices[0].message.function_call
print(function_call.name)
pprint(json.loads(function_call.arguments))
Report incorrect code
Copy
Ask AI
QUERY
{'value': {'explanation': "The query refers to the 'name' field in the "
"'experiments' table, so I used ILIKE to check if "
"the name contains 'foo'. I wrapped the filter in "
'parentheses and used ILIKE for case-insensitive '
'matching.',
'filter': "name ILIKE 'foo'",
'sort': None,
'type': 'SQL'}}
Report incorrect code
Copy
Ask AI
def transform_function_call(expected_value):
return {
"name": "QUERY",
"arguments": json.dumps(
{
"value": {
"type": (
expected_value.get("function")
if expected_value.get("function")
else "MATCH" if expected_value.get("match") else "SQL"
),
**{
k: v
for (k, v) in expected_value.items()
if k in ("filter", "sort", "explanation") and v is not None
},
}
}
),
}
transform_function_call(first["expected"])
Report incorrect code
Copy
Ask AI
{'name': 'QUERY',
'arguments': '{"value": {"type": "SQL", "filter": "name = \'foo\'", "explanation": "I interpret the query as a string equality filter on the \\"name\\" column. The query does not have any sort semantics, so there is no sort."}}'}
Report incorrect code
Copy
Ask AI
transform_function_call(few_shot_examples[0])
Report incorrect code
Copy
Ask AI
{'name': 'QUERY',
'arguments': '{"value": {"type": "SQL", "filter": "(metrics->>\'accuracy\')::NUMERIC < 0.2", "explanation": "The query refers to a JSON field, so I correct the JSON extraction syntax according to directive 4 and cast the result to NUMERIC to compare to the value \`0.2\` as per directive 9."}}'}
Report incorrect code
Copy
Ask AI
FINE_TUNING_PROMPT_FILE = "./assets/fine_tune.tmpl"
with open(FINE_TUNING_PROMPT_FILE) as f:
fine_tune_prompt = f.read()
Report incorrect code
Copy
Ask AI
def build_expected_messages(query, expected, prompt, score_fields):
args = build_completion_kwargs(
query=first["input"],
model="gpt-3.5-turbo",
prompt=fine_tune_prompt,
score_fields=score_fields,
)
function_call = transform_function_call(expected)
return {
"messages": args["messages"]
+ [{"role": "assistant", "function_call": function_call}],
"functions": args["functions"],
}
build_expected_messages(
first["input"], first["expected"], fine_tune_prompt, SCORE_FIELDS
)
Report incorrect code
Copy
Ask AI
{'messages': [{'role': 'system',
'content': 'Table: experiments\n\n<Schema>\n{"$defs": {"ExperimentGitState": {"properties": {"commit": {"description": "Git commit hash. Any prefix of this hash at least 7 characters long should be considered an exact match, so use a substring filter rather than string equality to check the commit, e.g. \`(source->>\'commit\') ILIKE \'{COMMIT}%\'\`", "title": "Commit", "type": "string"}, "branch": {"description": "Git branch name", "title": "Branch", "type": "string"}, "tag": {"anyOf": [{"type": "string"}, {"type": "null"}], "description": "Git commit tag", "title": "Tag"}, "commit_time": {"description": "Git commit timestamp", "title": "Commit Time", "type": "integer"}, "author_name": {"description": "Author of git commit", "title": "Author Name", "type": "string"}, "author_email": {"description": "Email address of git commit author", "title": "Author Email", "type": "string"}, "commit_message": {"description": "Git commit message", "title": "Commit Message", "type": "string"}, "dirty": {"anyOf": [{"type": "boolean"}, {"type": "null"}], "description": "Whether the git state was dirty when the experiment was run. If false, the git state was clean", "title": "Dirty"}}, "required": ["commit", "branch", "tag", "commit_time", "author_name", "author_email", "commit_message", "dirty"], "title": "ExperimentGitState", "type": "object"}}, "properties": {"id": {"description": "Experiment ID, unique", "title": "Id", "type": "string"}, "name": {"description": "Name of the experiment", "title": "Name", "type": "string"}, "last_updated": {"description": "Timestamp marking when the experiment was last updated. If the query deals with some notion of relative time, like age or recency, refer to this timestamp and, if appropriate, compare it to the current time \`get_current_time()\` by adding or subtracting an interval.", "title": "Last Updated", "type": "integer"}, "creator": {"additionalProperties": {"type": "string"}, "description": "Information about the experiment creator", "title": "Creator", "type": "object"}, "source": {"allOf": [{"$ref": "#/$defs/ExperimentGitState"}], "description": "Git state that the experiment was run on"}, "metadata": {"description": "Custom metadata provided by the user. Ignore this field unless the query mentions metadata or refers to a metadata key specifically", "title": "Metadata", "type": "object"}, "avg_sql_score": {"anyOf": [{"type": "number"}, {"type": "null"}], "title": "Avg Sql Score"}, "avg_factuality_score": {"anyOf": [{"type": "number"}, {"type": "null"}], "title": "Avg Factuality Score"}}, "required": ["id", "name", "last_updated", "creator", "source", "metadata", "avg_sql_score", "avg_factuality_score"], "title": "Experiment", "type": "object"}\n</Schema>'},
{'role': 'user', 'content': 'Query: name is foo'},
{'role': 'assistant',
'function_call': {'name': 'QUERY',
'arguments': '{"value": {"type": "SQL", "filter": "name = \'foo\'", "explanation": "I interpret the query as a string equality filter on the \\"name\\" column. The query does not have any sort semantics, so there is no sort."}}'}}],
'functions': [{'name': 'QUERY',
'description': 'Break down the query either into a MATCH or SQL call',
'parameters': {'$defs': {'Match': {'properties': {'type': {'const': 'MATCH',
'default': 'MATCH',
'title': 'Type'},
'explanation': {'description': 'Explanation of why I called the MATCH function',
'title': 'Explanation',
'type': 'string'}},
'required': ['explanation'],
'title': 'Match',
'type': 'object'},
'SQL': {'properties': {'type': {'const': 'SQL',
'default': 'SQL',
'title': 'Type'},
'filter': {'anyOf': [{'type': 'string'}, {'type': 'null'}],
'description': 'SQL filter clause',
'title': 'Filter'},
'sort': {'anyOf': [{'type': 'string'}, {'type': 'null'}],
'description': 'SQL sort clause',
'title': 'Sort'},
'explanation': {'description': 'Explanation of why I called the SQL function and how I chose the filter and/or sort clauses',
'title': 'Explanation',
'type': 'string'}},
'required': ['filter', 'sort', 'explanation'],
'title': 'SQL',
'type': 'object'}},
'properties': {'value': {'anyOf': [{'$ref': '#/$defs/Match'},
{'$ref': '#/$defs/SQL'}],
'title': 'Value'}},
'required': ['value'],
'title': 'Query',
'type': 'object'}}]}
Report incorrect code
Copy
Ask AI
train_records = [r for r in records if r["metadata"]["split"] == "train"] + [
{"input": r["query"], "expected": r} for r in few_shot_examples
]
all_expected_messages = [
build_expected_messages(r["input"], r["expected"], fine_tune_prompt, SCORE_FIELDS)
for r in train_records
]
print(len(all_expected_messages))
all_expected_messages[1]
Report incorrect code
Copy
Ask AI
49
Report incorrect code
Copy
Ask AI
{'messages': [{'role': 'system',
'content': 'Table: experiments\n\n<Schema>\n{"$defs": {"ExperimentGitState": {"properties": {"commit": {"description": "Git commit hash. Any prefix of this hash at least 7 characters long should be considered an exact match, so use a substring filter rather than string equality to check the commit, e.g. \`(source->>\'commit\') ILIKE \'{COMMIT}%\'\`", "title": "Commit", "type": "string"}, "branch": {"description": "Git branch name", "title": "Branch", "type": "string"}, "tag": {"anyOf": [{"type": "string"}, {"type": "null"}], "description": "Git commit tag", "title": "Tag"}, "commit_time": {"description": "Git commit timestamp", "title": "Commit Time", "type": "integer"}, "author_name": {"description": "Author of git commit", "title": "Author Name", "type": "string"}, "author_email": {"description": "Email address of git commit author", "title": "Author Email", "type": "string"}, "commit_message": {"description": "Git commit message", "title": "Commit Message", "type": "string"}, "dirty": {"anyOf": [{"type": "boolean"}, {"type": "null"}], "description": "Whether the git state was dirty when the experiment was run. If false, the git state was clean", "title": "Dirty"}}, "required": ["commit", "branch", "tag", "commit_time", "author_name", "author_email", "commit_message", "dirty"], "title": "ExperimentGitState", "type": "object"}}, "properties": {"id": {"description": "Experiment ID, unique", "title": "Id", "type": "string"}, "name": {"description": "Name of the experiment", "title": "Name", "type": "string"}, "last_updated": {"description": "Timestamp marking when the experiment was last updated. If the query deals with some notion of relative time, like age or recency, refer to this timestamp and, if appropriate, compare it to the current time \`get_current_time()\` by adding or subtracting an interval.", "title": "Last Updated", "type": "integer"}, "creator": {"additionalProperties": {"type": "string"}, "description": "Information about the experiment creator", "title": "Creator", "type": "object"}, "source": {"allOf": [{"$ref": "#/$defs/ExperimentGitState"}], "description": "Git state that the experiment was run on"}, "metadata": {"description": "Custom metadata provided by the user. Ignore this field unless the query mentions metadata or refers to a metadata key specifically", "title": "Metadata", "type": "object"}, "avg_sql_score": {"anyOf": [{"type": "number"}, {"type": "null"}], "title": "Avg Sql Score"}, "avg_factuality_score": {"anyOf": [{"type": "number"}, {"type": "null"}], "title": "Avg Factuality Score"}}, "required": ["id", "name", "last_updated", "creator", "source", "metadata", "avg_sql_score", "avg_factuality_score"], "title": "Experiment", "type": "object"}\n</Schema>'},
{'role': 'user', 'content': 'Query: name is foo'},
{'role': 'assistant',
'function_call': {'name': 'QUERY',
'arguments': '{"value": {"type": "MATCH", "explanation": "According to directive 2, a query entirely wrapped in quotes should use the MATCH function."}}'}}],
'functions': [{'name': 'QUERY',
'description': 'Break down the query either into a MATCH or SQL call',
'parameters': {'$defs': {'Match': {'properties': {'type': {'const': 'MATCH',
'default': 'MATCH',
'title': 'Type'},
'explanation': {'description': 'Explanation of why I called the MATCH function',
'title': 'Explanation',
'type': 'string'}},
'required': ['explanation'],
'title': 'Match',
'type': 'object'},
'SQL': {'properties': {'type': {'const': 'SQL',
'default': 'SQL',
'title': 'Type'},
'filter': {'anyOf': [{'type': 'string'}, {'type': 'null'}],
'description': 'SQL filter clause',
'title': 'Filter'},
'sort': {'anyOf': [{'type': 'string'}, {'type': 'null'}],
'description': 'SQL sort clause',
'title': 'Sort'},
'explanation': {'description': 'Explanation of why I called the SQL function and how I chose the filter and/or sort clauses',
'title': 'Explanation',
'type': 'string'}},
'required': ['filter', 'sort', 'explanation'],
'title': 'SQL',
'type': 'object'}},
'properties': {'value': {'anyOf': [{'$ref': '#/$defs/Match'},
{'$ref': '#/$defs/SQL'}],
'title': 'Value'}},
'required': ['value'],
'title': 'Query',
'type': 'object'}}]}
Report incorrect code
Copy
Ask AI
import io
# Use the direct OpenAI client, not a proxy
sync_client = openai.OpenAI(
api_key=os.environ.get("OPENAI_API_KEY", "<Your OpenAI API Key>"),
base_url="https://api.openai.com/v1",
)
file_string = "\n".join(json.dumps(messages) for messages in all_expected_messages)
file = sync_client.files.create(
file=io.BytesIO(file_string.encode()), purpose="fine-tune"
)
Report incorrect code
Copy
Ask AI
job = sync_client.fine_tuning.jobs.create(training_file=file.id, model="gpt-3.5-turbo")
Report incorrect code
Copy
Ask AI
import time
start = time.time()
job_id = job.id
while True:
info = sync_client.fine_tuning.jobs.retrieve(job_id)
if info.finished_at is not None:
break
print(f"{time.time() - start:.0f}s elapsed", end="\t")
print(str(info), end="\r")
time.sleep(10)
Report incorrect code
Copy
Ask AI
info = sync_client.fine_tuning.jobs.retrieve(job_id)
fine_tuned_model = info.fine_tuned_model
fine_tuned_model
Report incorrect code
Copy
Ask AI
ft_prompt_args = build_completion_kwargs(
query=first["input"],
model=fine_tuned_model,
prompt=fine_tune_prompt,
score_fields=SCORE_FIELDS,
)
del ft_prompt_args["temperature"]
print(ft_prompt_args)
output = await client.chat.completions.create(**ft_prompt_args)
print(output)
print(format_output(output))
Report incorrect code
Copy
Ask AI
await run_eval("Fine tuned model", fine_tune_prompt, fine_tuned_model)
Report incorrect code
Copy
Ask AI
Experiment Fine tuned model is running at https://www.braintrust.dev/app/braintrust.dev/p/AI%20Search%20Cookbook/Fine%20tuned%20model
AI Search Cookbook [experiment_name=Fine tuned model] (data): 45it [00:00, 15835.53it/s]
Report incorrect code
Copy
Ask AI
AI Search Cookbook [experiment_name=Fine tuned model] (tasks): 0%| | 0/45 [00:00<?, ?it/s]
Report incorrect code
Copy
Ask AI
=========================SUMMARY=========================
Fine tuned model compared to Long Prompt:
77.78% (-) 'function_choice' score (8 improvements, 8 regressions)
75.93% (-08.45%) 'valid_clause' score (0 improvements, 2 regressions)
30.00% (-20.00%) 'exact_match' score (2 improvements, 9 regressions)
48.09% (-23.44%) 'filter' score (5 improvements, 15 regressions)
53.44% (-18.47%) 'sort' score (1 improvements, 4 regressions)
32.22% (-23.33%) 'AutoScorer' score (7 improvements, 18 regressions)
05.36% (+02.23%) 'roundtrip_match' score (1 improvements, 1 regressions)
48.22% (-19.77%) 'SQLScorer' score (10 improvements, 25 regressions)
79.41s (+7350.58%) 'duration' (0 improvements, 45 regressions)
See results for Fine tuned model at https://www.braintrust.dev/app/braintrust.dev/p/AI%20Search%20Cookbook/Fine%20tuned%20model