From 7e1d127ee6751488f31d60deaa5e5791e8484ae7 Mon Sep 17 00:00:00 2001
From: Alexandre Strube <a.strube@fz-juelich.de>
Date: Thu, 4 Mar 2021 17:05:19 +0100
Subject: [PATCH] set XLA_FLAGS for cuda

---
 ...-0.2.9-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5.eb | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/Golden_Repo/j/JAX/JAX-0.2.9-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5.eb b/Golden_Repo/j/JAX/JAX-0.2.9-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5.eb
index 217069da3..eba72dd93 100644
--- a/Golden_Repo/j/JAX/JAX-0.2.9-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5.eb
+++ b/Golden_Repo/j/JAX/JAX-0.2.9-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5.eb
@@ -1,3 +1,4 @@
+import os
 easyblock = 'PythonPackage'
 
 name = 'JAX'
@@ -14,6 +15,7 @@ site_contacts = 'a.strube@fz-juelich.de'
 
 dependencies = [
     ('binutils', '2.34'),
+    ('CUDA', '11.0', '', SYSTEM),
     ('Python', '3.8.5'),
     ('SciPy-Stack', '2020', versionsuffix, ('gcccoremkl', '9.3.0-2020.2.254')),
     ('cuDNN', '8.0.2.39', '-CUDA-%s' % local_cudaver, True),
@@ -56,10 +58,11 @@ exts_list = [
     }),
 ]
 
-
-modextravars = {
-    'XLA_FLAGS': '--xla_gpu_cuda_data_dir=$EBROOTCUDA',
-}
+# This should be modextravars, but life is unfair
+modluafooter = """
+setenv("CUDA_DIR", os.getenv("EBROOTCUDA"))
+setenv("XLA_FLAGS", "--xla_gpu_cuda_data_dir=" .. os.getenv("EBROOTCUDA"))
+"""
 
 sanity_check_paths = {
     'files': [],
-- 
GitLab