Commit bcaf2274 authored by John Wang's avatar John Wang

feat: remove _generate for llm

parent 4de6e955
...@@ -14,7 +14,7 @@ from flask import Flask, request, Response, session ...@@ -14,7 +14,7 @@ from flask import Flask, request, Response, session
import flask_login import flask_login
from flask_cors import CORS from flask_cors import CORS
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_vector_store, ext_migrate, \ from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage ext_database, ext_storage
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_login import login_manager from extensions.ext_login import login_manager
...@@ -79,7 +79,6 @@ def initialize_extensions(app): ...@@ -79,7 +79,6 @@ def initialize_extensions(app):
ext_database.init_app(app) ext_database.init_app(app)
ext_migrate.init(app, db) ext_migrate.init(app, db)
ext_redis.init_app(app) ext_redis.init_app(app)
ext_vector_store.init_app(app)
ext_storage.init_app(app) ext_storage.init_app(app)
ext_celery.init_app(app) ext_celery.init_app(app)
ext_session.init_app(app) ext_session.init_app(app)
......
...@@ -53,31 +53,6 @@ class LLMCallbackHandler(BaseCallbackHandler): ...@@ -53,31 +53,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
self.start_at = time.perf_counter() self.start_at = time.perf_counter()
# todo chat serialized maybe deprecated in future
if 'Chat' in serialized['name']:
real_prompts = []
messages = []
for prompt in prompts:
role, content = prompt.split(': ', maxsplit=1)
if role == 'human':
role = 'user'
message = HumanMessage(content=content)
elif role == 'ai':
role = 'assistant'
message = AIMessage(content=content)
else:
message = SystemMessage(content=content)
real_prompt = {
"role": role,
"text": content
}
real_prompts.append(real_prompt)
messages.append(message)
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
else:
self.llm_message.prompt = [{ self.llm_message.prompt = [{
"role": 'user', "role": 'user',
"text": prompts[0] "text": prompts[0]
......
...@@ -19,6 +19,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): ...@@ -19,6 +19,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
**kwargs: Any **kwargs: Any
) -> Any: ) -> Any:
print_text("\n[on_chat_model_start]\n", color='blue')
for sub_messages in messages: for sub_messages in messages:
for sub_message in sub_messages: for sub_message in sub_messages:
print_text(str(sub_message) + "\n", color='blue') print_text(str(sub_message) + "\n", color='blue')
...@@ -28,11 +29,6 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): ...@@ -28,11 +29,6 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""Print out the prompts.""" """Print out the prompts."""
print_text("\n[on_llm_start]\n", color='blue') print_text("\n[on_llm_start]\n", color='blue')
if 'Chat' in serialized['name']:
for prompt in prompts:
print_text(prompt + "\n", color='blue')
else:
print_text(prompts[0] + "\n", color='blue') print_text(prompts[0] + "\n", color='blue')
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
......
...@@ -3,7 +3,7 @@ from collections import defaultdict ...@@ -3,7 +3,7 @@ from collections import defaultdict
from typing import Any, List, Optional, Dict from typing import Any, List, Optional, Dict
from langchain.schema import Document, BaseRetriever from langchain.schema import Document, BaseRetriever
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, Extra
from core.index.base import BaseIndex from core.index.base import BaseIndex
from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler
...@@ -170,6 +170,12 @@ class KeywordTableRetriever(BaseRetriever, BaseModel): ...@@ -170,6 +170,12 @@ class KeywordTableRetriever(BaseRetriever, BaseModel):
index: KeywordTableIndex index: KeywordTableIndex
search_kwargs: dict = Field(default_factory=dict) search_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]: def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query. """Get documents relevant for a query.
......
from typing import Union, Optional, List from typing import Union, Optional, List
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.llms.fake import FakeListLLM
from core.constant import llm_constant from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError from core.llm.error import ProviderTokenNotInitError
...@@ -32,10 +31,7 @@ class LLMBuilder: ...@@ -32,10 +31,7 @@ class LLMBuilder:
""" """
@classmethod @classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]: def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
if model_name == 'fake':
return FakeListLLM(responses=[])
provider = cls.get_default_provider(tenant_id) provider = cls.get_default_provider(tenant_id)
mode = cls.get_mode_by_model(model_name) mode = cls.get_mode_by_model(model_name)
......
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import AzureChatOpenAI from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
...@@ -69,68 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): ...@@ -69,68 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return message_tokens return message_tokens
def _generate( @handle_llm_exceptions
self, def generate(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
async def _agenerate(
self, self,
messages: List[BaseMessage], messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
@handle_llm_exceptions
def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
return super().generate(messages, stop) return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @handle_llm_exceptions_async
async def agenerate( async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return await super().agenerate(messages, stop) return await super().agenerate(messages, stop, callbacks, **kwargs)
import os from langchain.callbacks.manager import Callbacks
from langchain.llms import AzureOpenAI from langchain.llms import AzureOpenAI
from langchain.schema import LLMResult from langchain.schema import LLMResult
from typing import Optional, List, Dict, Mapping, Any from typing import Optional, List, Dict, Mapping, Any
...@@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI): ...@@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI):
@handle_llm_exceptions @handle_llm_exceptions
def generate( def generate(
self, prompts: List[str], stop: Optional[List[str]] = None self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return super().generate(prompts, stop) return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @handle_llm_exceptions_async
async def agenerate( async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return await super().agenerate(prompts, stop) return await super().agenerate(prompts, stop, callbacks, **kwargs)
import os import os
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.schema import BaseMessage, LLMResult
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
...@@ -71,65 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI): ...@@ -71,65 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI):
return message_tokens return message_tokens
def _generate( @handle_llm_exceptions
self, def generate(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
async def _agenerate(
self, self,
messages: List[BaseMessage], messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
@handle_llm_exceptions
def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
return super().generate(messages, stop) return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @handle_llm_exceptions_async
async def agenerate( async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return await super().agenerate(messages, stop) return await super().agenerate(messages, stop, callbacks, **kwargs)
import os import os
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult from langchain.schema import LLMResult
from typing import Optional, List, Dict, Any, Mapping from typing import Optional, List, Dict, Any, Mapping
from langchain import OpenAI from langchain import OpenAI
...@@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI): ...@@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI):
"organization": self.openai_organization if self.openai_organization else None, "organization": self.openai_organization if self.openai_organization else None,
}} }}
@handle_llm_exceptions @handle_llm_exceptions
def generate( def generate(
self, prompts: List[str], stop: Optional[List[str]] = None self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return super().generate(prompts, stop) return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @handle_llm_exceptions_async
async def agenerate( async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return await super().agenerate(prompts, stop) return await super().agenerate(prompts, stop, callbacks, **kwargs)
from typing import Any, List, Dict from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel from langchain.schema import get_buffer_string
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory ReadOnlyConversationTokenDBBufferSharedMemory
......
from core.vector_store.vector_store import VectorStore
vector_store = VectorStore()
def init_app(app):
vector_store.init_app(app)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment