Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 3 additions & 24 deletions outlines/models/dottxt.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
"""Integration with Dottxt's API."""

import json
from typing import TYPE_CHECKING, Any, Optional

from pydantic import TypeAdapter
from typing import TYPE_CHECKING, Any, Optional, cast

from outlines.models.base import Model, ModelTypeAdapter
from outlines.types import CFG, JsonSchema, Regex
from outlines.types.utils import (
is_dataclass,
is_genson_schema_builder,
is_pydantic_model,
is_typed_dict,
)

if TYPE_CHECKING:
from dottxt import Dottxt as DottxtClient
Expand Down Expand Up @@ -77,20 +68,8 @@ def format_output_type(self, output_type: Optional[Any] = None) -> str:
"CFG-based structured outputs will soon be available with "
"Dottxt. Use an open source model in the meantime."
)

elif isinstance(output_type, JsonSchema):
return output_type.schema
elif is_dataclass(output_type):
schema = TypeAdapter(output_type).json_schema()
return json.dumps(schema)
elif is_typed_dict(output_type):
schema = TypeAdapter(output_type).json_schema()
return json.dumps(schema)
elif is_pydantic_model(output_type):
schema = output_type.model_json_schema()
return json.dumps(schema)
elif is_genson_schema_builder(output_type):
return output_type.to_json()
elif JsonSchema.is_json_schema(output_type):
return cast(str, JsonSchema.convert_to(output_type, ["str"]))
else:
type_name = getattr(output_type, "__name__", output_type)
raise TypeError(
Expand Down
48 changes: 17 additions & 31 deletions outlines/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@
from outlines.models.base import Model, ModelTypeAdapter
from outlines.types import CFG, Choice, JsonSchema, Regex
from outlines.types.utils import (
is_dataclass,
is_enum,
get_enum_from_choice,
get_enum_from_literal,
is_genson_schema_builder,
is_literal,
is_pydantic_model,
is_typed_dict,
is_typing_list,
)

Expand Down Expand Up @@ -171,28 +168,18 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict:
"CFG-based structured outputs are not available with Gemini. "
"Use an open source model or dottxt instead."
)
elif is_genson_schema_builder(output_type):
raise TypeError(
"The Gemini SDK does not accept Genson schema builders as an "
"input. Pass a Pydantic model, typed dict or dataclass "
"instead."
)
elif isinstance(output_type, JsonSchema):
raise TypeError(
"The Gemini SDK does not accept Json Schemas as an input. "
"Pass a Pydantic model, typed dict or dataclass instead."
)

if output_type is None:
return {}

# Structured types
elif is_dataclass(output_type):
return self.format_json_output_type(output_type)
elif is_typed_dict(output_type):
return self.format_json_output_type(output_type)
elif is_pydantic_model(output_type):
return self.format_json_output_type(output_type)
# JSON schema types
elif JsonSchema.is_json_schema(output_type):
return self.format_json_output_type(
JsonSchema.convert_to(
output_type,
["dataclass", "typeddict", "pydantic"]
)
)

# List of structured types
elif is_typing_list(output_type):
Expand Down Expand Up @@ -233,21 +220,20 @@ def format_list_output_type(self, output_type: Optional[Any]) -> dict:
if len(args) == 1:
item_type = args[0]

# Check if list item type is supported
if (
is_pydantic_model(item_type)
or is_typed_dict(item_type)
or is_dataclass(item_type)
):
if JsonSchema.is_json_schema(item_type):
return {
"response_mime_type": "application/json",
"response_schema": output_type,
"response_schema": list[ # type: ignore
JsonSchema.convert_to(
item_type,
["dataclass", "typeddict", "pydantic"]
)
],
}

else:
raise TypeError(
"The only supported types for list items are Pydantic "
+ "models, typed dicts and dataclasses."
"The list items output type must contain a JSON schema "
"type."
)

raise TypeError(
Expand Down
43 changes: 15 additions & 28 deletions outlines/models/ollama.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
"""Integration with the `ollama` library."""

import json
from functools import singledispatchmethod
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional, Union

from pydantic import TypeAdapter
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Iterator,
Optional,
Union,
cast,
)

from outlines.inputs import Chat, Image
from outlines.models.base import AsyncModel, Model, ModelTypeAdapter
from outlines.types import CFG, JsonSchema, Regex
from outlines.types.utils import (
is_dataclass,
is_genson_schema_builder,
is_pydantic_model,
is_typed_dict,
)

if TYPE_CHECKING:
from ollama import Client
Expand Down Expand Up @@ -109,7 +108,7 @@ def _create_message(self, role: str, content: str | list) -> dict:

def format_output_type(
self, output_type: Optional[Any] = None
) -> Optional[str]:
) -> Optional[dict]:
"""Format the output type to pass to the client.

TODO: `int`, `float` and other Python types could be supported via
Expand All @@ -126,7 +125,9 @@ def format_output_type(
The formatted output type to be passed to the model.

"""
if isinstance(output_type, Regex):
if output_type is None:
return None
elif isinstance(output_type, Regex):
raise TypeError(
"Regex-based structured outputs are not supported by Ollama. "
"Use an open source model in the meantime."
Expand All @@ -136,22 +137,8 @@ def format_output_type(
"CFG-based structured outputs are not supported by Ollama. "
"Use an open source model in the meantime."
)

if output_type is None:
return None
elif isinstance(output_type, JsonSchema):
return json.loads(output_type.schema)
elif is_dataclass(output_type):
schema = TypeAdapter(output_type).json_schema()
return schema
elif is_typed_dict(output_type):
schema = TypeAdapter(output_type).json_schema()
return schema
elif is_pydantic_model(output_type):
schema = output_type.model_json_schema()
return schema
elif is_genson_schema_builder(output_type):
return output_type.to_json()
elif JsonSchema.is_json_schema(output_type):
return cast(dict, JsonSchema.convert_to(output_type, ["dict"]))
else:
type_name = getattr(output_type, "__name__", output_type)
raise TypeError(
Expand Down
30 changes: 7 additions & 23 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
"""Integration with OpenAI's API."""

import json
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Iterator,
Optional,
Union,
cast,
)
from functools import singledispatchmethod

from pydantic import BaseModel, TypeAdapter
from pydantic import BaseModel

from outlines.inputs import Chat, Image
from outlines.models.base import AsyncModel, Model, ModelTypeAdapter
from outlines.models.utils import set_additional_properties_false_json_schema
from outlines.types import JsonSchema, Regex, CFG
from outlines.types.utils import (
is_dataclass,
is_typed_dict,
is_pydantic_model,
is_genson_schema_builder,
is_native_dict
)
from outlines.types.utils import is_native_dict

if TYPE_CHECKING:
from openai import (
Expand Down Expand Up @@ -176,20 +170,10 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict:
return {}
elif is_native_dict(output_type):
return self.format_json_mode_type()
elif is_dataclass(output_type):
output_type = TypeAdapter(output_type).json_schema()
return self.format_json_output_type(output_type)
elif is_typed_dict(output_type):
output_type = TypeAdapter(output_type).json_schema()
return self.format_json_output_type(output_type)
elif is_pydantic_model(output_type):
output_type = output_type.model_json_schema()
return self.format_json_output_type(output_type)
elif is_genson_schema_builder(output_type):
schema = json.loads(output_type.to_json())
return self.format_json_output_type(schema)
elif isinstance(output_type, JsonSchema):
return self.format_json_output_type(json.loads(output_type.schema))
elif JsonSchema.is_json_schema(output_type):
return self.format_json_output_type(
cast(dict, JsonSchema.convert_to(output_type, ["dict"]))
)
else:
type_name = getattr(output_type, "__name__", output_type)
raise TypeError(
Expand Down
Loading
Loading