cmake_minimum_required(VERSION 3.25)

if(NOT MLX_VERSION)
  file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
  string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
  set(_major ${CMAKE_MATCH_1})
  string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
  set(_minor ${CMAKE_MATCH_1})
  string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
  set(_patch ${CMAKE_MATCH_1})
  set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
  set(MLX_VERSION ${MLX_PROJECT_VERSION})
else()
  string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
                       ${MLX_VERSION})
endif()

project(
  mlx
  LANGUAGES C CXX
  VERSION ${MLX_PROJECT_VERSION})

# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

# ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_PYTHON_STUBS "Build stub files for python bindings" ON)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
option(USE_ASAN "Enable AddressSanitizer (ASan)" OFF)
option(USE_UBSAN "Enable UndefinedBehaviorSanitizer (UBSan)" OFF)
option(USE_TSAN "Enable ThreadSanitizer (TSan)" OFF)

set(MLX_VENDOR_DIR "${CMAKE_CURRENT_LIST_DIR}/..")
set(MLX_LOCAL_FMT_DIR "${MLX_VENDOR_DIR}/fmt")
set(MLX_LOCAL_JSON_DIR "${MLX_VENDOR_DIR}/json")
set(MLX_LOCAL_METAL_CPP_DIR "${MLX_VENDOR_DIR}/metal-cpp")

if(APPLE)
  get_filename_component(_mlx_sysroot_name "${CMAKE_OSX_SYSROOT}" NAME)
  if(_mlx_sysroot_name STREQUAL "")
    set(_mlx_sysroot_name "${CMAKE_OSX_SYSROOT}")
  endif()
  if(_mlx_sysroot_name STREQUAL "")
    set(_mlx_sysroot_name "macosx")
  endif()
  string(TOLOWER "${_mlx_sysroot_name}" _mlx_sysroot_name)

  if(_mlx_sysroot_name MATCHES "iphoneos")
    set(MLX_METAL_SDK "iphoneos")
    set(MLX_MIN_DEPLOYMENT_TARGET "17.0")
    if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
      set(MLX_METAL_MIN_FLAG
          "-mios-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
    endif()
  elseif(_mlx_sysroot_name MATCHES "iphonesimulator")
    set(MLX_METAL_SDK "iphonesimulator")
    set(MLX_MIN_DEPLOYMENT_TARGET "17.0")
    if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
      set(MLX_METAL_MIN_FLAG
          "-mios-simulator-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
    endif()
  else()
    set(MLX_METAL_SDK "macosx")
    set(MLX_MIN_DEPLOYMENT_TARGET "14.0")
    if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
      set(MLX_METAL_MIN_FLAG
          "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
    endif()
  endif()
endif()

# --------------------- Processor tests -------------------------
message(
  STATUS
    "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
)

if(APPLE)
  if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64" AND MLX_METAL_SDK STREQUAL
                                                    "macosx")
    if(NOT MLX_ENABLE_X64_MAC)
      message(
        FATAL_ERROR
          "Building for x86_64 on macOS is not supported."
          " If you are on an Apple silicon system, check the build"
          " documentation for possible fixes: "
          "https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
      )
    else()
      set(MLX_BUILD_METAL OFF)
      message(WARNING "Building for x86_64 arch is not officially supported.")
    endif()
  endif()
else()
  set(MLX_BUILD_METAL OFF)
endif()

if(MLX_USE_CCACHE)
  find_program(CCACHE_PROGRAM ccache)
  if(CCACHE_PROGRAM)
    message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
    set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
    set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
    set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
  endif()
endif()

if(USE_ASAN AND USE_TSAN)
  message(
    FATAL_ERROR
      "AddressSanitizer (ASan) and ThreadSanitizer (TSan) are mutually exclusive and cannot be enabled at the same time."
  )
endif()

set(SANITIZER_COMPILE_FLAGS "")
set(SANITIZER_LINK_FLAGS "")

if(USE_ASAN)
  if(WIN32 AND MSVC)
    list(APPEND SANITIZER_COMPILE_FLAGS /fsanitize=address)
    list(APPEND SANITIZER_LINK_FLAGS /fsanitize=address)
  else()
    list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=address)
    list(APPEND SANITIZER_LINK_FLAGS -fsanitize=address)
    if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
      list(APPEND SANITIZER_LINK_FLAGS -lpthread)
    endif()
  endif()
endif()

if(USE_UBSAN)
  if(WIN32 AND MSVC)
    if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
      list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined)
      list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined)
    else()
      message(
        WARNING
          "UndefinedBehaviorSanitizer (UBSan) is not directly supported via a simple flag in MSVC."
      )
    endif()
  else()
    list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined)
    list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined)
  endif()
endif()

if(USE_TSAN)
  if(WIN32 AND MSVC)
    message(
      FATAL_ERROR
        "ThreadSanitizer (TSan) is not supported by the MSVC compiler. Please use Clang or GCC."
    )
  elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
    message(FATAL_ERROR "ThreadSanitizer (TSan) is not supported on macOS.")
  else()
    list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=thread)
    list(APPEND SANITIZER_LINK_FLAGS -fsanitize=thread)
    if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
      list(APPEND SANITIZER_LINK_FLAGS -lpthread)
    endif()
  endif()
endif()

# ----------------------------- Lib -----------------------------

include(FetchContent)
# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
cmake_policy(SET CMP0135 NEW)

add_library(mlx)

target_compile_options(mlx PUBLIC ${SANITIZER_COMPILE_FLAGS})
target_link_options(mlx PUBLIC ${SANITIZER_LINK_FLAGS})

if(MLX_BUILD_CUDA)
  enable_language(CUDA)
  find_package(CUDAToolkit REQUIRED)
  find_package(CUDNN REQUIRED)
endif()

if(MLX_BUILD_METAL)
  find_library(METAL_LIB Metal)
  find_library(FOUNDATION_LIB Foundation)
  find_library(QUARTZ_LIB QuartzCore)
  if(METAL_LIB)
    message(STATUS "Metal found ${METAL_LIB}")
  else()
    message(
      FATAL_ERROR
        "Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
  endif()

  if(MLX_METAL_DEBUG)
    add_compile_definitions(MLX_METAL_DEBUG)
  endif()

  # Throw an error if xcrun not found
  execute_process(
    COMMAND zsh "-c" "/usr/bin/xcrun -sdk ${MLX_METAL_SDK} --show-sdk-version"
    OUTPUT_VARIABLE MLX_APPLE_SDK_VERSION
    OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)

  if(${MLX_APPLE_SDK_VERSION} LESS ${MLX_MIN_DEPLOYMENT_TARGET})
    message(
      FATAL_ERROR
        "MLX requires ${MLX_METAL_SDK} SDK >= ${MLX_MIN_DEPLOYMENT_TARGET} to be built with MLX_BUILD_METAL=ON"
    )
  endif()
  message(STATUS "Building with ${MLX_METAL_SDK} SDK version ${MLX_APPLE_SDK_VERSION}")

  set(METAL_CPP_URL
      https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)

  if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
    if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS ${MLX_MIN_DEPLOYMENT_TARGET})
      message(FATAL_ERROR "MLX requires deployment target >= ${MLX_MIN_DEPLOYMENT_TARGET}")
    endif()
    set(XCRUN_FLAGS "${MLX_METAL_MIN_FLAG}")
  endif()
  execute_process(
    COMMAND
      zsh "-c"
      "echo \"__METAL_VERSION__\" | xcrun -sdk ${MLX_METAL_SDK} metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
    OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
  if(EXISTS "${MLX_LOCAL_METAL_CPP_DIR}/Metal/Metal.hpp")
    message(STATUS "Using vendored metal-cpp from ${MLX_LOCAL_METAL_CPP_DIR}")
    set(metal_cpp_SOURCE_DIR ${MLX_LOCAL_METAL_CPP_DIR})
  else()
    FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
    FetchContent_MakeAvailable(metal_cpp)
  endif()
  target_include_directories(
    mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
               $<INSTALL_INTERFACE:include/metal_cpp>)
  target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()

if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
  # With newer clang/gcc versions following libs are implicitly linked, but when
  # building on old distributions they need to be explicitly listed.
  target_link_libraries(mlx PRIVATE dl pthread)
endif()

if(WIN32)
  if(MSVC)
    # GGUF does not build with MSVC.
    set(MLX_BUILD_GGUF OFF)
  endif()
  # Generate DLL and EXE in the same dir, otherwise EXE will not be able to run.
  # This is only done when MLX is built as the top project.
  if(CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
    set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
  endif()
  # Windows implementation of dlfcn.h APIs.
  FetchContent_Declare(
    dlfcn-win32
    GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
    GIT_TAG v1.4.2
    EXCLUDE_FROM_ALL)
  block()
  set(BUILD_SHARED_LIBS OFF)
  FetchContent_MakeAvailable(dlfcn-win32)
  endblock()
  target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
  target_link_libraries(mlx PRIVATE dl)
endif()

if(MLX_BUILD_CPU)
  find_library(ACCELERATE_LIBRARY Accelerate)
  if(ACCELERATE_LIBRARY)
    message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
    set(MLX_BUILD_ACCELERATE ON)
  else()
    message(STATUS "Accelerate not found, using default backend.")
    set(MLX_BUILD_ACCELERATE OFF)
  endif()

  if(MLX_BUILD_ACCELERATE)
    target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
    add_compile_definitions(MLX_USE_ACCELERATE)
    add_compile_definitions(ACCELERATE_NEW_LAPACK)
  elseif(WIN32)
    # Download and link prebuilt binaries of OpenBLAS. Note that we can only
    # link with the dynamic library, the prebuilt binaries were built with MinGW
    # so static-linking would require linking with MinGW's runtime.
    FetchContent_Declare(
      openblas
      URL "https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.31/OpenBLAS-0.3.31-x64.zip"
    )
    FetchContent_MakeAvailable(openblas)
    target_link_libraries(mlx
                          PRIVATE "${openblas_SOURCE_DIR}/lib/libopenblas.lib")
    target_include_directories(mlx PRIVATE "${openblas_SOURCE_DIR}/include")
    # Make sure the DLL file is placed in the same dir with executables.
    set(OPENBLAS_DLL_FILE "${openblas_SOURCE_DIR}/bin/libopenblas.dll")
    add_custom_command(
      TARGET mlx
      POST_BUILD
      COMMAND ${CMAKE_COMMAND} -E copy_if_different ${OPENBLAS_DLL_FILE}
              ${CMAKE_BINARY_DIR})
  else()
    if(${CMAKE_HOST_APPLE})
      # The blas shipped in macOS SDK is not supported, search homebrew for
      # openblas instead.
      set(BLA_VENDOR OpenBLAS)
      set(LAPACK_ROOT
          "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
    endif()
    # Search and link with lapack.
    find_package(LAPACK REQUIRED)
    if(NOT LAPACK_FOUND)
      message(FATAL_ERROR "Must have LAPACK installed")
    endif()
    find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
              /usr/local/opt/openblas/include)
    message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
    message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
    target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
    target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
    # List blas after lapack otherwise we may accidentally incldue an old
    # version of lapack.h from the include dirs of blas.
    find_package(BLAS REQUIRED)
    if(NOT BLAS_FOUND)
      message(FATAL_ERROR "Must have BLAS installed")
    endif()
    # TODO find a cleaner way to do this
    find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
              $ENV{BLAS_HOME}/include)
    message(STATUS "Blas lib " ${BLAS_LIBRARIES})
    message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
    target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
    target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
  endif()
else()
  set(MLX_BUILD_ACCELERATE OFF)
endif()

if(EXISTS "${MLX_LOCAL_JSON_DIR}/single_include/nlohmann/json.hpp")
  message(STATUS "Using vendored json from ${MLX_LOCAL_JSON_DIR}")
  set(json_SOURCE_DIR ${MLX_LOCAL_JSON_DIR})
else()
  message(STATUS "Downloading json")
  FetchContent_Declare(
    json
    URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
  FetchContent_MakeAvailable(json)
endif()
target_include_directories(
  mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)

add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)

target_include_directories(
  mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
             $<INSTALL_INTERFACE:include>)

if(USE_SYSTEM_FMT)
  find_package(fmt REQUIRED)
elseif(EXISTS "${MLX_LOCAL_FMT_DIR}/include/fmt/format.h")
  message(STATUS "Using vendored fmt from ${MLX_LOCAL_FMT_DIR}")
  if(NOT TARGET fmt::fmt-header-only)
    add_library(fmt-header-only INTERFACE)
    add_library(fmt::fmt-header-only ALIAS fmt-header-only)
    target_include_directories(fmt-header-only
                               INTERFACE ${MLX_LOCAL_FMT_DIR}/include)
    target_compile_definitions(fmt-header-only INTERFACE FMT_HEADER_ONLY)
  endif()
else()
  FetchContent_Declare(
    fmt
    GIT_REPOSITORY https://github.com/fmtlib/fmt.git
    GIT_TAG 12.1.0
    EXCLUDE_FROM_ALL)
  FetchContent_MakeAvailable(fmt)
endif()
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)

if(MLX_BUILD_PYTHON_BINDINGS)
  message(STATUS "Building Python bindings.")
  find_package(
    Python 3.10
    COMPONENTS Interpreter Development.Module
    REQUIRED)
  FetchContent_Declare(
    nanobind
    GIT_REPOSITORY https://github.com/wjakob/nanobind.git
    GIT_TAG v2.10.2
    GIT_SHALLOW TRUE
    EXCLUDE_FROM_ALL)
  FetchContent_MakeAvailable(nanobind)
  add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()

if(MLX_BUILD_TESTS)
  include(CTest)
  add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
endif()

if(MLX_BUILD_EXAMPLES)
  add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
endif()

if(MLX_BUILD_BENCHMARKS)
  add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
endif()

# ----------------------------- Installation -----------------------------
include(GNUInstallDirs)

if(WIN32)
  # Install DLLs to the same dir with extension file (core.pyd) on Windows.
  set(CMAKE_INSTALL_BINDIR ".")
  if(MLX_BUILD_CPU)
    # Install OpenBLAS.
    install(FILES ${OPENBLAS_DLL_FILE} TYPE BIN)
  endif()
endif()

# Install library
install(
  TARGETS mlx
  EXPORT MLXTargets
  LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
  ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
  RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
  INCLUDES
  DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})

# Install headers
install(
  DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
  DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
  COMPONENT headers
  FILES_MATCHING
  PATTERN "*.h"
  PATTERN "backend/metal/kernels.h" EXCLUDE)

# Install metal dependencies
if(MLX_BUILD_METAL)

  # Install metal cpp
  install(
    DIRECTORY ${metal_cpp_SOURCE_DIR}/
    DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
    COMPONENT metal_cpp_source)

endif()

# Install cmake config
set(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake)
set(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake)
set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)

install(
  EXPORT MLXTargets
  FILE MLXTargets.cmake
  DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})

include(CMakePackageConfigHelpers)

write_basic_package_version_file(
  ${MLX_CMAKE_BUILD_VERSION_CONFIG}
  COMPATIBILITY SameMajorVersion
  VERSION ${MLX_VERSION})

configure_package_config_file(
  ${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
  INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
  NO_CHECK_REQUIRED_COMPONENTS_MACRO
  PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
            MLX_CMAKE_INSTALL_MODULE_DIR)

install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
        DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})

install(DIRECTORY ${CMAKE_MODULE_PATH}/
        DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
