// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key';
import {Graph} from '../../../graph';
import {NUMBER_TYPES, OperatorImplementation, OperatorInitialization} from '../../../operators';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types';

interface GatherAttributes extends AttributeWithCacheKey {
  readonly axis: number;
}

export const gather: OperatorImplementation<GatherAttributes> =
    (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: GatherAttributes): Tensor[] => {
      validateInputs(inputs, attributes.axis);
      const output = inferenceHandler.run(createGatherProgramInfoLoader(inferenceHandler, inputs, attributes), inputs);
      return [output];
    };

export const parseGatherAttributes: OperatorInitialization<GatherAttributes> = (node: Graph.Node): GatherAttributes =>
    createAttributeWithCacheKey({axis: node.attributes.getInt('axis', 0)});

const gatherProgramMetadata = {
  name: 'Gather',
  inputNames: ['A', 'B'],
  inputTypes: [TextureType.unpacked, TextureType.unpacked],
};

const createGatherProgramInfo =
    (handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => {
      const inputShape = inputs[0].dims.slice();
      const indexDataShape = inputs[1].dims.slice();
      const outputShape = new Array(inputShape.length + indexDataShape.length - 1);

      axis = ShapeUtil.normalizeAxis(axis, inputShape.length);
      const indexCopyOps: string[] = [];
      for (let i = 0; i < outputShape.length; i++) {
        // outputShape is divided into three parts: A, B, C
        // |0        axis|  axis + indexDataShape.length |          end|
        // |     A       |             B                 |      C      |
        //
        // inputIdx: [A, inputs[1][B], C]
        if (i < axis) {  // A
          outputShape[i] = inputShape[i];
          indexCopyOps.push(`inputIdx[${i}] = outputIdx[${i}];`);
        } else {
          if (i < axis + indexDataShape.length) {  // B
            outputShape[i] = indexDataShape[i - axis];
            indexCopyOps.push(`indexDataIdx[${i - axis}] = outputIdx[${i}];`);
          } else {                                                       // C
            outputShape[i] = inputShape[i - indexDataShape.length + 1];  // skip 1 for axis
            indexCopyOps.push(`inputIdx[${i - indexDataShape.length + 1}] = outputIdx[${i}];`);
          }
        }
      }

      const orank = outputShape.length || 1;
      const irank = inputShape.length;
      const iDrank = indexDataShape.length || 1;
      const shaderSource = `
      float process(int outputIdx[${orank}]) {
        int inputIdx[${irank}];
        int indexDataIdx[${iDrank}];
        indexDataIdx[0] = 0;
        ${indexCopyOps.join('\n        ')}
        int idx = int(_B(indexDataIdx));
        inputIdx[${axis}] = idx < 0 ? idx + ${inputShape[axis]} : idx;
        return _A(inputIdx);
      }`;
      return {
        ...metadata,
        output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked},
        shaderSource
      };
    };

const createGatherProgramInfoLoader =
    (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: GatherAttributes): ProgramInfoLoader => {
      const metadata = {...gatherProgramMetadata, cacheHint: attributes.cacheKey};
      return {...metadata, get: () => createGatherProgramInfo(handler, metadata, inputs, attributes.axis)};
    };

const validateInputs = (inputs: Tensor[], axis: number): void => {
  if (!inputs || inputs.length !== 2) {
    throw new Error('Gather requires 2 inputs.');
  }
  const tensorRank = inputs[0].dims.length;
  if (tensorRank < 1) {
    throw new Error('Invalid input shape.');
  }
  if (axis < -tensorRank || axis > tensorRank - 1) {
    throw new Error('Invalid axis.');
  }
  if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
    throw new Error('Invaid input type.');
  }
  if (inputs[1].type !== 'int32' && inputs[1].type !== 'int16') {
    throw new Error('Invaid input type.');
  }
};
