import {useContext, useEffect, useMemo, useRef, useState} from 'react';
import {DataConfig} from '../../types';
import {dataRows} from '../../../lib/data-types';
import {
    formatData,
    formatXAxisByType,
    generateAxisBottomConfig,
    getTextWidth,
    limitRowsForChart,
    PlotColors,
    PlotTheme,
} from '../common';
import ChartContainer from '../container';
import ChartContext, {ChartStore} from '../context';
import {observer} from 'mobx-react-lite';
import {QueryTableContext} from '../../query-table';
import QueryDataContext from '../../renderers/query-data-renderer/context';
import Loading from '../../../components/loading';
import {ResponsiveScatterPlot} from '@nivo/scatterplot';
import {ScatterPlotRawSerie, ScatterPlotValue} from '@nivo/scatterplot';
import {dataToPlot} from '../plot-chart';
import AdditionalToolTipVariables from '../additional-tooltip-variables';

const MAX_LINES = 99;


export const ScatterChartVis = observer(() => {
    let {yAxisNames, rowsToShow, xAxisName, valueMeta, showLatestYears, isYearChart, isYearMonthChart, maxXLabelLength, idVariable, nodeSize, nodeSizeVariable, metadata: chartMetadata} = useContext(ChartContext);
    let {data, metadata, loading} = useContext(QueryDataContext);
    let ref = useRef<HTMLDivElement>(null);
    let [containerWidth, setContainerWidth] = useState(800);
    const yVars = valueMeta;
    const xVar = metadata.find((m) => m.name === xAxisName)!;


    function createScatterPlotData(data: dataRows, idVariable: string, xVariable: string, yVariable: string, zVariable:string): any {
        const result: Record<string, ScatterPlotRawSerie<any>> = {};
        for (const row of data) {
            const group: string | number = row[idVariable] as string | number;
            if (!(group in result)) {
                result[String(group)] = {id: group, data: []};
            }
            const xValue: ScatterPlotValue = formatXAxisByType(xVariable,row);
            const yValue: ScatterPlotValue = row[yVariable] as any;
            const zValue = row[zVariable];
            if (xValue !== null && yValue !== null) {
                result[String(group)].data.push({x: xValue, y: yValue, z: zValue});
            }
        }
        return Object.values(result);
    }

    const graphData = useMemo(()=>{
        return createScatterPlotData(data,idVariable, xAxisName,yAxisNames[0], nodeSizeVariable);
    },[data, idVariable, xAxisName, yAxisNames[0], nodeSizeVariable]);

    const minMaxZValues:[number, number] = useMemo(() => {
        const zValues = graphData.flatMap((item:any) =>
            item.data
                .map((d:any) => d.z ? parseFloat(d.z) : 10)
                .filter((z:any) => z !== null)
        );
        if (zValues.length === 0) return [10, 10];
        const minZ = Math.min(...zValues);
        const maxZ = Math.max(...zValues);
        return [minZ || 10, maxZ || 10];
    }, [graphData]);

    const rows = useMemo(
        () => dataToPlot(
            formatData(data, showLatestYears, rowsToShow, xAxisName, maxXLabelLength),
            xAxisName,
            yVars,
            isYearChart,
        )
        ,[data, rowsToShow, xAxisName, yAxisNames[0], showLatestYears]);
    let maxXLegendLabelLength = data
        .slice(0, rowsToShow)
        .reduce((maxLength, row) => Math.max(maxLength, getTextWidth(xVar.formatSimple(row))), 0);
    let maxYLabelLength = rows.reduce(
        (maxLength, line, i) =>
            Math.max(
                maxLength,
                yVars.length
                    ? line.data.reduce(
                        (max, point) =>
                            Math.max(max, getTextWidth(yVars[i].formatter.formatSimple(point.y))),
                        0,
                    )
                    : 0,
            ),
        0,
    );

    let showLegend = true;
    let maxYVarLabelLength = yVars.reduce(
        (maxLength, v) => Math.max(maxLength, getTextWidth(v.label)),
        0,
    );

    let maxLegendsOnLine = Math.floor(containerWidth / (maxYVarLabelLength + 22));

    useEffect(() => {
        if (!ref.current)
            return;

        const observer = new ResizeObserver(entries => {
            setContainerWidth(entries[0].contentRect.width);
        });
        observer.observe(ref.current);
        return () => {
            if (ref.current)
                observer.unobserve(ref.current);
        };
    }, [ref]);

    const label = useMemo(() => {
        return chartMetadata.find(v => v.name === nodeSizeVariable)?.label;
    }, [chartMetadata, nodeSizeVariable]);



    if (loading)
        return <Loading />;

    return (
        <ChartContainer divRef={ref} maxRows={MAX_LINES}>
            {(idVariable && xAxisName && yAxisNames[0]) ? <div className='flex h-full'>
                <ResponsiveScatterPlot
                    data={graphData}
                    margin={{
                        top: 15,
                        right: maxXLegendLabelLength + 5,
                        bottom: maxXLegendLabelLength + 5 + (showLegend ? (20 * Math.ceil(rows.length / maxLegendsOnLine)) : 0),
                        left: maxYLabelLength + (yVars.length === 1 ? 30 : 10),
                    }}
                    theme={PlotTheme}
                    colors={PlotColors}
                    axisLeft={{
                        format: (x) => yVars[0].formatter.formatSimple(x),
                        legend: yVars[0].label,
                        legendPosition: 'middle',
                        legendOffset: -(maxYLabelLength + 25),
                    }}
                    axisBottom={generateAxisBottomConfig(isYearChart, isYearMonthChart)}
                    xScale={isYearChart || isYearMonthChart ? {
                        type: 'time',
                        format: 'native',
                        min: new Date(rows.reduce((min, r) => Math.min(min, ...r.data.map(d => (d.x as Date).getFullYear())), 9999), 11, 30),
                        max: new Date(rows.reduce((max, r) => Math.max(max, ...r.data.map(d => (d.x as Date).getFullYear())), 0), 12, 31)
                    } : {
                        type: 'linear',
                        min: 'auto',
                        max: 'auto',
                    }}
                    useMesh={true}
                    tooltip={(d) => (
                        <div className='max-w-xs bg-white border shadow-lg rounded'>
                            <div
                                className='px-3 pt-2 pb-1 rounded-t text-white font-semibold'
                                style={{backgroundColor: d.node.color}}
                            >
                                {d.node.serieId}
                            </div>
                            <div className='px-3 pb-3 pt-3 rounded-b'>
                                {yVars[0].formatter.formatSimple(d.node.formattedY)}
                            </div>
                            <AdditionalToolTipVariables rowIndex={Number(d.node.index)}/>
                        </div>
                    )}
                    legends={showLegend ? [
                        {
                            anchor: 'bottom-left',
                            direction: 'row',
                            justify: false,
                            translateX: 0,
                            translateY: maxXLegendLabelLength + 20,
                            itemsSpacing: 10,
                            itemDirection: 'left-to-right',
                            itemWidth: maxYVarLabelLength + 12 + 10,
                            itemHeight: 12,
                            symbolSize: 12,
                            symbolShape: 'circle',
                            symbolBorderColor: 'rgba(0, 0, 0, .5)',
                        },
                    ] : []}
                    nodeSize={nodeSize ? {
                        key: 'data.z',
                        sizes: [12,40],
                        values: minMaxZValues
                    } : 10}
                />
                {nodeSize &&
                <div className="flex flex-col items-center">
                    <div className="flex items-center">
                        <p>{label}</p>
                    </div>
                    <div className="flex items-center">
                        <div className="w-3 h-3 bg-orange-500 rounded-full mr-2"></div>
                        <p className="text-sm">{minMaxZValues[0]}</p>
                    </div>
                    <div className="flex items-center">
                        <div className="w-10 h-10 bg-orange-500 rounded-full mr-2"></div>
                        <p className="text-sm">{minMaxZValues[1]}</p>
                    </div>
                </div>
                }       
            </div> : <div>Select all variables to show chart</div>}
        </ChartContainer>
    );
});

const ScatterChartContext = observer(() => {
    const queryContext = useContext(QueryDataContext);
    const chartContext = useMemo(() => new ChartStore({dataConfig: queryContext.dataConfig, data: queryContext.data, metadata: queryContext.metadata}, MAX_LINES), [queryContext.dataConfig, queryContext.data, queryContext.metadata]);

    if (queryContext.loading)
        return <Loading />;

    return (
        <ChartContext.Provider value={chartContext}>
            <ScatterChartVis />
        </ChartContext.Provider>
    );
});

export default function ScatterChartRenderer(props: {
    dataConfig: DataConfig
    loadCB?: (promise: Promise<void>) => void
}) {
    return (
        <QueryTableContext
            dataConfig={limitRowsForChart(props.dataConfig, MAX_LINES)}
            loadCB={props.loadCB}
        >
            <ScatterChartContext />
        </QueryTableContext>
    );
}
