import {
    ChatCompletionChunk, ChatCompletionFunction,
    ChatCompletionRequestMessage,
    ChatCompletionResponseMessage,
    ChatModel
} from "./types";
import {Plugin} from "../plugins/types";
import {ApiError, Configuration, fetchWithTimeout as fetch} from "./index";
import {isEmpty} from "lodash-es";
import {handleSseStream} from "../core/fetch-sse";
import PluginManager from "./PluginManager";

type ChatClientInit = {
    apiKey: string,
    model: ChatModel,
    timeout?: number,
    baseUrl?: string,
};

type BeforeFunctionCallParams = {
    name: string,
    params?: string,
    pluginId: string,
}

type AfterFunctionCallParams = {
    name: string,
    params?: string,
    result: string,
    pluginId: string,
}

type ChatParams = {
    messages: Array<ChatCompletionRequestMessage>
    signal?: AbortSignal,
    plugins?: Array<Plugin>,
    onMessageUpdate?: (message: ChatCompletionResponseMessage) => unknown;
    onFunctionCallFailed?: (e: Error) => unknown;
    beforeFunctionCall?: (params: BeforeFunctionCallParams) => unknown;
    afterFunctionCall?: (params: AfterFunctionCallParams) => unknown;
};

type RequestParams = {
    messages: Array<ChatCompletionRequestMessage>
    signal?: AbortSignal,
    functions?: Array<ChatCompletionFunction>
    onMessageUpdate?: (message: ChatCompletionResponseMessage) => unknown;
};

export default class ChatClient {

    apiKey: string;
    timeout: number | undefined;
    baseUrl: string = Configuration.DEFAULT_BASE_URL;
    model: ChatModel;

    constructor(init: ChatClientInit) {
        this.apiKey = init.apiKey;
        if (init.baseUrl) {
            this.baseUrl = init.baseUrl;
        }
        this.timeout = init.timeout;
        this.model = init.model;
    }

    async chat(params: ChatParams): Promise<ChatCompletionResponseMessage> {
        const {messages, signal, plugins = [], onMessageUpdate, beforeFunctionCall, afterFunctionCall} = params;
        const pluginManager = new PluginManager();
        for (let plugin of plugins) {
            pluginManager.addPlugin(plugin);
        }
        let message = null;
        while (true) {
            message = await this.requestMessage({
                messages,
                signal,
                functions: pluginManager.getFunctions(),
                onMessageUpdate,
            })
            if (message.function_call) {
                const pluginId = pluginManager.functions[message.function_call.name].pluginId;
                beforeFunctionCall && beforeFunctionCall({
                    name: message.function_call.name,
                    params: message.function_call.arguments,
                    pluginId,
                });
                const result = await pluginManager.callFunction(message.function_call.name, message.function_call.arguments as string)
                afterFunctionCall && afterFunctionCall({
                    name: message.function_call.name,
                    params: message.function_call.arguments,
                    result,
                    pluginId,
                })
                messages.push({
                    role: 'assistant',
                    content: null,
                    function_call: message.function_call,
                });
                messages.push({
                    role: 'function',
                    name: message.function_call.name,
                    content: result,
                })
            } else {
                break;
            }
        }
        return message;
    }

    async requestMessage(params: RequestParams): Promise<ChatCompletionResponseMessage> {
        const {messages, signal, functions, onMessageUpdate} = params;
        const url = this.baseUrl + '/v1/chat/completions';
        const headers: HeadersInit = {
            'Content-Type': 'application/json',
        }
        if (this.apiKey) {
            headers['Authorization'] = `Bearer ${this.apiKey}`;
        }
        const options = {
            method: 'post',
            signal,
            headers: {
                Accept: 'text/event-stream',
                ...headers,
            },
            body: JSON.stringify({
                model: this.model,
                max_tokens: 800,
                temperature: 0.6,
                stream: true,
                messages,
                functions: functions,
            }),
            timeout: this.timeout,
        };
        const response = await fetch(url, options)
        if (!response.ok) {
            const {error} = await response.json().catch(() => ({}));
            if (isEmpty(error)) {
                throw new Error(`${response.status} ${response.statusText}`);
            } else {
                throw new ApiError(error);
            }
        }
        const responseMessage: ChatCompletionResponseMessage = {
            role: 'assistant',
        };
        return handleSseStream(response.body, (data) => {
            if (data === '[DONE]') {
                return;
            }
            try {
                const chunk = parseChatCompletionChunk(data);
                updateResponseMessage(responseMessage, chunk);
                if (responseMessage.content) {
                    onMessageUpdate && onMessageUpdate(responseMessage);
                }
            } catch (err) {
                console.error(err);
            }
        }).then(() => {
            return responseMessage;
        });
    }
}

function parseChatCompletionChunk(data: string): ChatCompletionChunk {
    return JSON.parse(data) as ChatCompletionChunk
}

function updateResponseMessage(message: ChatCompletionResponseMessage, chunk: ChatCompletionChunk) {
    const delta = chunk.choices[0].delta;
    if (delta.role) {
        message.role = delta.role;
    }
    if (delta.content) {
        if (message.content) {
            message.content += delta.content;
        } else {
            message.content = delta.content;
        }
    }
    if (delta.function_call) {
        const function_call = delta.function_call;
        if (function_call.name) {
            message.function_call = {
                name: function_call.name,
            }
        }
        if (function_call.arguments && message.function_call) {
            if (message.function_call?.arguments) {
                message.function_call.arguments += function_call.arguments;
            } else {
                message.function_call.arguments = function_call.arguments;
            }
        }
    }
    return message;
}