import { SelfieSegmentation, Results } from '@mediapipe/selfie_segmentation';
import { TrackProvider } from '../sfu/track-provider';
import { htmlToElement } from '../util';
import { trackUserActivationFor } from '../util/user-activation';
import { drawCenteredImage } from './camera-draw-util';

class SegmentationVideoTrackProvider implements TrackProvider {
  type = 'segmentation';

  private animationFrame: number;

  private segmentationSession: SelfieSegmentation;

  private canvasElement: HTMLCanvasElement = null;

  private canvasCtx: CanvasRenderingContext2D = null;

  private init() {
    this.segmentationSession = new SelfieSegmentation({
      locateFile: (file: string) => `https://cdn.jsdelivr.net/npm/@mediapipe/selfie_segmentation/${file}`,
    });

    this.canvasElement = htmlToElement('<canvas width="512" height="1024" />') as HTMLCanvasElement;
    this.canvasCtx = this.canvasElement.getContext('2d');
  }

  async getTrack(videoMediaConfig: object) {
    if (!this.segmentationSession) this.init();

    const inputStream = await navigator.mediaDevices.getUserMedia({ video: videoMediaConfig || true });

    this.segmentationSession.setOptions({
      modelSelection: 1,
    });
    this.segmentationSession.onResults(this.processSegmentationResults.bind(this));

    const videoElement = this.localVideoWorkspace();
    videoElement.srcObject = inputStream;
    videoElement.onloadedmetadata = () => {
      trackUserActivationFor('setupForCamera');
      videoElement.play();
    };
    const step = async () => {
      if (window.rtc.isVideoOn && videoElement.videoWidth > 0 && videoElement.videoHeight > 0) {
        await this.segmentationSession.send({ image: videoElement });
      }
      this.animationFrame = requestAnimationFrame(step);
    };
    this.animationFrame = requestAnimationFrame(step);

    const outputStream = this.canvasElement.captureStream(24);
    return outputStream.getVideoTracks()[0];
  }

  async suspend() {
    await this.segmentationSession.close();
    this.segmentationSession = null;
    cancelAnimationFrame(this.animationFrame);
  }

  private processSegmentationResults(results: Results) {
    this.canvasCtx.save();

    const canvasDim = Math.min(results.image.width, results.image.height);
    this.canvasElement.width = canvasDim;
    this.canvasElement.height = canvasDim * 2;

    // Draw the mask directly above the original image
    // The mask is drawn as a red on top of the black canvas
    drawCenteredImage(
      results.segmentationMask,
      results.segmentationMask.width,
      results.segmentationMask.height,
      this.canvasCtx,
      this.canvasElement.width,
      this.canvasElement.height / 2,
      'source-over',
      0,
      0
    );
    drawCenteredImage(
      results.image,
      results.image.width,
      results.image.height,
      this.canvasCtx,
      this.canvasElement.width,
      this.canvasElement.height / 2,
      'source-over',
      0,
      this.canvasElement.height / 2
    );
    this.canvasCtx.restore();
  }

  private localVideoWorkspace(): HTMLVideoElement {
    if (!document.getElementById('local-camera-video')) {
      const videoElement = htmlToElement(
        '<video id="local-camera-video" autoplay muted style="display: block; zindex: -1; width: 1px; height: 1px;" />'
      );
      document.body.appendChild(videoElement);
    }
    return document.getElementById('local-camera-video') as HTMLVideoElement;
  }
}

export default new SegmentationVideoTrackProvider();
