AI Search Bar

This guide demonstrates how we developed Braintrust's AI-powered search bar, harnessing the power of Braintrust's evaluation workflow along the way. If you've used Braintrust before, you may be familiar with the project page, which serves as a home base for collections of eval experiments:

Braintrust Project Page

To find a particular experiment, you can type filter and sort queries into the search bar, using standard SQL syntax. But SQL can be finicky -- it's very easy to run into syntax errors like single quotes instead of double, incorrect JSON extraction syntax, or typos. Users would prefer to just type in an intuitive search like experiments run on git commit 2a43fd1 or score under 0.5 and see a corresponding SQL query appear automatically. Let's achieve this using AI, with assistance from Braintrust's eval framework.

We'll start by installing some packages and setting up our OpenAI client.

%pip install -U Levenshtein autoevals braintrust chevron duckdb openai pydantic
import os
 
import braintrust
import openai
 
PROJECT_NAME = "AI Search Cookbook"
 
# We use the Braintrust proxy here to get access to caching, but this is totally optional!
openai_opts = dict(
    base_url="https://api.braintrust.dev/v1/proxy",
    api_key=os.environ.get("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY"),
)
client = braintrust.wrap_openai(openai.AsyncOpenAI(default_headers={"x-bt-use-cache": "always"}, **openai_opts))
 
braintrust.login(api_key=os.environ.get("BRAINTRUST_API_KEY", "YOUR_BRAINTRUST_API_KEY"))
dataset = braintrust.init_dataset(PROJECT_NAME, "AI Search Cookbook Data", use_output=False)

Load the data and render the templates

When we ask GPT to translate a search query, we have to account for multiple output options: (1) a SQL filter, (2) a SQL sort, (3) both of the above, or (4) an unsuccessful translation (e.g. for a nonsensical user input). We'll use function calling to robustly handle each distinct scenario, with the following output format:

  • match: Whether or not the model was able to translate the search into a valid SQL filter/sort.
  • filter: A WHERE clause.
  • sort: An ORDER BY clause.
  • explanation: Explanation for the choices above -- this is useful for debugging and evaluation.
import dataclasses
from typing import Literal, Optional, Union
 
from pydantic import BaseModel, Field, create_model
 
 
@dataclasses.dataclass
class FunctionCallOutput:
    match: Optional[bool] = None
    filter: Optional[str] = None
    sort: Optional[str] = None
    explanation: Optional[str] = None
    error: Optional[str] = None
 
 
class Match(BaseModel):
    type: Literal["MATCH"] = "MATCH"
    explanation: str = Field(
        ..., description="Explanation of why I called the MATCH function"
    )
 
 
class SQL(BaseModel):
    type: Literal["SQL"] = "SQL"
    filter: Optional[str] = Field(..., description="SQL filter clause")
    sort: Optional[str] = Field(..., description="SQL sort clause")
    explanation: str = Field(
        ...,
        description="Explanation of why I called the SQL function and how I chose the filter and/or sort clauses",
    )
 
 
class Query(BaseModel):
    value: Union[Match, SQL] = Field(
        ...,
    )
 
 
def function_choices():
    return [
        {
            "name": "QUERY",
            "description": "Break down the query either into a MATCH or SQL call",
            "parameters": Query.model_json_schema(),
        },
    ]

Prepare prompts for evaluation in Braintrust

Let's evaluate two different prompts: a shorter prompt with a brief explanation of the problem statement and description of the experiment schema, and a longer prompt that additionally contains a feed of example cases to guide the model. There's nothing special about either of these prompts, and that's OK -- we can iterate and improve the prompts when we use Braintrust to drill down into the results.

import json
 
SHORT_PROMPT_FILE = "./assets/short_prompt.tmpl"
LONG_PROMPT_FILE = "./assets/long_prompt.tmpl"
FEW_SHOT_EXAMPLES_FILE = "./assets/few_shot.json"
 
with open(SHORT_PROMPT_FILE) as f:
    short_prompt = f.read()
 
with open(LONG_PROMPT_FILE) as f:
    long_prompt = f.read()
 
with open(FEW_SHOT_EXAMPLES_FILE, "r") as f:
    few_shot_examples = json.load(f)

One detail worth mentioning: each prompt contains a stub for dynamic insertion of the data schema. This is motivated by the need to handle semantic searches like more than 40 examples or score < 0.5 that don't directly reference a column in the base table. We need to tell the model how the data is structured and what each fields actually means. We'll construct a descriptive schema using pydantic and paste it into each prompt to provide the model with this information.

from typing import Any, Callable, Dict, List
 
import chevron
 
 
class ExperimentGitState(BaseModel):
    commit: str = Field(
        ...,
        description="Git commit hash. Any prefix of this hash at least 7 characters long should be considered an exact match, so use a substring filter rather than string equality to check the commit, e.g. `(source->>'commit') ILIKE '{COMMIT}%'`",
    )
    branch: str = Field(..., description="Git branch name")
    tag: Optional[str] = Field(..., description="Git commit tag")
    commit_time: int = Field(..., description="Git commit timestamp")
    author_name: str = Field(..., description="Author of git commit")
    author_email: str = Field(..., description="Email address of git commit author")
    commit_message: str = Field(..., description="Git commit message")
    dirty: Optional[bool] = Field(
        ...,
        description="Whether the git state was dirty when the experiment was run. If false, the git state was clean",
    )
 
 
class Experiment(BaseModel):
    id: str = Field(..., description="Experiment ID, unique")
    name: str = Field(..., description="Name of the experiment")
    last_updated: int = Field(
        ...,
        description="Timestamp marking when the experiment was last updated. If the query deals with some notion of relative time, like age or recency, refer to this timestamp and, if appropriate, compare it to the current time `get_current_time()` by adding or subtracting an interval.",
    )
    creator: Dict[str, str] = Field(..., description="Information about the experiment creator")
    source: ExperimentGitState = Field(..., description="Git state that the experiment was run on")
    metadata: Dict[str, Any] = Field(
        ...,
        description="Custom metadata provided by the user. Ignore this field unless the query mentions metadata or refers to a metadata key specifically",
    )
 
 
def build_experiment_schema(score_fields: List[str]):
    ExperimentWithScoreFields = create_model(
        "Experiment",
        __base__=Experiment,
        **{field: (Optional[float], ...) for field in score_fields},
    )
    return json.dumps(ExperimentWithScoreFields.model_json_schema())

Our prompts are ready! Before we run our evals, we just need to load some sample data and define our scoring functions.

Load sample data

Let's load our examples. Each example case contains input (the search query) and expected (function call output).

import json
 
 
@dataclasses.dataclass
class Example:
    input: str
    expected: FunctionCallOutput
    metadata: Optional[Dict[str, Any]] = None
 
 
EXAMPLES_FILE = "./assets/examples.json"
with open(EXAMPLES_FILE) as f:
    examples_json = json.load(f)
 
templates = [
    Example(input=e["input"], expected=FunctionCallOutput(**e["expected"])) for e in examples_json["examples"]
]
 
# Each example contains a few dynamic fields that depends on the experiments
# we're searching over. For simplicity, we'll hard-code these fields here.
SCORE_FIELDS = ["avg_sql_score", "avg_factuality_score"]
 
 
def render_example(example: Example, args: Dict[str, Any]) -> Example:
    render_optional = lambda template: (chevron.render(template, args, warn=True) if template is not None else None)
    return Example(
        input=render_optional(example.input),
        expected=FunctionCallOutput(
            match=example.expected.match,
            filter=render_optional(example.expected.filter),
            sort=render_optional(example.expected.sort),
            explanation=render_optional(example.expected.explanation),
        ),
    )
 
 
examples = [render_example(t, {"score_fields": SCORE_FIELDS}) for t in templates]

Let's also split the examples into a training set and test set. For now, this won't matter, but later on when we fine-tune the model, we'll want to use the test set to evaluate the model's performance.

for i, e in enumerate(examples):
    if i < 0.8 * len(examples):
        e.metadata = {"split": "train"}
    else:
        e.metadata = {"split": "test"}

Insert our examples into a Braintrust dataset so we can introspect and reuse the data later.

for example in examples:
    dataset.insert(
        input=example.input, expected=example.expected, metadata=example.metadata
    )
dataset.flush()
 
records = list(dataset)
print(f"Generated {len(records)} records. Here are the first 2...")
for record in records[:2]:
    print(record)
Generated 45 records. Here are the first 2...
{'id': '05e44f2c-da5c-4f5e-a253-d6ce1d081ca4', 'span_id': 'c2329825-10d3-462f-890b-ef54323f8060', 'root_span_id': 'c2329825-10d3-462f-890b-ef54323f8060', '_xact_id': '1000192628646491178', 'created': '2024-03-04T08:08:12.977238Z', 'project_id': '61ce386b-1dac-4027-980f-2f3baf32c9f4', 'dataset_id': 'cbb856d4-b2d9-41ea-a5a7-ba5b78be6959', 'input': 'name is foo', 'expected': {'sort': None, 'error': None, 'match': False, 'filter': "name = 'foo'", 'explanation': 'I interpret the query as a string equality filter on the "name" column. The query does not have any sort semantics, so there is no sort.'}, 'metadata': {'split': 'train'}, 'tags': None}
{'id': '0d127613-505c-404c-8140-2c287313b682', 'span_id': '1e72c902-fe72-4438-adf4-19950f8a2c57', 'root_span_id': '1e72c902-fe72-4438-adf4-19950f8a2c57', '_xact_id': '1000192628646491178', 'created': '2024-03-04T08:08:12.981295Z', 'project_id': '61ce386b-1dac-4027-980f-2f3baf32c9f4', 'dataset_id': 'cbb856d4-b2d9-41ea-a5a7-ba5b78be6959', 'input': "'highest score'", 'expected': {'sort': None, 'error': None, 'match': True, 'filter': None, 'explanation': 'According to directive 2, a query entirely wrapped in quotes should use the MATCH function.'}, 'metadata': {'split': 'train'}, 'tags': None}

Define scoring functions

How do we score our outputs against the ground truth queries? We can't rely on an exact text match, since there are multiple correct ways to translate a SQL query. Instead, we'll use two approximate scoring methods: (1) SQLScorer, which roundtrips each query through json_serialize_sql to normalize before attempting a direct comparison, and (2) AutoScorer, which delegates the scoring task to gpt-4.

import duckdb
from braintrust import current_span, traced
from Levenshtein import distance
 
from autoevals import Score, Scorer, Sql
 
EXPERIMENTS_TABLE = "./assets/experiments.parquet"
SUMMARY_TABLE = "./assets/experiments_summary.parquet"
duckdb.sql(f"DROP TABLE IF EXISTS experiments; CREATE TABLE experiments AS SELECT * FROM '{EXPERIMENTS_TABLE}'")
duckdb.sql(
    f"DROP TABLE IF EXISTS experiments_summary; CREATE TABLE experiments_summary AS SELECT * FROM '{SUMMARY_TABLE}'"
)
 
 
def _test_clause(*, filter=None, sort=None) -> bool:
    clause = f"""
        SELECT
          experiments.id AS id,
          experiments.name,
          experiments_summary.last_updated,
          experiments.user AS creator,
          experiments.repo_info AS source,
          experiments_summary.* EXCLUDE (experiment_id, last_updated),
        FROM experiments
        LEFT JOIN experiments_summary ON experiments.id = experiments_summary.experiment_id
        {'WHERE ' + filter if filter else ''}
        {'ORDER BY ' + sort if sort else ''}
    """
    current_span().log(metadata=dict(test_clause=clause))
    try:
        duckdb.sql(clause).fetchall()
        return True
    except Exception:
        return False
 
 
def _single_quote(s):
    return f"""'{s.replace("'", "''")}'"""
 
 
def _roundtrip_filter(s):
    return duckdb.sql(
        f"""
        SELECT json_deserialize_sql(json_serialize_sql({_single_quote(f"SELECT 1 WHERE {s}")}))
    """
    ).fetchall()[0][0]
 
 
def _roundtrip_sort(s):
    return duckdb.sql(
        f"""
        SELECT json_deserialize_sql(json_serialize_sql({_single_quote(f"SELECT 1 ORDER BY {s}")}))
    """
    ).fetchall()[0][0]
 
 
def score_clause(
    output: Optional[str],
    expected: Optional[str],
    roundtrip: Callable[[str], str],
    test_clause: Callable[[str], bool],
) -> float:
    exact_match = 1 if output == expected else 0
    current_span().log(scores=dict(exact_match=exact_match))
    if exact_match:
        return 1
 
    roundtrip_match = 0
    try:
        if roundtrip(output) == roundtrip(expected):
            roundtrip_match = 1
    except Exception as e:
        current_span().log(metadata=dict(roundtrip_error=str(e)))
 
    current_span().log(scores=dict(roundtrip_match=roundtrip_match))
    if roundtrip_match:
        return 1
 
    # If the queries aren't equivalent after roundtripping, it's not immediately clear
    # whether they are semantically equivalent. Let's at least check that the generated
    # clause is valid SQL by running the `test_clause` function defined above, which
    # runs a test query against our sample data.
    valid_clause_score = 1 if test_clause(output) else 0
    current_span().log(scores=dict(valid_clause=valid_clause_score))
    if valid_clause_score == 0:
        return 0
 
    max_len = max(len(clause) for clause in [output, expected])
    if max_len == 0:
        current_span().log(metadata=dict(error="Bad example: empty clause"))
        return 0
    return 1 - (distance(output, expected) / max_len)
 
 
class SQLScorer(Scorer):
    """SQLScorer uses DuckDB's `json_serialize_sql` function to determine whether
    the model's chosen filter/sort clause(s) are equivalent to the expected
    outputs. If not, we assign partial credit to each clause depending on
    (1) whether the clause is valid SQL, as determined by running it against
    the actual data and seeing if it errors, and (2) a distance-wise comparison
    to the expected text.
    """
 
    def _run_eval_sync(
        self,
        output,
        expected=None,
        **kwargs,
    ):
        if expected is None:
            raise ValueError("SQLScorer requires an expected value")
 
        name = "SQLScorer"
        expected = FunctionCallOutput(**expected)
 
        function_choice_score = 1 if output.match == expected.match else 0
        current_span().log(scores=dict(function_choice=function_choice_score))
        if function_choice_score == 0:
            return Score(name=name, score=0)
        if expected.match:
            return Score(name=name, score=1)
 
        filter_score = None
        if output.filter and expected.filter:
            with current_span().start_span("SimpleFilter") as span:
                filter_score = score_clause(
                    output.filter,
                    expected.filter,
                    _roundtrip_filter,
                    lambda s: _test_clause(filter=s),
                )
        elif output.filter or expected.filter:
            filter_score = 0
        current_span().log(scores=dict(filter=filter_score))
 
        sort_score = None
        if output.sort and expected.sort:
            with current_span().start_span("SimpleSort") as span:
                sort_score = score_clause(
                    output.sort,
                    expected.sort,
                    _roundtrip_sort,
                    lambda s: _test_clause(sort=s),
                )
        elif output.sort or expected.sort:
            sort_score = 0
        current_span().log(scores=dict(sort=sort_score))
 
        scores = [s for s in [filter_score, sort_score] if s is not None]
        if len(scores) == 0:
            return Score(
                name=name,
                score=0,
                error="Bad example: no filter or sort for SQL function call",
            )
        return Score(name=name, score=sum(scores) / len(scores))
 
 
@traced("auto_score_filter")
def auto_score_filter(openai_opts, **kwargs):
    return Sql(**openai_opts)(**kwargs)
 
 
@traced("auto_score_sort")
def auto_score_sort(openai_opts, **kwargs):
    return Sql(**openai_opts)(**kwargs)
 
 
class AutoScorer(Scorer):
    """AutoScorer uses the `Sql` scorer from the autoevals library to auto-score
    the model's chosen filter/sort clause(s) against the expected outputs
    using an LLM.
    """
 
    def __init__(self, **openai_opts):
        self.openai_opts = openai_opts
 
    def _run_eval_sync(
        self,
        output,
        expected=None,
        **kwargs,
    ):
        if expected is None:
            raise ValueError("AutoScorer requires an expected value")
        input = kwargs.get("input")
        if input is None or not isinstance(input, str):
            raise ValueError("AutoScorer requires an input value of type str")
 
        name = "AutoScorer"
        expected = FunctionCallOutput(**expected)
 
        function_choice_score = 1 if output.match == expected.match else 0
        current_span().log(scores=dict(function_choice=function_choice_score))
        if function_choice_score == 0:
            return Score(name=name, score=0)
        if expected.match:
            return Score(name=name, score=1)
 
        filter_score = None
        if output.filter and expected.filter:
            result = auto_score_filter(
                openai_opts=self.openai_opts,
                input=input,
                output=output.filter,
                expected=expected.filter,
            )
            filter_score = result.score or 0
        elif output.filter or expected.filter:
            filter_score = 0
        current_span().log(scores=dict(filter=filter_score))
 
        sort_score = None
        if output.sort and expected.sort:
            result = auto_score_sort(
                openai_opts=self.openai_opts,
                input=input,
                output=output.sort,
                expected=expected.sort,
            )
            sort_score = result.score or 0
        elif output.sort or expected.sort:
            sort_score = 0
        current_span().log(scores=dict(sort=sort_score))
 
        scores = [s for s in [filter_score, sort_score] if s is not None]
        if len(scores) == 0:
            return Score(
                name=name,
                score=0,
                error="Bad example: no filter or sort for SQL function call",
            )
        return Score(name=name, score=sum(scores) / len(scores))

Run the evals!

We'll use the Braintrust Eval framework to set up our experiments according to the prompts, dataset, and scoring functions defined above.

def build_completion_kwargs(
    *,
    query: str,
    model: str,
    prompt: str,
    score_fields: List[str],
    **kwargs,
):
    # Inject the JSON schema into the prompt to assist the model.
    schema = build_experiment_schema(score_fields=score_fields)
    system_message = chevron.render(
        prompt.strip(), {"schema": schema, "examples": few_shot_examples}, warn=True
    )
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": f"Query: {query}"},
    ]
 
    # We use the legacy function choices format for now, because fine-tuning still requires it.
    return dict(
        model=model,
        temperature=0,
        messages=messages,
        functions=function_choices(),
        function_call={"name": "QUERY"},
    )
 
 
def format_output(completion):
    try:
        function_call = completion.choices[0].message.function_call
        arguments = json.loads(function_call.arguments)["value"]
        match = arguments.pop("type").lower() == "match"
        return FunctionCallOutput(match=match, **arguments)
    except Exception as e:
        return FunctionCallOutput(error=str(e))
 
 
GRADER = "gpt-4"  # Used by AutoScorer to grade the model outputs
 
 
def make_task(model, prompt, score_fields):
    async def task(input):
        completion_kwargs = build_completion_kwargs(
            query=input,
            model=model,
            prompt=prompt,
            score_fields=score_fields,
        )
        return format_output(await client.chat.completions.create(**completion_kwargs))
 
    return task
 
 
async def run_eval(experiment_name, prompt, model, score_fields=SCORE_FIELDS):
    task = make_task(model, prompt, score_fields)
    await braintrust.Eval(
        name=PROJECT_NAME,
        experiment_name=experiment_name,
        data=dataset,
        task=task,
        scores=[SQLScorer(), AutoScorer(**openai_opts, model=GRADER)],
    )

Let's try it on one example before running an eval.

args = build_completion_kwargs(
    query=list(dataset)[0]["input"],
    model="gpt-3.5-turbo",
    prompt=short_prompt,
    score_fields=SCORE_FIELDS,
)
response = await client.chat.completions.create(**args)
format_output(response)
FunctionCallOutput(match=False, filter="(name) = 'foo'", sort=None, explanation="Filtered for experiments where the name is 'foo'.", error=None)

We're ready to run our evals! Let's use gpt-3.5-turbo for both.

await run_eval("Short Prompt", short_prompt, "gpt-3.5-turbo")
Experiment Short Prompt is running at https://www.braintrust.dev/app/braintrust.dev/p/AI%20Search%20Cookbook/Short%20Prompt
AI Search Cookbook [experiment_name=Short Prompt] (data): 45it [00:00, 73071.50it/s]
AI Search Cookbook [experiment_name=Short Prompt] (tasks):   0%|          | 0/45 [00:00<?, ?it/s]

=========================SUMMARY=========================
Short Prompt compared to Long Prompt 2.0:
46.28% (-21.68%) 'SQLScorer'       score	(10 improvements, 25 regressions)
15.00% (-36.52%) 'exact_match'     score	(2 improvements, 7 regressions)
40.89% (-32.19%) 'sort'            score	(0 improvements, 4 regressions)
16.67% (+01.96%) 'roundtrip_match' score	(2 improvements, 3 regressions)
69.36% (-04.67%) 'filter'          score	(6 improvements, 10 regressions)
60.00% (-22.22%) 'function_choice' score	(5 improvements, 15 regressions)
70.00% (-16.67%) 'valid_clause'    score	(1 improvements, 0 regressions)
43.33% (-12.22%) 'AutoScorer'      score	(9 improvements, 15 regressions)

4.54s (-210.10%) 'duration'	(28 improvements, 17 regressions)

See results for Short Prompt at https://www.braintrust.dev/app/braintrust.dev/p/AI%20Search%20Cookbook/Short%20Prompt
await run_eval("Long Prompt", long_prompt, "gpt-3.5-turbo")
Experiment Long Prompt is running at https://www.braintrust.dev/app/braintrust.dev/p/AI%20Search%20Cookbook/Long%20Prompt
AI Search Cookbook [experiment_name=Long Prompt] (data): 45it [00:00, 35385.02it/s]
AI Search Cookbook [experiment_name=Long Prompt] (tasks):   0%|          | 0/45 [00:00<?, ?it/s]

=========================SUMMARY=========================
Long Prompt compared to Short Prompt:
67.99% (+21.71%) 'SQLScorer'       score	(21 improvements, 5 regressions)
50.00% (+35.00%) 'exact_match'     score	(6 improvements, 1 regressions)
71.92% (+31.02%) 'sort'            score	(3 improvements, 0 regressions)
03.12% (-13.54%) 'roundtrip_match' score	(1 improvements, 2 regressions)
71.53% (+02.17%) 'filter'          score	(10 improvements, 5 regressions)
77.78% (+17.78%) 'function_choice' score	(9 improvements, 1 regressions)
84.38% (+14.38%) 'valid_clause'    score	(1 improvements, 1 regressions)
55.56% (+12.22%) 'AutoScorer'      score	(9 improvements, 4 regressions)

5.90s (+136.66%) 'duration'	(11 improvements, 34 regressions)

See results for Long Prompt at https://www.braintrust.dev/app/braintrust.dev/p/AI%20Search%20Cookbook/Long%20Prompt

View the results in Braintrust

The evals will generate a link to the experiment page. Click into an experiment to view the results!

If you've just been following along, you can check out some sample results here. Type some searches into the search bar to see AI search in action. :)

Braintrust Project Page

Fine-tuning

Let's try to fine-tune the model with an exceedingly short prompt. We'll use the same dataset and scoring functions, but we'll change the prompt to be more concise. To start, let's play with one example:

first = list(dataset.fetch())[0]
print(first["input"])
print(json.dumps(first["expected"], indent=2))
name is foo
{
  "sort": null,
  "error": null,
  "match": false,
  "filter": "name = 'foo'",
  "explanation": "I interpret the query as a string equality filter on the \"name\" column. The query does not have any sort semantics, so there is no sort."
}
from dataclasses import asdict
from pprint import pprint
 
long_prompt_args = build_completion_kwargs(
    query=first["input"],
    model="gpt-3.5-turbo",
    prompt=long_prompt,
    score_fields=SCORE_FIELDS,
)
output = await client.chat.completions.create(**long_prompt_args)
function_call = output.choices[0].message.function_call
print(function_call.name)
pprint(json.loads(function_call.arguments))
QUERY
{'value': {'explanation': "The query refers to the 'name' field in the "
                          "'experiments' table, so I used ILIKE to check if "
                          "the name contains 'foo'. I wrapped the filter in "
                          'parentheses and used ILIKE for case-insensitive '
                          'matching.',
           'filter': "name ILIKE 'foo'",
           'sort': None,
           'type': 'SQL'}}

Great! Now let's turn the output from the dataset into the tool call format that OpenAI expects.

def transform_function_call(expected_value):
    return {
        "name": "QUERY",
        "arguments": json.dumps(
            {
                "value": {
                    "type": (
                        expected_value.get("function")
                        if expected_value.get("function")
                        else "MATCH" if expected_value.get("match") else "SQL"
                    ),
                    **{
                        k: v
                        for (k, v) in expected_value.items()
                        if k in ("filter", "sort", "explanation") and v is not None
                    },
                }
            }
        ),
    }
 
 
transform_function_call(first["expected"])
{'name': 'QUERY',
 'arguments': '{"value": {"type": "SQL", "filter": "name = \'foo\'", "explanation": "I interpret the query as a string equality filter on the \\"name\\" column. The query does not have any sort semantics, so there is no sort."}}'}

This function also works on our few shot examples:

transform_function_call(few_shot_examples[0])
{'name': 'QUERY',
 'arguments': '{"value": {"type": "SQL", "filter": "(metrics->>\'accuracy\')::NUMERIC < 0.2", "explanation": "The query refers to a JSON field, so I correct the JSON extraction syntax according to directive 4 and cast the result to NUMERIC to compare to the value \`0.2\` as per directive 9."}}'}

Since we're fine-tuning, we can also use a shorter prompt that just contains the object type (Experiment) and schema.

FINE_TUNING_PROMPT_FILE = "./assets/fine_tune.tmpl"
 
with open(FINE_TUNING_PROMPT_FILE) as f:
    fine_tune_prompt = f.read()
def build_expected_messages(query, expected, prompt, score_fields):
    args = build_completion_kwargs(
        query=first["input"],
        model="gpt-3.5-turbo",
        prompt=fine_tune_prompt,
        score_fields=score_fields,
    )
    function_call = transform_function_call(expected)
    return {
        "messages": args["messages"]
        + [{"role": "assistant", "function_call": function_call}],
        "functions": args["functions"],
    }
 
 
build_expected_messages(
    first["input"], first["expected"], fine_tune_prompt, SCORE_FIELDS
)
{'messages': [{'role': 'system',
   'content': 'Table: experiments\n\n<Schema>\n{"$defs": {"ExperimentGitState": {"properties": {"commit": {"description": "Git commit hash. Any prefix of this hash at least 7 characters long should be considered an exact match, so use a substring filter rather than string equality to check the commit, e.g. \`(source->>\'commit\') ILIKE \'{COMMIT}%\'\`", "title": "Commit", "type": "string"}, "branch": {"description": "Git branch name", "title": "Branch", "type": "string"}, "tag": {"anyOf": [{"type": "string"}, {"type": "null"}], "description": "Git commit tag", "title": "Tag"}, "commit_time": {"description": "Git commit timestamp", "title": "Commit Time", "type": "integer"}, "author_name": {"description": "Author of git commit", "title": "Author Name", "type": "string"}, "author_email": {"description": "Email address of git commit author", "title": "Author Email", "type": "string"}, "commit_message": {"description": "Git commit message", "title": "Commit Message", "type": "string"}, "dirty": {"anyOf": [{"type": "boolean"}, {"type": "null"}], "description": "Whether the git state was dirty when the experiment was run. If false, the git state was clean", "title": "Dirty"}}, "required": ["commit", "branch", "tag", "commit_time", "author_name", "author_email", "commit_message", "dirty"], "title": "ExperimentGitState", "type": "object"}}, "properties": {"id": {"description": "Experiment ID, unique", "title": "Id", "type": "string"}, "name": {"description": "Name of the experiment", "title": "Name", "type": "string"}, "last_updated": {"description": "Timestamp marking when the experiment was last updated. If the query deals with some notion of relative time, like age or recency, refer to this timestamp and, if appropriate, compare it to the current time \`get_current_time()\` by adding or subtracting an interval.", "title": "Last Updated", "type": "integer"}, "creator": {"additionalProperties": {"type": "string"}, "description": "Information about the experiment creator", "title": "Creator", "type": "object"}, "source": {"allOf": [{"$ref": "#/$defs/ExperimentGitState"}], "description": "Git state that the experiment was run on"}, "metadata": {"description": "Custom metadata provided by the user. Ignore this field unless the query mentions metadata or refers to a metadata key specifically", "title": "Metadata", "type": "object"}, "avg_sql_score": {"anyOf": [{"type": "number"}, {"type": "null"}], "title": "Avg Sql Score"}, "avg_factuality_score": {"anyOf": [{"type": "number"}, {"type": "null"}], "title": "Avg Factuality Score"}}, "required": ["id", "name", "last_updated", "creator", "source", "metadata", "avg_sql_score", "avg_factuality_score"], "title": "Experiment", "type": "object"}\n</Schema>'},
  {'role': 'user', 'content': 'Query: name is foo'},
  {'role': 'assistant',
   'function_call': {'name': 'QUERY',
    'arguments': '{"value": {"type": "SQL", "filter": "name = \'foo\'", "explanation": "I interpret the query as a string equality filter on the \\"name\\" column. The query does not have any sort semantics, so there is no sort."}}'}}],
 'functions': [{'name': 'QUERY',
   'description': 'Break down the query either into a MATCH or SQL call',
   'parameters': {'$defs': {'Match': {'properties': {'type': {'const': 'MATCH',
        'default': 'MATCH',
        'title': 'Type'},
       'explanation': {'description': 'Explanation of why I called the MATCH function',
        'title': 'Explanation',
        'type': 'string'}},
      'required': ['explanation'],
      'title': 'Match',
      'type': 'object'},
     'SQL': {'properties': {'type': {'const': 'SQL',
        'default': 'SQL',
        'title': 'Type'},
       'filter': {'anyOf': [{'type': 'string'}, {'type': 'null'}],
        'description': 'SQL filter clause',
        'title': 'Filter'},
       'sort': {'anyOf': [{'type': 'string'}, {'type': 'null'}],
        'description': 'SQL sort clause',
        'title': 'Sort'},
       'explanation': {'description': 'Explanation of why I called the SQL function and how I chose the filter and/or sort clauses',
        'title': 'Explanation',
        'type': 'string'}},
      'required': ['filter', 'sort', 'explanation'],
      'title': 'SQL',
      'type': 'object'}},
    'properties': {'value': {'anyOf': [{'$ref': '#/$defs/Match'},
       {'$ref': '#/$defs/SQL'}],
      'title': 'Value'}},
    'required': ['value'],
    'title': 'Query',
    'type': 'object'}}]}

Let's construct messages from our train split and few-shot examples, and then fine-tune the model.

train_records = [r for r in records if r["metadata"]["split"] == "train"] + [
    {"input": r["query"], "expected": r} for r in few_shot_examples
]
all_expected_messages = [
    build_expected_messages(r["input"], r["expected"], fine_tune_prompt, SCORE_FIELDS)
    for r in train_records
]
 
print(len(all_expected_messages))
all_expected_messages[1]
49
{'messages': [{'role': 'system',
   'content': 'Table: experiments\n\n<Schema>\n{"$defs": {"ExperimentGitState": {"properties": {"commit": {"description": "Git commit hash. Any prefix of this hash at least 7 characters long should be considered an exact match, so use a substring filter rather than string equality to check the commit, e.g. \`(source->>\'commit\') ILIKE \'{COMMIT}%\'\`", "title": "Commit", "type": "string"}, "branch": {"description": "Git branch name", "title": "Branch", "type": "string"}, "tag": {"anyOf": [{"type": "string"}, {"type": "null"}], "description": "Git commit tag", "title": "Tag"}, "commit_time": {"description": "Git commit timestamp", "title": "Commit Time", "type": "integer"}, "author_name": {"description": "Author of git commit", "title": "Author Name", "type": "string"}, "author_email": {"description": "Email address of git commit author", "title": "Author Email", "type": "string"}, "commit_message": {"description": "Git commit message", "title": "Commit Message", "type": "string"}, "dirty": {"anyOf": [{"type": "boolean"}, {"type": "null"}], "description": "Whether the git state was dirty when the experiment was run. If false, the git state was clean", "title": "Dirty"}}, "required": ["commit", "branch", "tag", "commit_time", "author_name", "author_email", "commit_message", "dirty"], "title": "ExperimentGitState", "type": "object"}}, "properties": {"id": {"description": "Experiment ID, unique", "title": "Id", "type": "string"}, "name": {"description": "Name of the experiment", "title": "Name", "type": "string"}, "last_updated": {"description": "Timestamp marking when the experiment was last updated. If the query deals with some notion of relative time, like age or recency, refer to this timestamp and, if appropriate, compare it to the current time \`get_current_time()\` by adding or subtracting an interval.", "title": "Last Updated", "type": "integer"}, "creator": {"additionalProperties": {"type": "string"}, "description": "Information about the experiment creator", "title": "Creator", "type": "object"}, "source": {"allOf": [{"$ref": "#/$defs/ExperimentGitState"}], "description": "Git state that the experiment was run on"}, "metadata": {"description": "Custom metadata provided by the user. Ignore this field unless the query mentions metadata or refers to a metadata key specifically", "title": "Metadata", "type": "object"}, "avg_sql_score": {"anyOf": [{"type": "number"}, {"type": "null"}], "title": "Avg Sql Score"}, "avg_factuality_score": {"anyOf": [{"type": "number"}, {"type": "null"}], "title": "Avg Factuality Score"}}, "required": ["id", "name", "last_updated", "creator", "source", "metadata", "avg_sql_score", "avg_factuality_score"], "title": "Experiment", "type": "object"}\n</Schema>'},
  {'role': 'user', 'content': 'Query: name is foo'},
  {'role': 'assistant',
   'function_call': {'name': 'QUERY',
    'arguments': '{"value": {"type": "MATCH", "explanation": "According to directive 2, a query entirely wrapped in quotes should use the MATCH function."}}'}}],
 'functions': [{'name': 'QUERY',
   'description': 'Break down the query either into a MATCH or SQL call',
   'parameters': {'$defs': {'Match': {'properties': {'type': {'const': 'MATCH',
        'default': 'MATCH',
        'title': 'Type'},
       'explanation': {'description': 'Explanation of why I called the MATCH function',
        'title': 'Explanation',
        'type': 'string'}},
      'required': ['explanation'],
      'title': 'Match',
      'type': 'object'},
     'SQL': {'properties': {'type': {'const': 'SQL',
        'default': 'SQL',
        'title': 'Type'},
       'filter': {'anyOf': [{'type': 'string'}, {'type': 'null'}],
        'description': 'SQL filter clause',
        'title': 'Filter'},
       'sort': {'anyOf': [{'type': 'string'}, {'type': 'null'}],
        'description': 'SQL sort clause',
        'title': 'Sort'},
       'explanation': {'description': 'Explanation of why I called the SQL function and how I chose the filter and/or sort clauses',
        'title': 'Explanation',
        'type': 'string'}},
      'required': ['filter', 'sort', 'explanation'],
      'title': 'SQL',
      'type': 'object'}},
    'properties': {'value': {'anyOf': [{'$ref': '#/$defs/Match'},
       {'$ref': '#/$defs/SQL'}],
      'title': 'Value'}},
    'required': ['value'],
    'title': 'Query',
    'type': 'object'}}]}
import io
 
# Use the direct OpenAI client, not a proxy
sync_client = openai.OpenAI(
    api_key=os.environ.get("OPENAI_API_KEY", "<Your OpenAI API Key>"),
    base_url="https://api.openai.com/v1",
)
 
file_string = "\n".join(json.dumps(messages) for messages in all_expected_messages)
file = sync_client.files.create(
    file=io.BytesIO(file_string.encode()), purpose="fine-tune"
)
job = sync_client.fine_tuning.jobs.create(training_file=file.id, model="gpt-3.5-turbo")
import time
 
start = time.time()
job_id = job.id
while True:
    info = sync_client.fine_tuning.jobs.retrieve(job_id)
    if info.finished_at is not None:
        break
    print(f"{time.time() - start:.0f}s elapsed", end="\t")
    print(str(info), end="\r")
    time.sleep(10)
info = sync_client.fine_tuning.jobs.retrieve(job_id)
fine_tuned_model = info.fine_tuned_model
fine_tuned_model
ft_prompt_args = build_completion_kwargs(
    query=first["input"],
    model=fine_tuned_model,
    prompt=fine_tune_prompt,
    score_fields=SCORE_FIELDS,
)
del ft_prompt_args["temperature"]
print(ft_prompt_args)
output = await client.chat.completions.create(**ft_prompt_args)
print(output)
print(format_output(output))
await run_eval("Fine tuned model", fine_tune_prompt, fine_tuned_model)
Experiment Fine tuned model is running at https://www.braintrust.dev/app/braintrust.dev/p/AI%20Search%20Cookbook/Fine%20tuned%20model
AI Search Cookbook [experiment_name=Fine tuned model] (data): 45it [00:00, 15835.53it/s]
AI Search Cookbook [experiment_name=Fine tuned model] (tasks):   0%|          | 0/45 [00:00<?, ?it/s]

=========================SUMMARY=========================
Fine tuned model compared to Long Prompt:
77.78% (-) 'function_choice' score	(8 improvements, 8 regressions)
75.93% (-08.45%) 'valid_clause'    score	(0 improvements, 2 regressions)
30.00% (-20.00%) 'exact_match'     score	(2 improvements, 9 regressions)
48.09% (-23.44%) 'filter'          score	(5 improvements, 15 regressions)
53.44% (-18.47%) 'sort'            score	(1 improvements, 4 regressions)
32.22% (-23.33%) 'AutoScorer'      score	(7 improvements, 18 regressions)
05.36% (+02.23%) 'roundtrip_match' score	(1 improvements, 1 regressions)
48.22% (-19.77%) 'SQLScorer'       score	(10 improvements, 25 regressions)

79.41s (+7350.58%) 'duration'	(0 improvements, 45 regressions)

See results for Fine tuned model at https://www.braintrust.dev/app/braintrust.dev/p/AI%20Search%20Cookbook/Fine%20tuned%20model