import * as tf from "@tensorflow/tfjs";

/**
 * Load pretrained model stored at a remote URL.
 *
 * @return An instance of `tf.Model` with model topology and weights loaded.
 */
export async function loadHostedPretrainedModel(url) {
  console.log("Loading pretrained model from " + url);
  try {
    const model = await tf.loadLayersModel(url);
    console.log("Done loading pretrained model.");
    // We can't load a model twice due to
    // https://github.com/tensorflow/tfjs/issues/34
    // Therefore we remove the load buttons to avoid user confusion.
    return model;
  } catch (err) {
    console.error(err);
    console.log("Loading pretrained model failed.");
  }
}

/**
 * Load metadata file stored at a remote URL.
 *
 * @return An object containing metadata as key-value pairs.
 */
export async function loadHostedMetadata(url) {
  console.log("Loading metadata from " + url);
  try {
    const metadataJson = await fetch(url);
    const metadata = await metadataJson.json();
    console.log("Done loading metadata.");
    return metadata;
  } catch (err) {
    console.error(err);
    console.log("Loading metadata failed.");
  }
}

export const PAD_INDEX = 0; // Index of the padding character.
export const OOV_INDEX = 2; // Index fo the OOV character.

/**
 * Pad and truncate all sequences to the same length
 *
 * @param {number[][]} sequences The sequences represented as an array of array
 *   of numbers.
 * @param {number} maxLen Maximum length. Sequences longer than `maxLen` will be
 *   truncated. Sequences shorter than `maxLen` will be padded.
 * @param {'pre'|'post'} padding Padding type.
 * @param {'pre'|'post'} truncating Truncation type.
 * @param {number} value Padding value.
 */
export function padSequences(
  sequences,
  maxLen,
  padding = "pre",
  truncating = "pre",
  value = PAD_INDEX
) {
  // TODO(cais): This perhaps should be refined and moved into tfjs-preproc.
  return sequences.map((seq) => {
    // Perform truncation.
    if (seq.length > maxLen) {
      if (truncating === "pre") {
        seq.splice(0, seq.length - maxLen);
      } else {
        seq.splice(maxLen, seq.length - maxLen);
      }
    }

    // Perform padding.
    if (seq.length < maxLen) {
      const pad = [];
      for (let i = 0; i < maxLen - seq.length; ++i) {
        pad.push(value);
      }
      if (padding === "pre") {
        seq = pad.concat(seq);
      } else {
        seq = seq.concat(pad);
      }
    }

    return seq;
  });
}
