import { defineStore } from 'pinia'
import { ref } from 'vue'

import { ModelType } from '@/core/annotations'
import { FilterSubject } from '@/modules/AdvancedFilters/FilterSubject'
import { FilterOperator } from '@/modules/AdvancedFilters/FilterOperator'
import { WorkflowStatus } from '@/modules/AdvancedFilters/WorkflowStatus'
import { createAdvancedFilter } from '@/modules/AdvancedFilters/filterFactory'
import type {
  ModelTemplatePayload,
  TrainingClass,
  TrainingSessionPayload,
} from '@/backend/wind/types'
import { ModelDevice } from '@/backend/wind/types'
import type { ApiResult } from '@/backend/darwin/types'
import { loadV2DatasetGeneralCounts } from '@/backend/darwin/loadV2DatasetGeneralCounts'
import { assert } from '@/core/utils/assert'
import type { AnnotationClassPayload } from '@/store/types/AnnotationClassPayload'
import type { DatasetPayload } from '@/store/types/DatasetPayload'

import {
  loadPublishedModelTemplates as loadPublishedModelTemplatesFromWind,
  trainModel as trainModelFromWind,
} from '@/backend/wind'
import { loadV2DatasetItems } from '@/backend/darwin/loadV2DatasetItems'
import {
  fetchMainAnnotationType,
  getSubAnnotationTypes,
  type AnnotationType,
} from '@/core/annotationTypes'
import type { ValidationError } from '@/store/types/ValidationError'
import type { ParsedValidationError } from '@/backend/error/types'

import { WindErrorCodes } from '@/backend/error/errors'
import type { NeuralModelValidationErrors } from './modelCreationTypes'
import { validateNewModel as validate } from './modelCreationUtils'
import type { V2DatasetItemPayload } from '@/store/types/V2DatasetItemPayload'
import type { DatasetReportPayload } from '@/store/types'

const annotationClassToTrainingClass = (
  annotationClass: AnnotationClassPayload,
  type: AnnotationType,
  subs: AnnotationType[],
): Omit<TrainingClass, 'id'> => {
  const { id, name } = annotationClass
  return { darwin_id: id, name, type, subs }
}

/**
 * This Pinia store is used when training a new model. It holds the state of the
 * model to be trained, and sends a request to Wind to train the model.
 */
export const useModelCreationStore = defineStore('modelCreation', () => {
  const modelTemplates = ref<ModelTemplatePayload[]>([])

  const loadPublishedModelTemplates = async (teamId: number): Promise<void> => {
    const response = await loadPublishedModelTemplatesFromWind({ teamId })

    if ('data' in response) {
      modelTemplates.value = response.data
    }
  }

  const newModelAnnotationClasses = ref<AnnotationClassPayload[]>([])
  const newModelClassCounts = ref<DatasetReportPayload | null>(null)
  const newModelSelectedClassIds = ref<number[]>([])
  const newModelTrainingCounts = ref<number | null>(null)

  const deselectAllNewModelClasses = (): void => {
    newModelTrainingCounts.value = null
    newModelSelectedClassIds.value = []
  }

  const toggleNewModelClassSelection = (annotationClass: AnnotationClassPayload): void => {
    const idx = newModelSelectedClassIds.value.indexOf(annotationClass.id)
    if (idx === -1) {
      newModelSelectedClassIds.value.push(annotationClass.id)
      newModelTrainingCounts.value = null
    } else {
      newModelSelectedClassIds.value.splice(idx, 1)
    }
  }

  /**
   * Used during model creation.
   * Dataset the new model will be trained on.
   */
  const newModelDataset = ref<DatasetPayload | null>(null)

  const newModelSampleItems = ref<V2DatasetItemPayload[]>([])
  const newModelSampleItemsCursor = ref<string | null>(null)
  const loadSampleDatasetItems = async (): Promise<void> => {
    const { team_slug: teamSlug, id: datasetId } = assert(
      newModelDataset.value,
      'No dataset selected',
    )

    const statusFilter = {
      subject: FilterSubject.WorkflowStatus,
      matcher: { name: FilterOperator.AnyOf, values: [WorkflowStatus.Complete] },
    }

    const response = await loadV2DatasetItems({
      include_thumbnails: true,
      page: { from: newModelSampleItemsCursor.value || undefined, size: 20 },
      filter: createAdvancedFilter(statusFilter),
      dataset_ids: [datasetId],
      teamSlug,
    })

    if (response.ok && 'items' in response.data) {
      const newIds = response.data.items.map((i) => i.id)
      newModelSampleItems.value = newModelSampleItems.value
        .filter((i) => !newIds.includes(i.id))
        .concat(response.data.items)
      newModelSampleItems.value = response.data.items
      newModelSampleItemsCursor.value = response.data.page.next
    }
  }

  const setNewModelDataset = (dataset: DatasetPayload | null): void => {
    if (newModelDataset.value?.id !== dataset?.id) {
      newModelTrainingCounts.value = null
      newModelSampleItemsCursor.value = null
      newModelSampleItems.value = []
      newModelSelectedClassIds.value = []
    }

    newModelDataset.value = dataset
  }

  const newModelTemplate = ref<ModelTemplatePayload | null>(null)

  /**
   * Used during model creation.
   * Allows us to filter selectable model templates.
   */
  const newModelType = ref<ModelType>(ModelType.INSTANCE_SEGMENTATION)

  const newModelName = ref('')
  const setNewModelName = (name: string): void => {
    // strip newlines from name before commiting to store
    // wind doesn't really care about this, so not enforced from that end
    // but we face less render issues if we sanitize on frontend
    newModelName.value = name.replace(/(\r\n|\n|\r)/gm, '')
  }

  const newModelValidationErrors = ref<NeuralModelValidationErrors>({})

  const loadNewModelTrainingCounts = async (): Promise<void> => {
    const { team_slug: teamSlug, id: datasetId } = assert(
      newModelDataset.value,
      'No dataset selected',
    )

    // we load counts for the selected dataset, selected classes, completed items only
    const classFilter = {
      subject: FilterSubject.AnnotationClass,
      matcher: { name: FilterOperator.AnyOf, values: newModelSelectedClassIds.value },
    }

    const statusFilter = {
      subject: FilterSubject.WorkflowStatus,
      matcher: { name: FilterOperator.AnyOf, values: [WorkflowStatus.Complete] },
    }

    const response = await loadV2DatasetGeneralCounts({
      teamSlug,
      dataset_ids: [datasetId],
      filter: createAdvancedFilter(classFilter, statusFilter),
    })

    if ('data' in response) {
      newModelTrainingCounts.value = response.data.simple_counts[0]?.filtered_item_count || 0
    }
  }

  const trainModel = async (): Promise<ApiResult<TrainingSessionPayload>> => {
    const errors = newModelValidationErrors.value

    if (Object.keys(errors).length > 0) {
      const error: ParsedValidationError = {
        errors: errors as ValidationError,
        isValidationError: true,
        message: 'Data is invalid',
      }

      return { error, ok: false }
    }

    const {
      team_slug: teamSlug,
      slug: datasetSlug,
      team_id: teamId,
      id: datasetId,
    } = assert(newModelDataset.value, 'No dataset selected')
    const { id: modelTemplateId } = assert(newModelTemplate.value, 'No template selected')

    const classes = newModelAnnotationClasses.value
      .filter((c) => newModelSelectedClassIds.value.includes(c.id))
      .map((c) => {
        const mainType = fetchMainAnnotationType(c.annotation_types)
        const subTypes = getSubAnnotationTypes(c.annotation_types)
        return annotationClassToTrainingClass(c, mainType, subTypes)
      })

    const params = {
      datasetId,
      datasetSlug,
      device: ModelDevice.GPU,
      classes,
      modelTemplateId,
      name: newModelName.value,
      teamId,
      teamSlug,
    }

    const response = await trainModelFromWind(params)
    if (!response.ok && 'isValidationError' in response.error && response.error.isValidationError) {
      const backendErrors = response.error.errors as ValidationError
      const errors: NeuralModelValidationErrors = {}
      if (backendErrors.name) {
        errors.name = backendErrors.name as string
      }
      newModelValidationErrors.value = errors
      return response
    }

    if (!response.ok && response.error.code === WindErrorCodes.NO_PAYMENT_METHOD) {
      const message = 'You need to provide a valid payment method in order to train models'
      return { error: { ...response.error, message }, ok: false }
    }

    if (
      !response.ok &&
      response.error.code === WindErrorCodes.PARTNER_DOES_NOT_COVER_NEURAL_NETWORKS
    ) {
      const message = [
        'Your partner does not cover your neural network costs.',
        'You will have to discuss and change your relationship with them to start training models.',
      ].join(' ')

      return { error: { ...response.error, message }, ok: false }
    }

    return response
  }

  const validateNewModel = (): void => {
    const errors = validate({
      newModelType: newModelType.value,
      newModelName: newModelName.value,
      newModelDataset: newModelDataset.value,
      newModelTemplate: newModelTemplate.value,
      newModelSampleItems: newModelSampleItems.value,
      newModelSelectedClassIds: newModelSelectedClassIds.value,
    })
    newModelValidationErrors.value = errors
  }

  const reset = (): void => {
    modelTemplates.value = []
    newModelValidationErrors.value = {}
    newModelType.value = ModelType.INSTANCE_SEGMENTATION
    newModelName.value = ''
    newModelDataset.value = null
    newModelTemplate.value = null
    newModelTrainingCounts.value = null
    newModelSampleItemsCursor.value = null
    newModelSampleItems.value = []
    newModelAnnotationClasses.value = []
    newModelClassCounts.value = null
    newModelSelectedClassIds.value = []
  }

  return {
    modelTemplates,
    loadPublishedModelTemplates,

    newModelTemplate,

    newModelType,

    newModelName,
    setNewModelName,

    newModelAnnotationClasses,
    newModelSelectedClassIds,
    newModelClassCounts,
    deselectAllNewModelClasses,
    toggleNewModelClassSelection,

    newModelDataset,
    setNewModelDataset,

    newModelSampleItems,
    newModelSampleItemsCursor,
    loadSampleDatasetItems,

    newModelTrainingCounts,
    loadNewModelTrainingCounts,

    newModelValidationErrors,
    trainModel,

    validateNewModel,

    reset,
  }
})
