Skip to content

Add Formatron framework #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,35 @@ ModelsmithFramework:
llm_model_family: "openai"
retries: 2

FormatronFramework:
- task: "multilabel_classification"
n_runs: 10
init_kwargs:
prompt: "Classify the following text: {text}.\nRespond in the following json schema: {json_schema}:\n"
llm_model: "unsloth/llama-3-8b-Instruct-bnb-4bit"
llm_model_family: "transformers"
retries: 0
source_data_pickle_path: "data/multilabel_classification.pkl"
max_length: 4096
# sample_rows: 2
# - task: "ner_required_fields"
# n_runs: 10
# init_kwargs:
# prompt: "Extract and resolve a list of entities from the following text: {text}.\nRespond in the following json schema: {json_schema}:\n"
# llm_model: "unsloth/llama-3-8b-Instruct-bnb-4bit"
# llm_model_family: "transformers"
# retries: 0
# source_data_pickle_path: "data/ner.pkl"
# max_length: 4096
# # sample_rows: 2
- task: "synthetic_data_generation"
n_runs: 100
init_kwargs:
prompt: "Generate a random person's information. The name must be chosen at random. Make it something you wouldn't normally choose.\nRespond in the following json schema: {json_schema}:\n"
llm_model: "unsloth/llama-3-8b-Instruct-bnb-4bit"
llm_model_family: "transformers"
max_length: 4096

# ModelsmithFramework:
# - task: "ner"
# n_runs: 10
Expand Down Expand Up @@ -234,4 +263,4 @@ ModelsmithFramework:
# llm_model_family: "transformers"
# retries: 0 # Oulines transformers has no retry parameter
# source_data_pickle_path: "data/ner.pkl"
# # sample_rows: 2
# # sample_rows: 2
1 change: 1 addition & 0 deletions frameworks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

from frameworks.base import experiment
from frameworks.formatron import FormatronFramework
from frameworks.fructose_framework import FructoseFramework
from frameworks.instructor_framework import InstructorFramework
from frameworks.llamaindex_framework import LlamaIndexFramework
Expand Down
72 changes: 72 additions & 0 deletions frameworks/formatron_framework.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import json
from typing import Any

from formatron.formatter import FormatterBuilder
from formatron.integrations.transformers import create_formatter_logits_processor_list
from formatron.schemas import json_schema
from outlines.fsm.json_schema import build_regex_from_schema
from transformers import AutoModelForCausalLM, AutoTokenizer

from frameworks.base import BaseFramework, experiment


class FormatronFramework(BaseFramework):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.max_length = kwargs.get("max_length", 4096)
# whether the model returns a regex match (only for the fallback regex
# alternative)
self.load_json_from_re_match = False

if self.llm_model_family != "transformers":
raise ValueError(f"Model family: {self.llm_model_family} not supported")

f = FormatterBuilder()
if self.task != "multilabel_classification":
model_schema = self.response_model.model_json_schema()
# pydantic (v2.7.1) doesn't have a good way to include $schema,
# c.f.: https://github.com/pydantic/pydantic/issues/1478
model_schema["$id"] = model_schema["title"]
model_schema["$schema"] = "https://json-schema.org/draft/2020-12/schema"
schema = json_schema.create_schema(model_schema)
response = f.json(schema, capture_name="json")
else:
# fall back to outlines's pydantic regex for:
# - multilabel_classification task with enum: enum seems to be
# supported by formatron v0.4, but the output appears to be buggy
# at first glance, a mixture of [] and [set()]
self.load_json_from_re_match = True
schema = json.dumps(self.response_model.model_json_schema())
whitespace_pattern = r" ?"
regex_str = build_regex_from_schema(schema, whitespace_pattern)
response = f.regex(regex_str, capture_name="json")
f.append_line(f"{response}")

self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model)
self.logits_processor = create_formatter_logits_processor_list(
self.tokenizer, f
)
self.model = AutoModelForCausalLM.from_pretrained(self.llm_model)

def run(
self, task: str, n_runs: int, expected_response: Any = None, inputs: dict = {}
) -> tuple[list[Any], float, dict, list[list[float]]]:
@experiment(n_runs=n_runs, expected_response=expected_response, task=task)
def run_experiment(inputs):
prompt = self.prompt.format(
json_schema=self.response_model.model_json_schema(), **inputs
)
tokens = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
self.logits_processor[0].reset()
self.model.generate(
**tokens,
logits_processor=self.logits_processor,
max_length=self.max_length,
)
response = self.logits_processor[0].formatters_captures[0]["json"]
if self.load_json_from_re_match:
response = json.loads(response.group(0))
return response

predictions, percent_successful, metrics, latencies = run_experiment(inputs)
return predictions, percent_successful, metrics, latencies
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
accelerate==0.31.0
bitsandbytes==0.43.1
datasets==2.19.1
formatron==0.4.2
fructose==0.0.13
instructor==1.3.5
llama-index==0.10.56
Expand All @@ -15,4 +16,4 @@ plotly==5.22.0
tabulate==0.9.0
torch==2.3.1
transformers==4.42.4
typer==0.12.3
typer==0.12.3