load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"

NSYNC_SHA256 = ""

ENZYMEXLA_COMMIT = "8487993f5e3547ee29505ea03e4c832b68023e59"

ENZYMEXLA_SHA256 = ""

http_archive(
    name = "nsync",
    sha256 = NSYNC_SHA256,
    strip_prefix = "nsync-" + NSYNC_COMMIT,
    urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

# Required by Perfetto.
http_archive(
    name = "rules_android",
    sha256 = "",
    strip_prefix = "rules_android-0.7.1",
    url = "https://github.com/bazelbuild/rules_android/releases/download/v0.7.1/rules_android-v0.7.1.tar.gz",
)

http_archive(
    name = "perfetto",
    sha256 = "",
    strip_prefix = "perfetto-53.0",
    urls = ["https://github.com/google/perfetto/archive/refs/tags/v53.0.tar.gz"],
)

http_archive(
    name = "perfetto_cfg",
    build_file_content = "exports_files([\"perfetto_cfg.bzl\"])",
    sha256 = "",
    strip_prefix = "perfetto-53.0/bazel/standalone",
    urls = ["https://github.com/google/perfetto/archive/refs/tags/v53.0.tar.gz"],
)

http_archive(
    name = "enzyme_ad",
    patch_cmds = [
        """
sed -i.bak0 "s/\\\\\\\\\\\\\\\\\\/\\\\\\\\\\\\\\\\\\/:patches/@enzyme_ad\\\\\\\\\\\\\\\\\\/\\\\\\\\\\\\\\\\\\/:patches/g" workspace.bzl
sed -i.bak0 "s,//:patches,@enzyme_ad//:patches,g" third_party/*/workspace.bzl
""",
    ],
    sha256 = ENZYMEXLA_SHA256,
    strip_prefix = "Enzyme-JAX-" + ENZYMEXLA_COMMIT,
    urls = ["https://github.com/EnzymeAD/Enzyme-JAX/archive/{commit}.tar.gz".format(commit = ENZYMEXLA_COMMIT)],
)

NEW_XLA_PATCHES = []

LLVM_TARGETS = [
    "AMDGPU",
    "NVPTX",
] + [
    "AArch64",
    "X86",
    "ARM",
]
#+ [
#    "PowerPC",
#    "SystemZ"
#]

# Uncomment these lines to use a custom LLVM commit
# LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3"
# LLVM_SHA256 = ""
# http_archive(
#     name = "llvm-raw",
#     build_file_content = "# empty",
#     sha256 = LLVM_SHA256,
#     strip_prefix = "llvm-project-" + LLVM_COMMIT,
#     urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)],
# )
#
#
# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
# maybe(
#     http_archive,
#     name = "llvm_zlib",
#     build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD",
#     sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731",
#     strip_prefix = "zlib-ng-2.0.7",
#     urls = [
#         "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip",
#     ],
# )
#
# maybe(
#     http_archive,
#     name = "llvm_zstd",
#     build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD",
#     sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0",
#     strip_prefix = "zstd-1.5.2",
#     urls = [
#         "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz"
#     ],
# )

load("@enzyme_ad//third_party/xprof:workspace.bzl", xprof_workspace = "repo")

xprof_workspace()

load("@enzyme_ad//third_party/ml_toolchain:workspace.bzl", ml_toolchain_workspace = "repo")

ml_toolchain_workspace()

load("@enzyme_ad//third_party/jax:workspace.bzl", jax_workspace = "repo")

jax_workspace([])

load("@enzyme_ad//third_party/xla:workspace.bzl", xla_workspace = "repo")

xla_workspace(NEW_XLA_PATCHES)

#
# load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip")
# python_init_pip()
#
# load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules")
# python_init_rules()
#
# load("@rules_python//python:repositories.bzl", "py_repositories")
#
# py_repositories()
#
# load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependencies")
#
# pip_install_dependencies()

load("@enzyme_ad//third_party/enzyme:workspace.bzl", enzyme_workspace = "repo")
load("@enzyme_ad//third_party/cuda_tile:workspace.bzl", cuda_tile_workspace = "repo")

enzyme_workspace()

cuda_tile_workspace("@enzyme_ad")

# http_archive(
#     name = "upb",
#     sha256 = "61d0417abd60e65ed589c9deee7c124fe76a4106831f6ad39464e1525cef1454",
#     strip_prefix = "upb-9effcbcb27f0a665f9f345030188c0b291e32482",
#     patch_cmds = [
#         "sed -i.bak0 's/@bazel_tools\\/\\/platforms:windows/@platforms\\/\\/os:windows/g' BUILD",
#         "sed -i.bak0 's/-Werror//g' BUILD"
#     ],
#     url = "https://github.com/protocolbuffers/upb/archive/9effcbcb27f0a665f9f345030188c0b291e32482.tar.gz"
# )

load("@jax//third_party/xla:workspace.bzl", jax_xla_workspace = "repo")

jax_xla_workspace()

load("@xla//:workspace4.bzl", "xla_workspace4")

xla_workspace4()

load("@xla//:workspace3.bzl", "xla_workspace3")

xla_workspace3()

load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules")

python_init_rules()

load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")

python_init_repositories(
    requirements = {
        "3.9": "//build:requirements_lock_3_9.txt",
        "3.10": "//build:requirements_lock_3_10.txt",
        "3.11": "//build:requirements_lock_3_11.txt",
        "3.12": "//build:requirements_lock_3_12.txt",
        "3.13": "//build:requirements_lock_3_13.txt",
    },
)

load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")

python_init_toolchains()

load("@xla//third_party/llvm:workspace.bzl", llvm = "repo")

llvm("llvm-raw")

load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")

llvm_configure(
    name = "llvm-project",
    targets = LLVM_TARGETS,
)

load("@xla//:workspace2.bzl", "xla_workspace2")

xla_workspace2()

load("@xla//:workspace1.bzl", "xla_workspace1")

xla_workspace1()

load("@xla//:workspace0.bzl", "xla_workspace0")

xla_workspace0()

load("@jax//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository")

jax_python_wheel_repository(
    name = "jax_wheel",
    version_key = "_version",
    version_source = "@jax//jax:version.py",
)

load(
    "@xla//third_party/py:python_wheel.bzl",
    "nvidia_wheel_versions_repository",
    "python_wheel_version_suffix_repository",
)

nvidia_wheel_versions_repository(
    name = "nvidia_wheel_versions",
    versions_source = "@jax//build:nvidia-requirements.txt",
)

python_wheel_version_suffix_repository(
    name = "jax_wheel_version_suffix",
)

load("@jax//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")

flatbuffers()

load("@jax//third_party/external_deps:workspace.bzl", "external_deps_repository")

external_deps_repository(name = "rocm_external_test_deps")

load("@jax//:test_shard_count.bzl", "test_shard_count_repository")

test_shard_count_repository(
    name = "test_shard_count",
)

load(
    "@rules_ml_toolchain//cc/deps:cc_toolchain_deps.bzl",
    "cc_toolchain_deps",
)

cc_toolchain_deps()

register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64")

register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64_cuda")

load(
    "@rules_ml_toolchain//gpu/cuda:cuda_json_init_repository.bzl",
    "cuda_json_init_repository",
)

cuda_json_init_repository()

load(
    "@cuda_redist_json//:distributions.bzl",
    "CUDA_REDISTRIBUTIONS",
    "CUDNN_REDISTRIBUTIONS",
)
load(
    "@rules_ml_toolchain//gpu/cuda:cuda_redist_init_repositories.bzl",
    "cuda_redist_init_repositories",
    "cudnn_redist_init_repository",
)
load(
    "@rules_ml_toolchain//gpu/cuda:cuda_redist_versions.bzl",
    "REDIST_VERSIONS_TO_BUILD_TEMPLATES",
)
load("@xla//third_party/cccl:workspace.bzl", "CCCL_3_2_0_DIST_DICT", "CCCL_GITHUB_VERSIONS_TO_BUILD_TEMPLATES")

cuda_redist_init_repositories(
    cuda_redistributions = CUDA_REDISTRIBUTIONS | CCCL_3_2_0_DIST_DICT,
    redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES | CCCL_GITHUB_VERSIONS_TO_BUILD_TEMPLATES,
)

cudnn_redist_init_repository(
    cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
)

load(
    "@rules_ml_toolchain//gpu/cuda:cuda_configure.bzl",
    "cuda_configure",
)

cuda_configure(name = "local_config_cuda")

load(
    "@rules_ml_toolchain//gpu/nccl:nccl_redist_init_repository.bzl",
    "nccl_redist_init_repository",
)

nccl_redist_init_repository()

load(
    "@rules_ml_toolchain//gpu/nccl:nccl_configure.bzl",
    "nccl_configure",
)

nccl_configure(name = "local_config_nccl")

load(
    "@rules_ml_toolchain//gpu/nvshmem:nvshmem_json_init_repository.bzl",
    "nvshmem_json_init_repository",
)

nvshmem_json_init_repository()

load(
    "@nvshmem_redist_json//:distributions.bzl",
    "NVSHMEM_REDISTRIBUTIONS",
)
load(
    "@rules_ml_toolchain//gpu/nvshmem:nvshmem_redist_init_repository.bzl",
    "nvshmem_redist_init_repository",
)

nvshmem_redist_init_repository(
    nvshmem_redistributions = NVSHMEM_REDISTRIBUTIONS,
)

# Hedron's Compile Commands Extractor for Bazel
# https://github.com/hedronvision/bazel-compile-commands-extractor
http_archive(
    name = "hedron_compile_commands",
    strip_prefix = "bazel-compile-commands-extractor-4f28899228fb3ad0126897876f147ca15026151e",

    # Replace the commit hash (0e990032f3c5a866e72615cf67e5ce22186dcb97) in both places (below) with the latest (https://github.com/hedronvision/bazel-compile-commands-extractor/commits/main), rather than using the stale one here.
    # Even better, set up Renovate and let it do the work for you (see "Suggestion: Updates" in the README).
    url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/4f28899228fb3ad0126897876f147ca15026151e.tar.gz",
    # When you first run this tool, it'll recommend a sha256 hash to put here with a message like: "DEBUG: Rule 'hedron_compile_commands' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = ..."
)

load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup")

hedron_compile_commands_setup()

load("@hedron_compile_commands//:workspace_setup_transitive.bzl", "hedron_compile_commands_setup_transitive")

hedron_compile_commands_setup_transitive()

load("@hedron_compile_commands//:workspace_setup_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive")

hedron_compile_commands_setup_transitive_transitive()

load("@hedron_compile_commands//:workspace_setup_transitive_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive_transitive")

hedron_compile_commands_setup_transitive_transitive_transitive()
