import { models } from '@trova-trip/trova-models';
import regression, { DataPoint } from 'regression';
import minBy from 'lodash/minBy';
import maxBy from 'lodash/maxBy';
import { CostPerThreshold } from '../PricingCalculator.types';
import { roundUp } from '../Utils/common.utils';
import { PredictivePricing } from './PredictivePricing';

type CostThreshold = models.itineraries.CostThreshold;

class PredictivePricingImpl implements PredictivePricing {
    protected _costThresholds: CostThreshold[];
    protected _tier: number;

    constructor(
        costThresholds: CostThreshold[],
        tier: number,
    ) {
        this._costThresholds = costThresholds;
        this._tier = tier;
    }

    public getPredictedPricePerTier(): CostPerThreshold {
        const existingTier = this._costThresholds.find(
            ({ numberOfTravelers }) => numberOfTravelers === this._tier,
        );

        if (existingTier) {
            return {
                numberTravelers: existingTier.numberOfTravelers,
                pricePerTraveler: existingTier.price,
                platformFee: existingTier.platformFee,
            };
        }

        if (this._tier <= 0) {
            return {
                numberTravelers: this._tier,
                pricePerTraveler: 0,
                platformFee: 0,
            };
        }

        const previousTier = maxBy(
            this._costThresholds,
            ({ numberOfTravelers }) =>
                numberOfTravelers <= this._tier ? numberOfTravelers : null,
        );
        const nextTier = minBy(this._costThresholds, ({ numberOfTravelers }) =>
            numberOfTravelers >= this._tier ? numberOfTravelers : null,
        );

        if (!previousTier) {
            return this._calculateLowerTier();
        }

        if (!nextTier) {
            return this._calculateHigherTier();
        }

        return this._predictPriceBetweenTwoTiers(previousTier!, nextTier!);
    }

    private _calculateLowerTier(): CostPerThreshold {
        const lowerTier = minBy(this._costThresholds, 'numberOfTravelers')!;
        const price = lowerTier.price;
        return {
            numberTravelers: this._tier,
            pricePerTraveler: price > 0 ? roundUp(price) : 0,
            platformFee: lowerTier.platformFee,
        };
    }

    private _calculateHigherTier(): CostPerThreshold {
        const higherTier = maxBy(this._costThresholds, 'numberOfTravelers')!;
        const price = higherTier.price;
        return {
            numberTravelers: this._tier,
            pricePerTraveler: price > 0 ? roundUp(price) : 0,
            platformFee: higherTier.platformFee,
        };
    }

    private _predictPriceBetweenTwoTiers(
        previousTier: CostThreshold,
        nextTier: CostThreshold,
    ): CostPerThreshold {
        const pointToCalculatePriceSlope: DataPoint[] = [
            [previousTier.numberOfTravelers, previousTier.price],
            [nextTier.numberOfTravelers, nextTier.price],
        ];

        const pricePerTraveler = this._calculateSlope(
            pointToCalculatePriceSlope,
        );

        let platformFee = previousTier.platformFee;

        if (previousTier.platformFee && nextTier.platformFee) {
            const pointToCalculatePlatformFeeSlope: DataPoint[] = [
                [previousTier.numberOfTravelers, previousTier.platformFee],
                [nextTier.numberOfTravelers, nextTier.platformFee],
            ];

            platformFee = this._calculateSlope(
                pointToCalculatePlatformFeeSlope,
            );
        }

        return {
            numberTravelers: this._tier,
            pricePerTraveler,
            platformFee,
        };
    }

    private _calculateSlope = (dataPoints: DataPoint[]): number => {
        const priceSlope = regression.linear(dataPoints as DataPoint[], {
            order: 1,
        });

        const predictedPrice = priceSlope.predict(this._tier)[1];

        return predictedPrice > 0 ? roundUp(predictedPrice) : 0;
    };
}

export default PredictivePricingImpl;
