import { GridValidRowModel } from '@mui/x-data-grid-premium';
import { generateCollectionFilter, ServerSideState } from '@price-for-profit/data-grid';
import {
    ChangeReturn,
    DataAccessPaginatedResponse,
    DataAccessService,
    DataExecuteMassActionProps,
    DataGetCollectionProps,
    IDataAccessService,
    IFilter,
    LogicalFilter,
} from '@price-for-profit/micro-services';
import { CustomerPricesLevelPermissionType } from 'shared/types';

export interface IRowLevelSecurityDasService extends IDataAccessService {
    getCollectionWithRls<T extends GridValidRowModel, DB>(
        state: ServerSideState,
        permittedRowLevels: PermittedRowLevels,
        dataGetCollectionProps: DataGetCollectionProps<T, DB>
    ): Promise<DataAccessPaginatedResponse<T>>;
    executeMassActionWithRls<T extends GridValidRowModel, P, R = undefined>(
        state: ServerSideState,
        permittedRowLevels: PermittedRowLevels,
        params: DataExecuteMassActionProps<T, P>
    ): Promise<DataAccessPaginatedResponse<ChangeReturn<T, R>>>;
}

type MarketRegionType = 'inclusive' | 'exclusive' | 'ignore';

export type PermittedRowLevels = {
    marketRegionType: MarketRegionType;
    customerPricesLevel: CustomerPricesLevelPermissionType;
    page: 'customerPrices' | 'productPrices';
    permitted: {
        orgRegions: string[];
        businessLines: string[];
        marketSegments: string[];
        rlsNames: string[];
    };
};

export class RowLevelSecurityDasService extends DataAccessService implements IRowLevelSecurityDasService {
    getCollection<T, DB>(params: DataGetCollectionProps<T, DB>): Promise<DataAccessPaginatedResponse<T>> {
        throw new Error('RLS required.');
    }

    getCollectionWithRls<T extends GridValidRowModel, DB>(
        state: ServerSideState,
        permittedRowLevels: PermittedRowLevels,
        params: DataGetCollectionProps<T, DB>
    ): Promise<DataAccessPaginatedResponse<T>> {
        const fixedParams: DataGetCollectionProps<T, DB> = this.fixParamsForRls(state, permittedRowLevels, params);
        return super.getCollection(fixedParams);
    }

    executeMassAction<T, P, R = undefined>(
        params: DataExecuteMassActionProps<T, P>
    ): Promise<DataAccessPaginatedResponse<ChangeReturn<T, R>>> {
        throw new Error('RLS required.');
    }

    executeMassActionWithRls<T extends GridValidRowModel, DB, R = undefined>(
        state: ServerSideState,
        permittedRowLevels: PermittedRowLevels,
        params: DataExecuteMassActionProps<T, DB>
    ): Promise<DataAccessPaginatedResponse<ChangeReturn<T, R>>> {
        const fixedParams: DataExecuteMassActionProps<T, DB> = this.fixParamsForRls(state, permittedRowLevels, params);
        return super.executeMassAction(fixedParams);
    }

    private getRLSCollectionFilterFromGrid<T extends GridValidRowModel>(
        state: ServerSideState,
        permittedRowLevels: PermittedRowLevels
    ): IFilter<T> {
        const stateFilter = generateCollectionFilter<T>(state.filterModel);

        return this.getRLSCollectionFilter(stateFilter, permittedRowLevels);
    }

    private fixParamsForRls<P>(state: ServerSideState, permittedRowLevels: PermittedRowLevels, params: P): P {
        const rlsCollectionFilter = this.getRLSCollectionFilterFromGrid(state, permittedRowLevels);

        const result: P = {
            ...params,
            collectionFilter: rlsCollectionFilter,
        };
        return result;
    }

    private getRLSCollectionFilter<T extends GridValidRowModel>(
        stateFilter: IFilter<T> | undefined,
        { page, marketRegionType, permitted, customerPricesLevel }: PermittedRowLevels
    ): IFilter<T> {
        const { businessLines, orgRegions, marketSegments, rlsNames } = permitted;

        const businessLineOrFilter = this.buildLogicalOrFilter('businessLine', businessLines);
        const orgRegionOrFilter = this.buildLogicalOrFilter('orgRegion', orgRegions);
        const marketSegmentOrFilter = this.buildLogicalOrFilter('marketSegment', marketSegments);
        const rlsNameFilter = this.buildLogicalOrFilter('rlsName', rlsNames);

        const rlsFilter: IFilter<T> = {
            logicalOperator: 'and',
            filters: [businessLineOrFilter],
        };

        if (page === 'customerPrices' && customerPricesLevel === 'row') rlsFilter.filters.push(rlsNameFilter);
        if (['inclusive', 'ignore'].includes(marketRegionType)) rlsFilter.filters.push(orgRegionOrFilter);
        if (marketRegionType === 'inclusive') rlsFilter.filters.push(marketSegmentOrFilter);
        if (marketRegionType === 'exclusive') {
            const marketSegmentExclusiveFilters: IFilter<T>[] = [orgRegionOrFilter, marketSegmentOrFilter].filter(
                orFilter => orFilter.filters.length > 0
            );
            if (marketSegmentExclusiveFilters.length > 0) {
                rlsFilter.filters.push({
                    logicalOperator: 'or',
                    filters: marketSegmentExclusiveFilters,
                });
            }
        }

        const isLogicalFilter = (filter: IFilter<T>): filter is LogicalFilter<T> => {
            return filter && filter.hasOwnProperty('logicalOperator');
        };
        if (!stateFilter || (isLogicalFilter(stateFilter) && !stateFilter?.filters?.length)) return rlsFilter;

        const filter: IFilter<T> = {
            logicalOperator: 'and',
            filters: [stateFilter, rlsFilter],
        };

        return filter;
    }

    private buildLogicalOrFilter<T>(property: string, stringList: string[]) {
        const result: IFilter<T> = {
            logicalOperator: 'or',
            filters: stringList.map(stringItem => ({
                property: property as keyof T,
                operator: 'eq',
                value: stringItem,
            })),
        };

        return result;
    }
}
