|
6 | 6 |
|
7 | 7 | from langchain.prompts import PromptTemplate |
8 | 8 | from langchain_aws import ChatBedrock |
| 9 | +from langchain_community.chat_models import ChatOllama |
9 | 10 | from langchain_core.output_parsers import JsonOutputParser |
10 | 11 | from langchain_core.runnables import RunnableParallel |
11 | 12 | from langchain_mistralai import ChatMistralAI |
12 | | -from langchain_openai import AzureChatOpenAI, ChatOpenAI |
| 13 | +from langchain_openai import ChatOpenAI |
13 | 14 | from tqdm import tqdm |
14 | 15 |
|
15 | 16 | from ..prompts import ( |
@@ -55,6 +56,13 @@ def __init__( |
55 | 56 | super().__init__(node_name, "node", input, output, 2, node_config) |
56 | 57 |
|
57 | 58 | self.llm_model = node_config["llm_model"] |
| 59 | + |
| 60 | + if isinstance(node_config["llm_model"], ChatOllama): |
| 61 | + if node_config.get("schema", None) is None: |
| 62 | + self.llm_model.format = "json" |
| 63 | + else: |
| 64 | + self.llm_model.format = self.node_config["schema"].model_json_schema() |
| 65 | + |
58 | 66 | self.embedder_model = node_config.get("embedder_model", None) |
59 | 67 | self.verbose = node_config.get("verbose", False) |
60 | 68 | self.force = node_config.get("force", False) |
@@ -92,8 +100,7 @@ def execute(self, state: dict) -> dict: |
92 | 100 | format_instructions = "" |
93 | 101 |
|
94 | 102 | if ( |
95 | | - isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) |
96 | | - and not self.script_creator |
| 103 | + not self.script_creator |
97 | 104 | or self.force |
98 | 105 | and not self.script_creator |
99 | 106 | or self.is_md_scraper |
|
0 commit comments