import type { AnnotationType } from '@/core/annotationTypes'
import { getMainAnnotationType, getSubAnnotationTypes } from '@/core/annotationTypes'
import type { NeuralModelAction } from '@/store/modules/neuralModel/types'
import { ModelDevice } from '@/store/modules/neuralModel/types'
import type { AnnotationClassPayload } from '@/store/types'
import type { ValidationError } from '@/store/types/ValidationError'
import type { ParsedValidationError } from '@/backend/error'
import { WindErrorCodes } from '@/backend/error/errors'
import { trainModel as windTrainModel } from '@/backend/wind'
import type { TrainingClass, TrainingSessionPayload } from '@/backend/wind/types'

const annotationClassToTrainingClass = (
  annotationClass: AnnotationClassPayload,
  type: AnnotationType,
  subs: AnnotationType[],
): Omit<TrainingClass, 'id'> => {
  const { id, name } = annotationClass
  return { darwin_id: id, name, type, subs }
}

export type TrainModel = NeuralModelAction<void, TrainingSessionPayload>

export const trainModel: TrainModel = async ({ commit, state }) => {
  const { newModelValidationErrors: errors } = state

  if (Object.keys(errors).length > 0) {
    const error: ParsedValidationError = {
      errors: errors as ValidationError,
      isValidationError: true,
      message: 'Data is invalid',
    }

    return { error }
  }

  const {
    newModelAnnotationClasses,
    newModelDataset,
    newModelName,
    newModelSelectedClassIds,
    newModelTemplate,
  } = state

  if (!newModelDataset || !newModelTemplate) {
    throw new Error('Invalid action outcome. Template and dataset should be set')
  }

  const { id: datasetId, slug: datasetSlug, team_id: teamId, team_slug: teamSlug } = newModelDataset

  if (teamId !== newModelDataset.team_id) {
    throw new Error(
      '[neuralModel/trainModel]: Training model for dataset which is not part of current team',
    )
  }

  const classes = newModelAnnotationClasses
    .filter((c) => newModelSelectedClassIds.includes(c.id))
    .map((c) => {
      const mainType = getMainAnnotationType(c.annotation_types)
      if (!mainType) {
        throw new Error('Model training encountered a class with no main type')
      }
      const subTypes = getSubAnnotationTypes(c.annotation_types)
      return annotationClassToTrainingClass(c, mainType, subTypes)
    })

  const params = {
    datasetId,
    datasetSlug,
    device: ModelDevice.GPU,
    classes,
    modelTemplateId: newModelTemplate.id,
    name: newModelName,
    teamId,
    teamSlug,
  }

  const response = await windTrainModel(params)
  if (
    'error' in response &&
    'isValidationError' in response.error &&
    response.error.isValidationError
  ) {
    commit('SET_NEW_MODEL_VALIDATION_ERRORS_FROM_BACKEND', response.error)
  }

  if ('error' in response && response.error.code === WindErrorCodes.NO_PAYMENT_METHOD) {
    const message = 'You need to provide a valid payment method in order to train models'
    return { error: { ...response.error, message } }
  }

  if (
    'error' in response &&
    response.error.code === WindErrorCodes.PARTNER_DOES_NOT_COVER_NEURAL_NETWORKS
  ) {
    const message = [
      'Your partner does not cover your neural network costs.',
      'You will have to discuss and change your relationship with them to start training models.',
    ].join(' ')

    return { error: { ...response.error, message } }
  }

  if ('data' in response) {
    commit('PUSH_TRAINING_SESSION', response.data)
  }

  return response
}
