Unverified Commit d7ae8679 authored by takatost's avatar takatost Committed by GitHub

feat: support basic feature of OpenAI new models (#1476)

parent 7b26c9e2
...@@ -8,3 +8,4 @@ class ProviderQuotaUnit(Enum): ...@@ -8,3 +8,4 @@ class ProviderQuotaUnit(Enum):
class ModelFeature(Enum): class ModelFeature(Enum):
AGENT_THOUGHT = 'agent_thought' AGENT_THOUGHT = 'agent_thought'
VISION = 'vision'
...@@ -19,6 +19,13 @@ from core.model_providers.models.entity.model_params import ModelMode, ModelKwar ...@@ -19,6 +19,13 @@ from core.model_providers.models.entity.model_params import ModelMode, ModelKwar
AZURE_OPENAI_API_VERSION = '2023-07-01-preview' AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
FUNCTION_CALL_MODELS = [
'gpt-4',
'gpt-4-32k',
'gpt-35-turbo',
'gpt-35-turbo-16k'
]
class AzureOpenAIModel(BaseLLM): class AzureOpenAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider, def __init__(self, model_provider: BaseModelProvider,
name: str, name: str,
...@@ -157,3 +164,7 @@ class AzureOpenAIModel(BaseLLM): ...@@ -157,3 +164,7 @@ class AzureOpenAIModel(BaseLLM):
@property @property
def support_streaming(self): def support_streaming(self):
return True return True
@property
def support_function_call(self):
return self.base_model_name in FUNCTION_CALL_MODELS
...@@ -310,6 +310,10 @@ class BaseLLM(BaseProviderModel): ...@@ -310,6 +310,10 @@ class BaseLLM(BaseProviderModel):
def support_streaming(self): def support_streaming(self):
return False return False
@property
def support_function_call(self):
return False
def _get_prompt_from_messages(self, messages: List[PromptMessage], def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]: model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]:
if not model_mode: if not model_mode:
......
...@@ -23,21 +23,36 @@ COMPLETION_MODELS = [ ...@@ -23,21 +23,36 @@ COMPLETION_MODELS = [
] ]
CHAT_MODELS = [ CHAT_MODELS = [
'gpt-4-1106-preview', # 128,000 tokens
'gpt-4-vision-preview', # 128,000 tokens
'gpt-4', # 8,192 tokens 'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens 'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo-1106', # 16,384 tokens
'gpt-3.5-turbo', # 4,096 tokens 'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens 'gpt-3.5-turbo-16k', # 16,384 tokens
] ]
MODEL_MAX_TOKENS = { MODEL_MAX_TOKENS = {
'gpt-4-1106-preview': 128000,
'gpt-4-vision-preview': 128000,
'gpt-4': 8192, 'gpt-4': 8192,
'gpt-4-32k': 32768, 'gpt-4-32k': 32768,
'gpt-3.5-turbo-1106': 16384,
'gpt-3.5-turbo': 4096, 'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 4097, 'gpt-3.5-turbo-instruct': 4097,
'gpt-3.5-turbo-16k': 16384, 'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097, 'text-davinci-003': 4097,
} }
FUNCTION_CALL_MODELS = [
'gpt-4-1106-preview',
'gpt-4',
'gpt-4-32k',
'gpt-3.5-turbo-1106',
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k'
]
class OpenAIModel(BaseLLM): class OpenAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider, def __init__(self, model_provider: BaseModelProvider,
...@@ -50,7 +65,6 @@ class OpenAIModel(BaseLLM): ...@@ -50,7 +65,6 @@ class OpenAIModel(BaseLLM):
else: else:
self.model_mode = ModelMode.CHAT self.model_mode = ModelMode.CHAT
# TODO load price config from configs(db)
super().__init__(model_provider, name, model_kwargs, streaming, callbacks) super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any: def _init_client(self) -> Any:
...@@ -100,7 +114,7 @@ class OpenAIModel(BaseLLM): ...@@ -100,7 +114,7 @@ class OpenAIModel(BaseLLM):
:param callbacks: :param callbacks:
:return: :return:
""" """
if self.name == 'gpt-4' \ if self.name.startswith('gpt-4') \
and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \ and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value: and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.") raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
...@@ -175,6 +189,10 @@ class OpenAIModel(BaseLLM): ...@@ -175,6 +189,10 @@ class OpenAIModel(BaseLLM):
def support_streaming(self): def support_streaming(self):
return True return True
@property
def support_function_call(self):
return self.name in FUNCTION_CALL_MODELS
# def is_model_valid_or_raise(self): # def is_model_valid_or_raise(self):
# """ # """
# check is a valid model. # check is a valid model.
......
...@@ -41,9 +41,17 @@ class OpenAIProvider(BaseModelProvider): ...@@ -41,9 +41,17 @@ class OpenAIProvider(BaseModelProvider):
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
}, },
{
'id': 'gpt-3.5-turbo-1106',
'name': 'gpt-3.5-turbo-1106',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{ {
'id': 'gpt-3.5-turbo-instruct', 'id': 'gpt-3.5-turbo-instruct',
'name': 'GPT-3.5-Turbo-Instruct', 'name': 'gpt-3.5-turbo-instruct',
'mode': ModelMode.COMPLETION.value, 'mode': ModelMode.COMPLETION.value,
}, },
{ {
...@@ -62,6 +70,22 @@ class OpenAIProvider(BaseModelProvider): ...@@ -62,6 +70,22 @@ class OpenAIProvider(BaseModelProvider):
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
}, },
{
'id': 'gpt-4-1106-preview',
'name': 'gpt-4-1106-preview',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4-vision-preview',
'name': 'gpt-4-vision-preview',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.VISION.value
]
},
{ {
'id': 'gpt-4-32k', 'id': 'gpt-4-32k',
'name': 'gpt-4-32k', 'name': 'gpt-4-32k',
...@@ -79,7 +103,7 @@ class OpenAIProvider(BaseModelProvider): ...@@ -79,7 +103,7 @@ class OpenAIProvider(BaseModelProvider):
if self.provider.provider_type == ProviderType.SYSTEM.value \ if self.provider.provider_type == ProviderType.SYSTEM.value \
and self.provider.quota_type == ProviderQuotaType.TRIAL.value: and self.provider.quota_type == ProviderQuotaType.TRIAL.value:
models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']] models = [item for item in models if not item['id'].startswith('gpt-4')]
return models return models
elif model_type == ModelType.EMBEDDINGS: elif model_type == ModelType.EMBEDDINGS:
...@@ -141,8 +165,11 @@ class OpenAIProvider(BaseModelProvider): ...@@ -141,8 +165,11 @@ class OpenAIProvider(BaseModelProvider):
:return: :return:
""" """
model_max_tokens = { model_max_tokens = {
'gpt-4-1106-preview': 128000,
'gpt-4-vision-preview': 128000,
'gpt-4': 8192, 'gpt-4': 8192,
'gpt-4-32k': 32768, 'gpt-4-32k': 32768,
'gpt-3.5-turbo-1106': 16384,
'gpt-3.5-turbo': 4096, 'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 4097, 'gpt-3.5-turbo-instruct': 4097,
'gpt-3.5-turbo-16k': 16384, 'gpt-3.5-turbo-16k': 16384,
......
...@@ -24,12 +24,30 @@ ...@@ -24,12 +24,30 @@
"unit": "0.001", "unit": "0.001",
"currency": "USD" "currency": "USD"
}, },
"gpt-4-1106-preview": {
"prompt": "0.01",
"completion": "0.03",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-vision-preview": {
"prompt": "0.01",
"completion": "0.03",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo": { "gpt-3.5-turbo": {
"prompt": "0.0015", "prompt": "0.0015",
"completion": "0.002", "completion": "0.002",
"unit": "0.001", "unit": "0.001",
"currency": "USD" "currency": "USD"
}, },
"gpt-3.5-turbo-1106": {
"prompt": "0.0010",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-instruct": { "gpt-3.5-turbo-instruct": {
"prompt": "0.0015", "prompt": "0.0015",
"completion": "0.002", "completion": "0.002",
......
...@@ -73,8 +73,7 @@ class OrchestratorRuleParser: ...@@ -73,8 +73,7 @@ class OrchestratorRuleParser:
planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router')) planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
# only OpenAI chat model (include Azure) support function call, use ReACT instead # only OpenAI chat model (include Azure) support function call, use ReACT instead
if agent_model_instance.model_mode != ModelMode.CHAT \ if not agent_model_instance.support_function_call:
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
if planning_strategy == PlanningStrategy.FUNCTION_CALL: if planning_strategy == PlanningStrategy.FUNCTION_CALL:
planning_strategy = PlanningStrategy.REACT planning_strategy = PlanningStrategy.REACT
elif planning_strategy == PlanningStrategy.ROUTER: elif planning_strategy == PlanningStrategy.ROUTER:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment