From 1c7f2433c5f73f4a83c14e790d03e11caad2f761 Mon Sep 17 00:00:00 2001
From: Alexandre Strube <a.strube@fz-juelich.de>
Date: Mon, 18 Jan 2021 16:05:36 +0100
Subject: [PATCH] Merge

---
 .pre-commit-config.yaml                       |  3 +-
 Custom_EasyBlocks/jax.py                      | 89 +++++++++++++++++++
 .../JAX-0.1.77-gpsmkl-2020-Python-3.8.5.eb    | 31 +------
 3 files changed, 92 insertions(+), 31 deletions(-)
 create mode 100644 Custom_EasyBlocks/jax.py

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 9e2cc1e3e..3e08d14d2 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -2,4 +2,5 @@ repos:
 -   repo: https://github.com/pre-commit/mirrors-autopep8
     rev: ''  # Use the sha / tag you want to point at
     hooks:
-    -   id: autopep8
\ No newline at end of file
+    -   id: autopep8
+        args: [-v]
\ No newline at end of file
diff --git a/Custom_EasyBlocks/jax.py b/Custom_EasyBlocks/jax.py
new file mode 100644
index 000000000..861304592
--- /dev/null
+++ b/Custom_EasyBlocks/jax.py
@@ -0,0 +1,89 @@
+"""
+EasyBuild support for building and installing JAX, implemented as an easyblock
+
+@author: Andrew Edmondson (University of Birmingham)
+"""
+from easybuild.framework.easyblock import EasyBlock
+from easybuild.tools.systemtools import POWER, get_cpu_architecture
+from easybuild.tools.modules import get_software_root
+from easybuild.tools.run import run_cmd
+from easybuild.easyblocks.generic.pythonpackage import det_pylibdir
+
+
+class EB_JAX(EasyBlock):
+    """Support for building/installing JAX."""
+
+    def configure_step(self):
+        """No configuration for JAX."""
+        pass
+
+    def build_step(self):
+        """Build JAX"""
+        cuda = get_software_root('CUDA')
+        cudnn = get_software_root('CuDNN')
+        bazel = get_software_root('Bazel')
+        binutils = get_software_root('binutils')
+        if not self.cfg['prebuildopts']:
+            self.cfg['prebuildopts'] = 'export BAZEL_LINKOPTS=-static-libstdc++:-static-libgcc; '
+            self.cfg['prebuildopts'] = 'export BAZEL_LINKLIBS=-l%:libstdc++.a:-lm BAZEL_CXXOPTS=-std=gnu++0x ;' 
+            self.cfg['prebuildopts'] = 'export TF_CUDA_PATHS="{}" '.format(cuda)
+            self.cfg['prebuildopts'] += 'GCC_HOST_COMPILER_PREFIX="{}/bin" '.format(binutils)
+            # To prevent bazel builds on different hosts/architectures conflicting with each other
+            # we'll set HOME, inside which Bazel puts active files (in ~/.cache/bazel/...)
+            self.cfg['prebuildopts'] += 'HOME="{}/fake_home" && '.format(self.builddir)
+
+        if not self.cfg['buildopts']:
+            self.cfg['buildopts'] = '--enable_cuda --cuda_path "{}" '.format(cuda)
+            self.cfg['buildopts'] += '--cudnn_path "{}" '.format(cudnn)
+            self.cfg['buildopts'] += '--bazel_path "{}/bin/bazel" '.format(bazel)
+            # Tell Bazel to pass PYTHONPATH through to what it's building, so it can find scipy etc.
+            self.cfg['buildopts'] += '--bazel_options=--action_env=PYTHONPATH '
+            self.cfg['buildopts'] += '--cuda_compute_capabilities "7.0,7.5,8.0" '
+            self.cfg['buildopts'] += '--noenable_mkl_dnn '
+            if get_cpu_architecture() == POWER:
+                # Tell Bazel to tell NVCC to tell the compiler to use -mno-float128
+                self.cfg['buildopts'] += (r'--bazel_options=--per_file_copt=.*cu\.cc.*'
+                                          '@-nvcc_options=compiler-options=-mno-float128 ')
+
+        cmd = ' '.join([
+            self.cfg['prebuildopts'],
+            'python build/build.py',
+            self.cfg['buildopts'],
+        ])
+
+        (out, _) = run_cmd(cmd, log_all=True, simple=False)
+
+        return out
+
+    def install_step(self):
+        """Install JAX"""
+        cmd = ' '.join([
+            self.cfg['preinstallopts'],
+            '(cd build && pip install --prefix {} .) && pip install --prefix {} .'.format(self.installdir,
+                                                                                          self.installdir),
+            self.cfg['installopts'],
+        ])
+
+        (out, _) = run_cmd(cmd, log_all=True, simple=False)
+
+        return out
+
+    def sanity_check_step(self):
+        """Custom sanity check for JAX."""
+        pylibdir = det_pylibdir()
+        custom_paths = {
+            'files': [],
+            'dirs': ['{}/jax'.format(pylibdir),
+                     '{}/jaxlib'.format(pylibdir)],
+        }
+        super(EB_JAX, self).sanity_check_step(custom_paths=custom_paths)
+
+    def make_module_extra(self):
+        """Add module entries specific to Amber/AmberTools"""
+        txt = super(EB_JAX, self).make_module_extra()
+        cuda = get_software_root('CUDA')
+        if cuda:
+            txt += self.module_generator.set_environment('XLA_FLAGS', "--xla_gpu_cuda_data_dir={}/bin".format(cuda))
+        txt += self.module_generator.prepend_paths('PYTHONPATH', det_pylibdir())
+
+        return txt
diff --git a/Golden_Repo/j/JAX/JAX-0.1.77-gpsmkl-2020-Python-3.8.5.eb b/Golden_Repo/j/JAX/JAX-0.1.77-gpsmkl-2020-Python-3.8.5.eb
index fbb8a5ddc..1dbf18aeb 100644
--- a/Golden_Repo/j/JAX/JAX-0.1.77-gpsmkl-2020-Python-3.8.5.eb
+++ b/Golden_Repo/j/JAX/JAX-0.1.77-gpsmkl-2020-Python-3.8.5.eb
@@ -1,6 +1,4 @@
 # This easyconfig was created by the BEAR Software team at the University of Birmingham.
-easyblock = 'ConfigureMake'
-
 name = 'JAX'
 version = '0.1.77'
 versionsuffix = '-Python-%(pyver)s'
@@ -18,39 +16,12 @@ dependencies = [
     ('Bazel', '3.6.0'),
     ('cuDNN', '8.0.2.39', '-CUDA-%s' % local_cudaver, True),
     ('flatbuffers', '1.12.0'),
+    ('LLVM', '10.0.1'),
 ]
 
 source_urls = ['https://github.com/google/jax/archive']
 sources = ['jax-v%(version)s.tar.gz']
-skipsteps = ['configure', 'build']
-
-install_cmd = 'export TF_CUDA_PATHS=${EBROOTCUDA} && '
-# To prevent bazel builds on different hosts/architectures conflicting with each other
-# we'll set HOME, inside which Bazel puts active files (in ~/.cache/bazel/...)
-install_cmd += 'export HOME=%(builddir)s/.home/ && '
-# Trying to prevent the ld.gold problem
-install_cmd += 'export GCC_HOST_COMPILER_PREFIX=$EBROOTBINUTILS/bin && '
-install_cmd += 'env BAZEL_LINKOPTS=-static-libstdc++:-static-libgcc '
-install_cmd += 'BAZEL_LINKLIBS=-l%:libstdc++.a:-lm BAZEL_CXXOPTS=-std=gnu++0x && '
-install_cmd += 'python build/build.py --enable_cuda --cuda_path ${EBROOTCUDA} '
-install_cmd += '--cudnn_path ${EBROOTCUDNN} '
-install_cmd += '--bazel_path ${EBROOTBAZEL}/bin/bazel '
-# Tell Bazel to pass PYTHONPATH through to what it's building, so it can find scipy etc.
-install_cmd += '--bazel_options=--action_env=PYTHONPATH '
-install_cmd += '--bazel_options=--action_env=PATH '
-install_cmd += '--cuda_compute_capabilities "7.0,7.5,8.0" '
-
-
-import subprocess as _subprocess  # NOQA
-_arch = _subprocess.check_output(['uname', '-m'], universal_newlines=True).strip()
-
-install_cmd += '--noenable_mkl_dnn && '
-install_cmd += '(cd build && pip install --prefix %(installdir)s .) && '  # install jaxlib
-install_cmd += 'pip install --prefix %(installdir)s .'  # install jax
-
-modextravars = {'XLA_FLAGS': "--xla_gpu_cuda_data_dir=${STAGES}/${STAGE}/software/CUDA/%s" % local_cudaver}
 
-modextrapaths = {'PYTHONPATH': 'lib/python%(pyshortver)s/site-packages'}
 
 sanity_check_paths = {
     'files': [],
-- 
GitLab