import * as THREE from 'three'

export default class RayCaster {
  constructor (camera) {
    this.rayCaster = new THREE.Raycaster()
    this.objects = []
    this.camera = camera
    this.mouse = new THREE.Vector2()
  }

  add (mesh) {
    this.objects.push(mesh)
  }

  remove (mesh) {
    // TODO: @Performance This copies the array, maybe swap and pop instead.
    this.objects = this.objects.filter(v => v !== mesh)
  }

  findIntersections (point, filter, meshIds) {
    this.mouse.set((point.x * 2) - 1, -(point.y * 2) + 1)
    this.rayCaster.near = this.camera.near
    this.rayCaster.setFromCamera(this.mouse, this.camera)
    return _trace(this.rayCaster, this.objects, filter, meshIds)
  }

  trace (origin, direction, filter) {
    this.rayCaster.set(origin, direction)
    return _trace(this.rayCaster, this.objects, filter)
  }
}

const _trace = (() => {
  const sphere = new THREE.Sphere()

  return function (raycaster, objects, filter = undefined, meshIds) {
    let distance = Infinity
    let closest = null
    
    const intersects = []
    
    objects.forEach(o => {
      if (!o.visible) return

      let currentIndex = []
      if (meshIds) {
        for (let i=0; i<o.count; i++) {
          if (meshIds.includes(o.userData.instancedSceneGraphIDs.get(i))) {
            currentIndex.push(i)
          } 
        }
      }

      // If we have a filter, discard objects that does not pass.
      if (filter && !filter(o)) return

      if (closest) {
        // Discard objects that are farther away than the current closest hit.
        let d
        let foundClosest = false
        let matrixToApply = new THREE.Matrix4()
        for (let i=0; i<o.count; i++) {
          if (currentIndex.includes(i)) continue
          sphere.copy(o.geometry.boundingSphere)
          o.getMatrixAt(i, matrixToApply)
          sphere.applyMatrix4(matrixToApply)
          d = sphere.distanceToPoint(raycaster.ray.origin)

          if (d < closest.distance) foundClosest = true
        }

        if (!foundClosest) return
      }

      o.raycast(raycaster, intersects)

      while (intersects.length > 0) {
        const hit = intersects.pop()
        if (currentIndex.includes(hit.instanceId)) continue
        if (hit.distance < distance) {
          closest = hit
          distance = hit.distance
        }
      }
    })

    if (closest) {
      closest.object = closest.object.userData.instancedSceneGraphIDs.get(closest.instanceId)
    }

    return closest
  }
})()
