Unverified Commit 91ee62d1 authored by Garfield Dai's avatar Garfield Dai Committed by GitHub

fix: huggingface and replicate. (#1888)

parent ede69b46
...@@ -154,20 +154,31 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ...@@ -154,20 +154,31 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
content=chunk.token.text content=chunk.token.text
) )
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) if chunk.details:
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model, yield LLMResultChunk(
prompt_messages=prompt_messages, model=model,
delta=LLMResultChunkDelta( prompt_messages=prompt_messages,
index=index, delta=LLMResultChunkDelta(
message=assistant_prompt_message, index=index,
usage=usage, message=assistant_prompt_message,
), usage=usage,
) finish_reason=chunk.details.finish_reason,
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
),
)
def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult: def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult:
if isinstance(response, str): if isinstance(response, str):
......
...@@ -116,7 +116,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): ...@@ -116,7 +116,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
) )
for key, value in input_properties: for key, value in input_properties:
if key not in ['system_prompt', 'prompt']: if key not in ['system_prompt', 'prompt'] and 'stop' not in key:
value_type = value.get('type') value_type = value.get('type')
if not value_type: if not value_type:
...@@ -151,9 +151,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): ...@@ -151,9 +151,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
index = -1 index = -1
current_completion: str = "" current_completion: str = ""
stop_condition_reached = False stop_condition_reached = False
prediction_output_length = 10000
is_prediction_output_finished = False
for output in prediction.output_iterator(): for output in prediction.output_iterator():
current_completion += output current_completion += output
if not is_prediction_output_finished and prediction.status == 'succeeded':
prediction_output_length = len(prediction.output) - 1
is_prediction_output_finished = True
if stop: if stop:
for s in stop: for s in stop:
if s in current_completion: if s in current_completion:
...@@ -172,20 +180,30 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): ...@@ -172,20 +180,30 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
content=output if output else '' content=output if output else ''
) )
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) if index < prediction_output_length:
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) yield LLMResultChunk(
model=model,
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
yield LLMResultChunk( index=index,
model=model, message=assistant_prompt_message
prompt_messages=prompt_messages, )
delta=LLMResultChunkDelta( )
index=index, else:
message=assistant_prompt_message, prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
usage=usage, completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
),
) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage
)
)
def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str], def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str],
prompt_messages: list[PromptMessage]) -> LLMResult: prompt_messages: list[PromptMessage]) -> LLMResult:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment