import { ImageSegmenter, FilesetResolver, ImageSegmenterResult } from '@mediapipe/tasks-vision'
const PADDING_FACTOR = 0.12

export class HeadSegmentation {
  private imageSegmenter: ImageSegmenter | null = null
  private initializationPromise: Promise<void> | null = null

  constructor() {
    this.initializationPromise = this.initialize()
  }

  static calculateBoundingBox(imageData: ImageData): {
    minx: number
    maxx: number
    miny: number
    maxy: number
  } {
    const { width, height, data } = imageData
    let minx = width,
      maxx = 0,
      miny = height,
      maxy = 0

    for (let y = 0; y < height; y++) {
      for (let x = 0; x < width; x++) {
        if (data[(y * width + x) * 4 + 3] > 0) {
          minx = Math.min(minx, x)
          maxx = Math.max(maxx, x)
          miny = Math.min(miny, y)
          maxy = Math.max(maxy, y)
        }
      }
    }

    return { minx, maxx, miny, maxy }
  }

  static extractHeadImage(
    imageData: ImageData,
    boundingBox: ReturnType<typeof HeadSegmentation.calculateBoundingBox>
  ): ImageData {
    const { minx, maxx, miny, maxy } = boundingBox
    const paddingX = Math.floor((maxx - minx) * PADDING_FACTOR)
    const paddingY = Math.floor((maxy - miny) * PADDING_FACTOR)
    const width = maxx - minx + 1 + paddingX * 2
    const height = maxy - miny + 1 + paddingY * 2
    const headData = new ImageData(width, height)

    // Fill with transparent pixels first
    for (let i = 0; i < headData.data.length; i += 4) {
      headData.data[i + 3] = 0 // Set alpha to 0
    }

    // Copy the head image data with offset for padding
    for (let y = 0; y < height - paddingY * 2; y++) {
      for (let x = 0; x < width - paddingX* 2; x++) {
        const newIndex = ((y + paddingY) * width + (x + paddingX)) * 4
        const originalIndex = ((miny + y) * imageData.width + (minx + x)) * 4
        headData.data.set(imageData.data.subarray(originalIndex, originalIndex + 4), newIndex)
      }
    }

    return headData
  }

  private async initialize() {
    const vision = await FilesetResolver.forVisionTasks(
      'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm'
    )
    const segmenter = await ImageSegmenter.createFromOptions(vision, {
      baseOptions: {
        modelAssetPath:
          'https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/selfie_multiclass_256x256.tflite',
      },
      outputCategoryMask: true,
      outputConfidenceMasks: false,
      runningMode: 'IMAGE',
    })

    this.imageSegmenter = segmenter
  }

  static resizeImageCanvas(
    imageData: ImageData,
    targetWidth: number,
    targetHeight: number
  ): HTMLCanvasElement {
    const { width, height } = imageData
    const aspectRatio = width / height
    const [newWidth, newHeight] =
      aspectRatio > 1
        ? [targetWidth, targetWidth / aspectRatio]
        : [targetHeight * aspectRatio, targetHeight]

    const sourceCanvas = document.createElement('canvas')
    sourceCanvas.width = width
    sourceCanvas.height = height
    const sourceCtx = sourceCanvas.getContext('2d')
    if (!sourceCtx) throw new Error('Failed to get 2D context')

    const resizedCanvas = document.createElement('canvas')
    resizedCanvas.width = newWidth
    resizedCanvas.height = newHeight
    const ctx = resizedCanvas.getContext('2d')
    if (!ctx) throw new Error('Failed to get 2D context')

    sourceCtx.putImageData(imageData, 0, 0)
    ctx.drawImage(sourceCanvas, 0, 0, newWidth, newHeight)

    // Reset the source canvas size to avoid memory leak
    sourceCanvas.width = 0
    sourceCanvas.height = 0

    return resizedCanvas
  }

  processSegmentationResult = (
    result: ImageSegmenterResult,
    ctx: CanvasRenderingContext2D,
    targetWidth: number,
    targetHeight: number
  ) => {
    if (!result.categoryMask) return

    const { width, height } = result.categoryMask
    const imageData = ctx.getImageData(0, 0, width, height)
    const mask = Array.from(result.categoryMask.getAsUint8Array())

    HeadSegmentation.clearBackground(imageData.data, mask)

    const boundingBox = HeadSegmentation.calculateBoundingBox(imageData)
    if (boundingBox.minx < boundingBox.maxx && boundingBox.miny < boundingBox.maxy) {
      const headImage = HeadSegmentation.extractHeadImage(imageData, boundingBox)
      const resizedCanvas = HeadSegmentation.resizeImageCanvas(headImage, targetWidth, targetHeight)

      ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height)
      ctx.canvas.width = resizedCanvas.width
      ctx.canvas.height = resizedCanvas.height
      ctx.drawImage(resizedCanvas, 0, 0)
    }

    return ctx.canvas.toDataURL('image/png')
  }

  static clearBackground(imageData: Uint8ClampedArray, mask: number[]) {
    // Set the alpha channel to 0 for all pixels that are not hair/face
    for (let i = 0; i < mask.length; i++) {
      const pixelIndex = i * 4
      if (mask[i] !== 1 && mask[i] !== 3) {
        imageData[pixelIndex + 3] = 0
      }
    }
  }

  segmentHead = async (
    base64DataUrl: string,
    targetWidth: number,
    targetHeight: number
  ): Promise<string> => {
    return new Promise((resolve, reject) => {
      const image = new Image()
      image.src = base64DataUrl
      image.onload = async () => {
        try {
          // Wait for initialization to complete
          await this.initializationPromise
          if (!this.imageSegmenter) {
            throw new Error('Segmenter failed to initialize')
          }

          const canvas = document.createElement('canvas')
          canvas.width = image.width
          canvas.height = image.height
          const ctx = canvas.getContext('2d')
          if (!ctx) throw new Error('Failed to get 2D context')

          ctx.drawImage(image, 0, 0)

          const result = await this.imageSegmenter.segment(image)

          if (!result.categoryMask) {
            throw new Error('No segmentation result')
          }

          // Check if all pixels are background (value 0)
          const mask = Array.from(result.categoryMask?.getAsUint8Array() ?? new Uint8Array())
          const isAllBackground = mask.every((pixel) => pixel === 0)
          if (isAllBackground) {
            return reject(new Error('No face found'))
          }

          const base64Image =
            this.processSegmentationResult(result, ctx, targetWidth, targetHeight) || ''

          canvas.width = 0
          canvas.height = 0

          resolve(base64Image)
        } catch (error) {
          reject(error)
        }
      }
    })
  }
}
