module seq_flux_mct 1,9
  
  use shr_kind_mod,      only: r8 => shr_kind_r8, in=>shr_kind_in
  use shr_sys_mod,       only: shr_sys_abort
  use shr_flux_mod,      only: shr_flux_atmocn
  use shr_orb_mod,       only: shr_orb_params, shr_orb_cosz, shr_orb_decl

  use mct_mod
  use seq_flds_indices
  use seq_comm_mct
  use seq_cdata_mod
  use seq_infodata_mod

  implicit none
  private 	
  save

!--------------------------------------------------------------------------
! Public interfaces
!--------------------------------------------------------------------------

  public seq_flux_init_mct
  public seq_flux_atmocn_mct

!--------------------------------------------------------------------------
! Private data
!--------------------------------------------------------------------------

  real(r8), pointer     :: lats(:)    ! latitudes  (degrees)
  real(r8), pointer     :: lons(:)    ! longitudes (degrees)
  integer(in) , allocatable ::  mask (:)  ! ocn domain mask: 0 <=> inactive cell

  real(r8), allocatable ::  uocn (:)  ! ocn velocity, zonal
  real(r8), allocatable ::  vocn (:)  ! ocn velocity, meridional
  real(r8), allocatable ::  tocn (:)  ! ocean temperature
  real(r8), allocatable ::  zbot (:)  ! atm level height
  real(r8), allocatable ::  ubot (:)  ! atm velocity, zonal     
  real(r8), allocatable ::  vbot (:)  ! atm velocity, meridional
  real(r8), allocatable ::  thbot(:)  ! atm potential T
  real(r8), allocatable ::  shum (:)  ! atm specific humidity
  real(r8), allocatable ::  dens (:)  ! atm density
  real(r8), allocatable ::  tbot (:)  ! atm bottom surface T
  real(r8), allocatable ::  sen  (:)  ! heat flux: sensible 
  real(r8), allocatable ::  lat  (:)  ! heat flux: latent   
  real(r8), allocatable ::  lwup (:)  ! lwup over ocean
  real(r8), allocatable ::  evap (:)  ! water flux: evaporation
  real(r8), allocatable ::  taux (:)  ! wind stress, zonal
  real(r8), allocatable ::  tauy (:)  ! wind stress, meridional
  real(r8), allocatable ::  tref (:)  ! diagnostic:  2m ref T
  real(r8), allocatable ::  qref (:)  ! diagnostic:  2m ref Q
  real(r8), allocatable :: duu10n(:)  ! diagnostic: 10m wind speed squared

  real(r8), allocatable ::  ustar(:)  ! saved ustar
  real(r8), allocatable ::  re   (:)  ! saved re
  real(r8), allocatable ::  ssq  (:)  ! saved sq

  ! Conversion from degrees to radians

  real(r8),parameter :: const_pi      = SHR_CONST_PI       ! pi
  real(r8),parameter :: const_deg2rad = const_pi/180.0_R8  ! deg to rads

!===============================================================================
contains
!===============================================================================


  subroutine seq_flux_init_mct( cdata, xao) 1,1

    !-----------------------------------------------------------------------
    !
    ! Arguments
    !
    type(seq_cdata), intent(in)  :: cdata
    type(mct_aVect), intent(out) :: xao  
    !
    ! Local variables
    !

    integer(in)                     :: nloc
    type(mct_gGrid), pointer        :: dom
    type(mct_gsMap), pointer        :: gsMap
    integer                         :: lsize
    integer                         :: mpicom
    integer                         :: ier
    real(r8), pointer     ::  rmask(:)  ! ocn domain mask
    character(*),parameter :: subName =   '(seq_flux_init_mct) '
    !-----------------------------------------------------------------------

    ! Set cdata pointers
    call seq_cdata_setptrs(cdata, gsMap=gsMap, dom=dom, mpicom=mpicom)
    lsize=mct_gsMap_lsize(gsMap, mpicom)

    ! Initialize attribute vector
    call mct_aVect_init(xao, rList=seq_flds_xao_fields, lsize=lsize)
    call mct_aVect_zero(xao)

    ! Determine size of buffers
    nloc = mct_avect_lSize(xao)
    
    ! Input fields atm
    allocate( zbot(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate zbot',ier)
    allocate( ubot(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate ubot',ier)
    allocate( vbot(nloc))
    if(ier/=0) call mct_die(subName,'allocate vbot',ier)
    allocate(thbot(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate thbot',ier)
    allocate(shum(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate shum',ier)
    allocate(dens(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate dens',ier)
    allocate(tbot(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate tbot',ier)
    allocate(ustar(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate ustar',ier)
    allocate(re(nloc), stat=ier)
    if(ier/=0) call mct_die(subName,'allocate re',ier)
    allocate(ssq(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate ssq',ier)
    allocate( uocn(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate uocn',ier)
    allocate( vocn(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate vocn',ier)
    allocate( tocn(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate tocn',ier)

    ! Output fields 
    allocate(sen (nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate sen',ier)
    allocate(lat (nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate lat',ier)
    allocate(evap(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate evap',ier)
    allocate(lwup(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate lwup',ier)
    allocate(taux(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate taux',ier)
    allocate(tauy(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate tauy',ier)
    allocate(tref(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate tref',ier)
    allocate(qref(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate qref',ier)
    allocate(duu10n(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate duu10n',ier)

    ! Grid fields
    allocate( lats(nloc),stat=ier )
    if(ier/=0) call mct_die(subName,'allocate lats',ier)
    allocate( lons(nloc),stat=ier )
    if(ier/=0) call mct_die(subName,'allocate lons',ier)
    allocate(mask(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate mask',ier)
    
    ! Get lat, lon, mask, which is time-invariant
    allocate(rmask(nloc),stat=ier)
    if(ier/=0) call mct_die(subName,'allocate rmask',ier)
    call mct_gGrid_exportRAttr(dom, 'lat' , lats , nloc) 
    call mct_gGrid_exportRAttr(dom, 'lon' , lons , nloc) 
    call mct_gGrid_exportRAttr(dom, 'mask', rmask, nloc)
    mask = nint(rmask)
    deallocate(rmask)

  end subroutine seq_flux_init_mct

!===============================================================================


  subroutine seq_flux_atmocn_mct( cdata, a2x, o2x, xao, fractions_o, albedo_only ) 4,8

    !-----------------------------------------------------------------------
    !
    ! Arguments
    !
    type(seq_cdata), intent(in)    :: cdata
    type(mct_aVect), intent(in)    :: a2x  
    type(mct_aVect), intent(in)    :: o2x
    type(mct_aVect), intent(inout) :: xao
    type(mct_aVect), intent(inout) :: fractions_o
    logical, optional,intent(in)  :: albedo_only
    !
    ! Local variables
    !
    type(seq_infodata_type),pointer :: infodata
    integer(in) :: n,i          ! indices
    real(r8)    :: rlat         ! gridcell latitude in radians
    real(r8)    :: rlon         ! gridcell longitude in radians
    real(r8)    :: cosz         ! Cosine of solar zenith angle
    real(r8)    :: eccen        ! Earth orbit eccentricity
    real(r8)    :: mvelpp       ! Earth orbit
    real(r8)    :: lambm0       ! Earth orbit
    real(r8)    :: obliqr       ! Earth orbit
    real(r8)    :: delta        ! Solar declination angle  in radians
    real(r8)    :: eccf         ! Earth orbit eccentricity factor
    real(r8)    :: calday       ! calendar day including fraction, at 0e
    real(r8)    :: nextsw_cday  ! calendar day of next atm shortwave
    real(r8)    :: anidr        ! albedo: near infrared, direct
    real(r8)    :: avsdr        ! albedo: visible      , direct
    real(r8)    :: anidf        ! albedo: near infrared, diffuse
    real(r8)    :: avsdf        ! albedo: visible      , diffuse
    integer(in) :: nloc         ! number of gridcells
    integer(in) :: ID           ! comm ID
    logical     :: flux_albav   ! flux avg option
    logical     :: dead_comps   ! .true.  => dead components are used
    integer     :: kx,kr        ! fractions indices
    !
    real(R8),parameter :: albdif = 0.06_R8 ! 60 deg reference albedo, diffuse
    real(R8),parameter :: albdir = 0.07_R8 ! 60 deg reference albedo, direct 
    character(*),parameter :: subName =   '(seq_flux_atmocn_mct) '
    !
    !-----------------------------------------------------------------------

    call seq_cdata_setptrs(cdata, infodata=infodata, ID=ID)
    call seq_infodata_GetData(infodata,flux_albav=flux_albav)

    nloc = mct_aVect_lsize(xao)

    if (flux_albav) then
       do n=1,nloc   
          anidr = albdir
          avsdr = albdir
          anidf = albdif
          avsdf = albdif

          ! Albedo is now function of latitude (will be new implementation)
          !rlat = const_deg2rad * lats(n)
          !anidr = 0.069_R8 - 0.011_R8 * cos(2._R8 * rlat)
          !avsdr = anidr
          !anidf = anidr
          !avsdf = anidr

          xao%rAttr(index_xao_So_avsdr,n) = avsdr
          xao%rAttr(index_xao_So_anidr,n) = anidr
          xao%rAttr(index_xao_So_avsdf,n) = avsdf
          xao%rAttr(index_xao_So_anidf,n) = anidf

       end do
       kx = mct_aVect_indexRA(fractions_o,"ifrac")
       kr = mct_aVect_indexRA(fractions_o,"ifrad")
       fractions_o%rAttr(kr,:) = fractions_o%rAttr(kx,:)
       kx = mct_aVect_indexRA(fractions_o,"ofrac")
       kr = mct_aVect_indexRA(fractions_o,"ofrad")
       fractions_o%rAttr(kr,:) = fractions_o%rAttr(kx,:)

    else
       ! Solar declination 
       ! Will only do albedo calculation if nextsw_cday is not -1.
       
       call seq_infodata_GetData(infodata,nextsw_cday=nextsw_cday,orb_eccen=eccen, &
          orb_mvelpp=mvelpp, orb_lambm0=lambm0, orb_obliqr=obliqr)
       if (nextsw_cday >= -0.5_r8) then
          calday = nextsw_cday
          call shr_orb_decl(calday, eccen, mvelpp,lambm0, obliqr, delta, eccf)
          ! Compute albedos 
          do n=1,nloc
             rlat = const_deg2rad * lats(n)
             rlon = const_deg2rad * lons(n)
             cosz = shr_orb_cosz( calday, rlat, rlon, delta )
             if (cosz  >  0.0_R8) then !--- sun hit --
                anidr = (.026_R8/(cosz**1.7_R8 + 0.065_R8)) +   &
                        (.150_R8*(cosz         - 0.100_R8 ) *   &
                                 (cosz         - 0.500_R8 ) *   &
                                 (cosz         - 1.000_R8 )  )
                avsdr = anidr
                anidf = albdif
                avsdf = albdif
             else !--- dark side of earth ---
                anidr = 1.0_R8
                avsdr = 1.0_R8
                anidf = 1.0_R8
                avsdf = 1.0_R8
             end if
             xao%rAttr(index_xao_So_avsdr,n) = avsdr
             xao%rAttr(index_xao_So_anidr,n) = anidr
             xao%rAttr(index_xao_So_avsdf,n) = avsdf
             xao%rAttr(index_xao_So_anidf,n) = anidf
          end do   ! nloc
          kx = mct_aVect_indexRA(fractions_o,"ifrac")
          kr = mct_aVect_indexRA(fractions_o,"ifrad")
          fractions_o%rAttr(kr,:) = fractions_o%rAttr(kx,:)
          kx = mct_aVect_indexRA(fractions_o,"ofrac")
          kr = mct_aVect_indexRA(fractions_o,"ofrad")
          fractions_o%rAttr(kr,:) = fractions_o%rAttr(kx,:)
       endif    ! nextsw_cday
    end if   ! flux_albav
       
    if (present(albedo_only)) then
       if (albedo_only) then
          if (seq_comm_iamroot(ID)) &
             write(logunit,'(2A)') trim(subname),' computing only ocn albedos '
          RETURN
       endif
    endif

    !-----------------------------------------------------

    ! Update ocean surface fluxes 
    ! Must fabricate "reasonable" data (using dead components)

    call seq_infodata_GetData(infodata, dead_comps=dead_comps)

    if (dead_comps) then
       do n = 1,nloc
          mask(n) =   1      ! ocn domain mask            ~ 0 <=> inactive cell
          tocn(n) = 290.0_R8 ! ocn temperature            ~ Kelvin
          uocn(n) =   0.0_R8 ! ocn velocity, zonal        ~ m/s
          vocn(n) =   0.0_R8 ! ocn velocity, meridional   ~ m/s
          zbot(n) =  55.0_R8 ! atm height of bottom layer ~ m
          ubot(n) =   0.0_R8 ! atm velocity, zonal        ~ m/s
          vbot(n) =   2.0_R8 ! atm velocity, meridional   ~ m/s
          thbot(n)= 301.0_R8 ! atm potential temperature  ~ Kelvin
          shum(n) = 1.e-2_R8 ! atm specific humidity      ~ kg/kg
          dens(n) =   1.0_R8 ! atm density                ~ kg/m^3
          tbot(n) = 300.0_R8 ! atm temperature            ~ Kelvin
       enddo
    else	
       do n = 1,nloc
          if (mask(n) /= 0) then	
             zbot(n) = a2x%rAttr(index_a2x_Sa_z   ,n)
             ubot(n) = a2x%rAttr(index_a2x_Sa_u   ,n)
             vbot(n) = a2x%rAttr(index_a2x_Sa_v   ,n)
             thbot(n)= a2x%rAttr(index_a2x_Sa_ptem,n)
             shum(n) = a2x%rAttr(index_a2x_Sa_shum,n)
             dens(n) = a2x%rAttr(index_a2x_Sa_dens,n)
             tbot(n) = a2x%rAttr(index_a2x_Sa_tbot,n)
             tocn(n) = o2x%rAttr(index_o2x_So_t   ,n)   
             uocn(n) = o2x%rAttr(index_o2x_So_u   ,n)
             vocn(n) = o2x%rAttr(index_o2x_So_v   ,n)
          end if
       enddo
    end if

    call shr_flux_atmocn (nloc , zbot , ubot, vbot, thbot, &
                          shum , dens , tbot, uocn, vocn , &
                          tocn , mask , sen , lat , lwup , &
                          evap , taux , tauy, tref, qref , &
                          duu10n,ustar, re  , ssq )

    do n = 1,nloc
       if (mask(n) /= 0) then	
          xao%rAttr(index_xao_Faox_sen ,n) = sen(n)
          xao%rAttr(index_xao_Faox_lat ,n) = lat(n)
          xao%rAttr(index_xao_Faox_taux,n) = taux(n)
          xao%rAttr(index_xao_Faox_tauy,n) = tauy(n)
          xao%rAttr(index_xao_Faox_evap,n) = evap(n)
          xao%rAttr(index_xao_So_tref  ,n) = tref(n)
	  xao%rAttr(index_xao_So_qref  ,n) = qref(n)
          xao%rAttr(index_xao_So_ustar ,n) = ustar(n)  ! friction velocity
          xao%rAttr(index_xao_So_re    ,n) = re(n)     ! reynolds number
          xao%rAttr(index_xao_So_ssq   ,n) = ssq(n)    ! s.hum. saturation at Ts
          xao%rAttr(index_xao_Faox_lwup,n) = lwup(n)   
          xao%rAttr(index_xao_Sx_duu10n,n) = duu10n(n)  
       end if
    enddo

  end subroutine seq_flux_atmocn_mct

end module seq_flux_mct