From 1c7f2433c5f73f4a83c14e790d03e11caad2f761 Mon Sep 17 00:00:00 2001 From: Alexandre Strube <a.strube@fz-juelich.de> Date: Mon, 18 Jan 2021 16:05:36 +0100 Subject: [PATCH] Merge --- .pre-commit-config.yaml | 3 +- Custom_EasyBlocks/jax.py | 89 +++++++++++++++++++ .../JAX-0.1.77-gpsmkl-2020-Python-3.8.5.eb | 31 +------ 3 files changed, 92 insertions(+), 31 deletions(-) create mode 100644 Custom_EasyBlocks/jax.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9e2cc1e3e..3e08d14d2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,4 +2,5 @@ repos: - repo: https://github.com/pre-commit/mirrors-autopep8 rev: '' # Use the sha / tag you want to point at hooks: - - id: autopep8 \ No newline at end of file + - id: autopep8 + args: [-v] \ No newline at end of file diff --git a/Custom_EasyBlocks/jax.py b/Custom_EasyBlocks/jax.py new file mode 100644 index 000000000..861304592 --- /dev/null +++ b/Custom_EasyBlocks/jax.py @@ -0,0 +1,89 @@ +""" +EasyBuild support for building and installing JAX, implemented as an easyblock + +@author: Andrew Edmondson (University of Birmingham) +""" +from easybuild.framework.easyblock import EasyBlock +from easybuild.tools.systemtools import POWER, get_cpu_architecture +from easybuild.tools.modules import get_software_root +from easybuild.tools.run import run_cmd +from easybuild.easyblocks.generic.pythonpackage import det_pylibdir + + +class EB_JAX(EasyBlock): + """Support for building/installing JAX.""" + + def configure_step(self): + """No configuration for JAX.""" + pass + + def build_step(self): + """Build JAX""" + cuda = get_software_root('CUDA') + cudnn = get_software_root('CuDNN') + bazel = get_software_root('Bazel') + binutils = get_software_root('binutils') + if not self.cfg['prebuildopts']: + self.cfg['prebuildopts'] = 'export BAZEL_LINKOPTS=-static-libstdc++:-static-libgcc; ' + self.cfg['prebuildopts'] = 'export BAZEL_LINKLIBS=-l%:libstdc++.a:-lm BAZEL_CXXOPTS=-std=gnu++0x ;' + self.cfg['prebuildopts'] = 'export TF_CUDA_PATHS="{}" '.format(cuda) + self.cfg['prebuildopts'] += 'GCC_HOST_COMPILER_PREFIX="{}/bin" '.format(binutils) + # To prevent bazel builds on different hosts/architectures conflicting with each other + # we'll set HOME, inside which Bazel puts active files (in ~/.cache/bazel/...) + self.cfg['prebuildopts'] += 'HOME="{}/fake_home" && '.format(self.builddir) + + if not self.cfg['buildopts']: + self.cfg['buildopts'] = '--enable_cuda --cuda_path "{}" '.format(cuda) + self.cfg['buildopts'] += '--cudnn_path "{}" '.format(cudnn) + self.cfg['buildopts'] += '--bazel_path "{}/bin/bazel" '.format(bazel) + # Tell Bazel to pass PYTHONPATH through to what it's building, so it can find scipy etc. + self.cfg['buildopts'] += '--bazel_options=--action_env=PYTHONPATH ' + self.cfg['buildopts'] += '--cuda_compute_capabilities "7.0,7.5,8.0" ' + self.cfg['buildopts'] += '--noenable_mkl_dnn ' + if get_cpu_architecture() == POWER: + # Tell Bazel to tell NVCC to tell the compiler to use -mno-float128 + self.cfg['buildopts'] += (r'--bazel_options=--per_file_copt=.*cu\.cc.*' + '@-nvcc_options=compiler-options=-mno-float128 ') + + cmd = ' '.join([ + self.cfg['prebuildopts'], + 'python build/build.py', + self.cfg['buildopts'], + ]) + + (out, _) = run_cmd(cmd, log_all=True, simple=False) + + return out + + def install_step(self): + """Install JAX""" + cmd = ' '.join([ + self.cfg['preinstallopts'], + '(cd build && pip install --prefix {} .) && pip install --prefix {} .'.format(self.installdir, + self.installdir), + self.cfg['installopts'], + ]) + + (out, _) = run_cmd(cmd, log_all=True, simple=False) + + return out + + def sanity_check_step(self): + """Custom sanity check for JAX.""" + pylibdir = det_pylibdir() + custom_paths = { + 'files': [], + 'dirs': ['{}/jax'.format(pylibdir), + '{}/jaxlib'.format(pylibdir)], + } + super(EB_JAX, self).sanity_check_step(custom_paths=custom_paths) + + def make_module_extra(self): + """Add module entries specific to Amber/AmberTools""" + txt = super(EB_JAX, self).make_module_extra() + cuda = get_software_root('CUDA') + if cuda: + txt += self.module_generator.set_environment('XLA_FLAGS', "--xla_gpu_cuda_data_dir={}/bin".format(cuda)) + txt += self.module_generator.prepend_paths('PYTHONPATH', det_pylibdir()) + + return txt diff --git a/Golden_Repo/j/JAX/JAX-0.1.77-gpsmkl-2020-Python-3.8.5.eb b/Golden_Repo/j/JAX/JAX-0.1.77-gpsmkl-2020-Python-3.8.5.eb index fbb8a5ddc..1dbf18aeb 100644 --- a/Golden_Repo/j/JAX/JAX-0.1.77-gpsmkl-2020-Python-3.8.5.eb +++ b/Golden_Repo/j/JAX/JAX-0.1.77-gpsmkl-2020-Python-3.8.5.eb @@ -1,6 +1,4 @@ # This easyconfig was created by the BEAR Software team at the University of Birmingham. -easyblock = 'ConfigureMake' - name = 'JAX' version = '0.1.77' versionsuffix = '-Python-%(pyver)s' @@ -18,39 +16,12 @@ dependencies = [ ('Bazel', '3.6.0'), ('cuDNN', '8.0.2.39', '-CUDA-%s' % local_cudaver, True), ('flatbuffers', '1.12.0'), + ('LLVM', '10.0.1'), ] source_urls = ['https://github.com/google/jax/archive'] sources = ['jax-v%(version)s.tar.gz'] -skipsteps = ['configure', 'build'] - -install_cmd = 'export TF_CUDA_PATHS=${EBROOTCUDA} && ' -# To prevent bazel builds on different hosts/architectures conflicting with each other -# we'll set HOME, inside which Bazel puts active files (in ~/.cache/bazel/...) -install_cmd += 'export HOME=%(builddir)s/.home/ && ' -# Trying to prevent the ld.gold problem -install_cmd += 'export GCC_HOST_COMPILER_PREFIX=$EBROOTBINUTILS/bin && ' -install_cmd += 'env BAZEL_LINKOPTS=-static-libstdc++:-static-libgcc ' -install_cmd += 'BAZEL_LINKLIBS=-l%:libstdc++.a:-lm BAZEL_CXXOPTS=-std=gnu++0x && ' -install_cmd += 'python build/build.py --enable_cuda --cuda_path ${EBROOTCUDA} ' -install_cmd += '--cudnn_path ${EBROOTCUDNN} ' -install_cmd += '--bazel_path ${EBROOTBAZEL}/bin/bazel ' -# Tell Bazel to pass PYTHONPATH through to what it's building, so it can find scipy etc. -install_cmd += '--bazel_options=--action_env=PYTHONPATH ' -install_cmd += '--bazel_options=--action_env=PATH ' -install_cmd += '--cuda_compute_capabilities "7.0,7.5,8.0" ' - - -import subprocess as _subprocess # NOQA -_arch = _subprocess.check_output(['uname', '-m'], universal_newlines=True).strip() - -install_cmd += '--noenable_mkl_dnn && ' -install_cmd += '(cd build && pip install --prefix %(installdir)s .) && ' # install jaxlib -install_cmd += 'pip install --prefix %(installdir)s .' # install jax - -modextravars = {'XLA_FLAGS': "--xla_gpu_cuda_data_dir=${STAGES}/${STAGE}/software/CUDA/%s" % local_cudaver} -modextrapaths = {'PYTHONPATH': 'lib/python%(pyshortver)s/site-packages'} sanity_check_paths = { 'files': [], -- GitLab