//
//
//  OpenAI Usage
//
//

import {useEffect, useState} from "react";
import {Team} from "../../../interfaces.ts";
import {Flex, LoadingOverlay, SimpleGrid, Text} from "@mantine/core";
import {BarChart} from "@mantine/charts";
import {formatDateToDayMonth, getAllDatesInMonth, groupBy} from "../../../utils.ts";
import {useAuth0} from "@auth0/auth0-react";
import {API} from "../../../api.ts";


interface OpenAIUsageInterface {
    date: string,
    prompt_tokens: number,
    completion_tokens: number,
    cost: number
}

interface OpenAICostInterface {
    input: number,
    output: number
}

const COST_BY_OPENAI_MODEL: Record<string, OpenAICostInterface> = {  // By 1M tokens
    "gpt-3.5-turbo-0125": {
        "input": 0.50,
        "output": 1.50
    },
    "gpt-4-0125-preview": {
        "input": 10,
        "output": 30
    },
    "text-embedding-3-large": {
        "input": 0.13,
        "output": 0  // Embedding model: there's no output cost
    }
}

function calculateCost(inputTokens: number, outputTokens: number, costPerMillionInput: number, costPerMillionOutput: number) {
    const costInput = (inputTokens / 1000000) * costPerMillionInput;
    const costOutput = (outputTokens / 1000000) * costPerMillionOutput;
    return costInput + costOutput;
}

function formatUsage(usage: any, date: Date) {
    const usageValuesByModel: Record<string, any> = {}
    for (const [model, data] of Object.entries(groupBy(usage, "model"))) {
        const costData = COST_BY_OPENAI_MODEL[model]

        const usageValues = []
        for (const day of getAllDatesInMonth(date)) {
            const dayData = data.find((item: any) => new Date(item.date).getDate() == day.getDate())

            if (dayData == null) {
                usageValues.push({
                    date: formatDateToDayMonth(day),
                    prompt_tokens: 0,
                    completion_tokens: 0,
                    cost: 0
                })
            } else {
                usageValues.push({
                    date: formatDateToDayMonth(day),
                    prompt_tokens: dayData.prompt_tokens,
                    completion_tokens: dayData.completion_tokens,
                    cost: calculateCost(dayData.prompt_tokens, dayData.completion_tokens, costData.input, costData.output)
                })
            }
        }
        usageValuesByModel[model] = usageValues
    }

    const globalUsageValues = Object.values(usageValuesByModel).reduce((result, currentArray) => {
        currentArray.forEach((object: any) => {
            const existingObject = result.find((item: any) => item.date === object.date)

            if (existingObject) {
                existingObject.prompt_tokens += object.prompt_tokens
                existingObject.completion_tokens += object.completion_tokens
                existingObject.cost += object.cost
            } else {
              result.push({
                  date: object.date,
                  prompt_tokens: object.prompt_tokens,
                  completion_tokens: object.completion_tokens,
                  cost: object.cost
              })
            }
        })

        return result
    }, [])

    return [globalUsageValues, usageValuesByModel]
}


function OpenAIUsage({date, team}: {date: Date, team?: Team}) {
    const {getAccessTokenSilently} = useAuth0()
    const [loading, setLoading] = useState(true)
    const [globalUsage, setGlobalUsage] = useState<OpenAIUsageInterface[]>([])
    const [usageByModel, setUsageByModel] = useState<Record<string, OpenAIUsageInterface[]>>({})

    useEffect(() => {
        setLoading(true)

        API.getAdminOpenAIUsage(getAccessTokenSilently, date, team?.id)
            .then(usage => {
                const [globalUsageValues, usageValuesByModel] = formatUsage(usage, date)
                setGlobalUsage(globalUsageValues)
                setUsageByModel(usageValuesByModel)
            }).catch(error => {
                console.error(error)
            }).finally(() => {
                setLoading(false)
            })
    }, [date, getAccessTokenSilently, team])

    if (loading) {
        return (
            <LoadingOverlay visible={loading} />
        )
    }

    const globalCost = globalUsage.reduce((acc, item) => acc + item.cost, 0)

    return (
        <>
            <Flex direction="column" gap={16} mb="xl">
                <Flex gap={8}>
                    <Text fw={500}>Monthly usage</Text>
                    <Text c="dimmed" fw={500}>${globalCost.toFixed(3)}</Text>
                </Flex>
                <BarChart
                    h={300}
                    data={globalUsage}
                    dataKey="date"
                    type="stacked"
                    withLegend
                    series={[
                        { name: 'prompt_tokens', color: 'violet.6' , label: "Prompt tokens" },
                        { name: 'completion_tokens', color: 'blue.6', label: "Completion tokens"},
                    ]}
                    tickLine="y"
                />
            </Flex>
            <SimpleGrid cols={2} spacing="xl" verticalSpacing="xl">
                {Object.entries(usageByModel).map(([model, data]) => {
                    const cost = data.reduce((acc, item) => acc + item.cost, 0)

                    return (
                        <Flex direction="column" gap={16} key={model}>
                            <Flex gap={8}>
                                <Text fw={500}>{model}</Text>
                                <Text c="dimmed" fw={500}>${cost.toFixed(3)}</Text>
                            </Flex>
                            <BarChart
                                h={300}
                                data={data}
                                dataKey="date"
                                type="stacked"
                                series={[
                                    { name: 'prompt_tokens', color: 'violet.6' , label: "Prompt tokens" },
                                    { name: 'completion_tokens', color: 'blue.6', label: "Completion tokens"},
                                ]}
                                tickLine="y"
                            />
                        </Flex>
                    )
                })}
            </SimpleGrid>
        </>
    )
}

export default OpenAIUsage
