Unverified Commit af2138e8 authored by John Wang's avatar John Wang Committed by GitHub

fix: json parse in router chain output (#243)

parent 091beffa
...@@ -84,13 +84,16 @@ class RouterOutputParser(BaseOutputParser[Dict[str, str]]): ...@@ -84,13 +84,16 @@ class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
def parse_json_markdown(self, json_string: str) -> dict: def parse_json_markdown(self, json_string: str) -> dict:
# Remove the triple backticks if present # Remove the triple backticks if present
json_string = json_string.replace("```json", "").replace("```", "") start_index = json_string.find("```json")
end_index = json_string.find("```", start_index + len("```json"))
# Strip whitespace and newlines from the start and end if start_index != -1 and end_index != -1:
json_string = json_string.strip() extracted_content = json_string[start_index + len("```json"):end_index].strip()
# Parse the JSON string into a Python dictionary # Parse the JSON string into a Python dictionary
parsed = json.loads(json_string) parsed = json.loads(extracted_content)
else:
raise Exception("Could not find JSON block in the output.")
return parsed return parsed
......
...@@ -90,7 +90,7 @@ class MultiDatasetRouterChain(Chain): ...@@ -90,7 +90,7 @@ class MultiDatasetRouterChain(Chain):
callback_manager=llm_callback_manager callback_manager=llm_callback_manager
) )
destinations = [f"{d.id}: {d.description}" for d in datasets] destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ')) for d in datasets]
destinations_str = "\n".join(destinations) destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format( router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
destinations=destinations_str destinations=destinations_str
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment