package cn.breeze.elleai.facade;

import cn.breeze.elleai.application.dto.langchain.EmbeddingRequest;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import java.util.List;

/**
 * @author yangyw
 */
@Component
@Slf4j
@RequiredArgsConstructor
public class OpenAIEmbeddingFacade implements EmbeddingService {

    @Value("${embedding.api-base-url:https://elle.e-tools.cn/v1}")
    private String apiBaseUrl;

    @Value("${embedding.api-key:smartbreeze}")
    private String apiKey;

    @Value("${embedding.model-name:bge-m3}")
    private String modelName;

    /**
     * 获取OpenAiEmbeddingModel
     * @param request
     * @return
     */
    private OpenAiEmbeddingModel getModel(EmbeddingRequest request) {
        return OpenAiEmbeddingModel.builder()
                .modelName(StrUtil.blankToDefault(request.getModel(), modelName))
                .baseUrl(StrUtil.blankToDefault(request.getApiBaseUrl(), apiBaseUrl))
                .apiKey(StrUtil.blankToDefault(request.getApiKey(), apiKey))
                .build();
    }

    @Override
    public List<Embedding> embed(List<String> texts) {
        EmbeddingRequest embeddingRequest = new EmbeddingRequest();
        embeddingRequest.setTexts(texts);
        return embed(embeddingRequest);
    }

    @Override
    public Embedding embed(String text) {
        return CollUtil.getFirst(embed(List.of(text)));
    }

    @Override
    public List<Embedding> embed(EmbeddingRequest request) {
        return getModel(request).embedAll(CollUtil.map(request.getTexts(), TextSegment::from, true)).content();
    }
}
