Unverified Commit 91ee62d1 authored by Garfield Dai's avatar Garfield Dai Committed by GitHub

fix: huggingface and replicate. (#1888)

parent ede69b46
...@@ -154,6 +154,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ...@@ -154,6 +154,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
content=chunk.token.text content=chunk.token.text
) )
if chunk.details:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
...@@ -166,6 +167,16 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ...@@ -166,6 +167,16 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
index=index, index=index,
message=assistant_prompt_message, message=assistant_prompt_message,
usage=usage, usage=usage,
finish_reason=chunk.details.finish_reason,
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
), ),
) )
......
...@@ -116,7 +116,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): ...@@ -116,7 +116,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
) )
for key, value in input_properties: for key, value in input_properties:
if key not in ['system_prompt', 'prompt']: if key not in ['system_prompt', 'prompt'] and 'stop' not in key:
value_type = value.get('type') value_type = value.get('type')
if not value_type: if not value_type:
...@@ -151,9 +151,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): ...@@ -151,9 +151,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
index = -1 index = -1
current_completion: str = "" current_completion: str = ""
stop_condition_reached = False stop_condition_reached = False
prediction_output_length = 10000
is_prediction_output_finished = False
for output in prediction.output_iterator(): for output in prediction.output_iterator():
current_completion += output current_completion += output
if not is_prediction_output_finished and prediction.status == 'succeeded':
prediction_output_length = len(prediction.output) - 1
is_prediction_output_finished = True
if stop: if stop:
for s in stop: for s in stop:
if s in current_completion: if s in current_completion:
...@@ -172,6 +180,16 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): ...@@ -172,6 +180,16 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
content=output if output else '' content=output if output else ''
) )
if index < prediction_output_length:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
)
else:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
...@@ -183,8 +201,8 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): ...@@ -183,8 +201,8 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=index, index=index,
message=assistant_prompt_message, message=assistant_prompt_message,
usage=usage, usage=usage
), )
) )
def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str], def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment