Actual source code: cupmcontext.hpp

  1: #ifndef PETSCDEVICECONTEXTCUPM_HPP
  2: #define PETSCDEVICECONTEXTCUPM_HPP

  4: #include <petsc/private/deviceimpl.h>
  5: #include <petsc/private/cupmblasinterface.hpp>

  7: #include <array>

  9: namespace Petsc
 10: {

 12: namespace Device
 13: {

 15: namespace CUPM
 16: {

 18: namespace Impl
 19: {

 21: namespace detail
 22: {

 24: // for tag-based dispatch of handle retrieval
 25: template <typename T> struct HandleTag { using type = T; };

 27: } // namespace detail

 29: // Forward declare
 30: template <DeviceType T> class PETSC_VISIBILITY_INTERNAL DeviceContext;

 32: template <DeviceType T>
 33: class DeviceContext : Impl::BlasInterface<T>
 34: {
 35: public:
 36:   PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t,T);

 38: private:
 39:   template <typename H> using HandleTag = typename detail::HandleTag<H>;
 40:   using stream_tag = HandleTag<cupmStream_t>;
 41:   using blas_tag   = HandleTag<cupmBlasHandle_t>;
 42:   using solver_tag = HandleTag<cupmSolverHandle_t>;

 44: public:
 45:   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
 46:   // header, but since we are using the power of templates it must be declared part of
 47:   // this class to have easy access the same typedefs. Technically one can make a
 48:   // templated struct outside the class but it's more code for the same result.
 49:   struct PetscDeviceContext_IMPLS
 50:   {
 51:     cupmStream_t       stream;
 52:     cupmEvent_t        event;
 53:     cupmEvent_t        begin; // timer-only
 54:     cupmEvent_t        end;   // timer-only
 55: #if PetscDefined(USE_DEBUG)
 56:     PetscBool          timerInUse;
 57: #endif
 58:     cupmBlasHandle_t   blas;
 59:     cupmSolverHandle_t solver;

 61:     PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) { return this->stream; }
 62:     PETSC_NODISCARD auto get(blas_tag)   const -> decltype(this->blas)   { return this->blas;   }
 63:     PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) { return this->solver; }
 64:   };

 66: private:
 67:   static bool initialized_;
 68:   static std::array<cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES>   blashandles_;
 69:   static std::array<cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> solverhandles_;

 71:   PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS* impls_cast_(PetscDeviceContext ptr))
 72:   {
 73:     return static_cast<PetscDeviceContext_IMPLS*>(ptr->data);
 74:   }

 76:   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(cupmBlasHandle_t &handle))
 77:   {
 78:     if (handle) return 0;
 79:     for (auto i = 0; i < 3; ++i) {
 80:       auto cberr = cupmBlasCreate(&handle);
 81:       if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
 82:       if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) cberr;
 83:       if (i != 2) {
 84:         PetscSleep(3);
 85:         continue;
 86:       }
 88:     }
 89:     return 0;
 90:   }

 92:   PETSC_CXX_COMPAT_DECL(PetscErrorCode set_handle_stream_(cupmBlasHandle_t &handle, cupmStream_t &stream))
 93:   {
 94:     cupmStream_t    cupmStream;

 96:     cupmBlasGetStream(handle,&cupmStream);
 97:     if (cupmStream != stream) cupmBlasSetStream(handle,stream);
 98:     return 0;
 99:   }

101:   PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_())
102:   {
103:     for (auto&& handle : blashandles_) {
104:       if (handle) {
105:         cupmBlasDestroy(handle);
106:         handle     = nullptr;
107:       }
108:     }
109:     for (auto&& handle : solverhandles_) {
110:       if (handle) {
111:         cupmBlasInterface_t::DestroyHandle(handle);
112:         handle    = nullptr;
113:       }
114:     }
115:     initialized_ = false;
116:     return 0;
117:   }

119:   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_(PetscInt id, PetscDeviceContext_IMPLS *dci))
120:   {
121:     PetscDeviceCheckDeviceCount_Internal(id);
122:     if (!initialized_) {
123:       initialized_ = true;
124:       PetscRegisterFinalize(finalize_);
125:     }
126:     // use the blashandle as a canary
127:     if (!blashandles_[id]) {
128:       initialize_handle_(blashandles_[id]);
129:       cupmBlasInterface_t::InitializeHandle(solverhandles_[id]);
130:     }
131:     set_handle_stream_(blashandles_[id],dci->stream);
132:     cupmBlasInterface_t::SetHandleStream(solverhandles_[id],dci->stream);
133:     dci->blas   = blashandles_[id];
134:     dci->solver = solverhandles_[id];
135:     return 0;
136:   }

138: public:
139:   const struct _DeviceContextOps ops = {
140:     destroy,
141:     changeStreamType,
142:     setUp,
143:     query,
144:     waitForContext,
145:     synchronize,
146:     getHandle<blas_tag>,
147:     getHandle<solver_tag>,
148:     getHandle<stream_tag>,
149:     beginTimer,
150:     endTimer,
151:   };

153:   // All of these functions MUST be static in order to be callable from C, otherwise they
154:   // get the implicit 'this' pointer tacked on
155:   PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext));
156:   PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType));
157:   PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext));
158:   PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext,PetscBool*));
159:   PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext));
160:   PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext));
161:   template <typename Handle_t>
162:   PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext,void*));
163:   PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext));
164:   PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*));
165: };

167: template <DeviceType T>
168: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx))
169: {
170:   auto           dci = impls_cast_(dctx);

172:   if (dci->stream) cupmStreamDestroy(dci->stream);
173:   if (dci->event)  cupmEventDestroy(dci->event);
174:   if (dci->begin)  cupmEventDestroy(dci->begin);
175:   if (dci->end)    cupmEventDestroy(dci->end);
176:   PetscFree(dctx->data);
177:   return 0;
178: }

180: template <DeviceType T>
181: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype))
182: {
183:   auto dci = impls_cast_(dctx);

185:   if (dci->stream) {
186:     cupmStreamDestroy(dci->stream);
187:     dci->stream = nullptr;
188:   }
189:   // set these to null so they aren't usable until setup is called again
190:   dci->blas   = nullptr;
191:   dci->solver = nullptr;
192:   return 0;
193: }

195: template <DeviceType T>
196: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx))
197: {
198:   auto           dci = impls_cast_(dctx);

200:   if (dci->stream) {
201:     cupmStreamDestroy(dci->stream);
202:     dci->stream = nullptr;
203:   }
204:   switch (dctx->streamType) {
205:   case PETSC_STREAM_GLOBAL_BLOCKING:
206:     // don't create a stream for global blocking
207:     break;
208:   case PETSC_STREAM_DEFAULT_BLOCKING:
209:     cupmStreamCreate(&dci->stream);
210:     break;
211:   case PETSC_STREAM_GLOBAL_NONBLOCKING:
212:     cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);
213:     break;
214:   default:
215:     SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[util::integral_value(dctx->streamType)]);
216:     break;
217:   }
218:   if (!dci->event) cupmEventCreate(&dci->event);
219: #if PetscDefined(USE_DEBUG)
220:   dci->timerInUse = PETSC_FALSE;
221: #endif
222:   initialize_(dctx->device->deviceId,dci);
223:   return 0;
224: }

226: template <DeviceType T>
227: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle))
228: {
229:   cupmError_t cerr;

231:   cerr = cupmStreamQuery(impls_cast_(dctx)->stream);
232:   if (cerr == cupmSuccess) *idle = PETSC_TRUE;
233:   else {
234:     // somethings gone wrong
235:     if (PetscUnlikely(cerr != cupmErrorNotReady)) cerr;
236:     *idle = PETSC_FALSE;
237:   }
238:   return 0;
239: }

241: template <DeviceType T>
242: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb))
243: {
244:   auto        dcib = impls_cast_(dctxb);

246:   cupmEventRecord(dcib->event,dcib->stream);
247:   cupmStreamWaitEvent(impls_cast_(dctxa)->stream,dcib->event,0);
248:   return 0;
249: }

251: template <DeviceType T>
252: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx))
253: {
254:   auto        dci = impls_cast_(dctx);

256:   // in case anything was queued on the event
257:   cupmStreamWaitEvent(dci->stream,dci->event,0);
258:   cupmStreamSynchronize(dci->stream);
259:   return 0;
260: }

262: template <DeviceType T>
263: template <typename handle_t>
264: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle))
265: {
266:   *static_cast<typename handle_t::type*>(handle) = impls_cast_(dctx)->get(handle_t());
267:   return 0;
268: }

270: template <DeviceType T>
271: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx))
272: {
273:   auto        dci = impls_cast_(dctx);

275: #if PetscDefined(USE_DEBUG)
277:   dci->timerInUse = PETSC_TRUE;
278: #endif
279:   if (!dci->begin) {
280:     cupmEventCreate(&dci->begin);
281:     cupmEventCreate(&dci->end);
282:   }
283:   cupmEventRecord(dci->begin,dci->stream);
284:   return 0;
285: }

287: template <DeviceType T>
288: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed))
289: {
290:   float       gtime;
291:   auto        dci = impls_cast_(dctx);

293: #if PetscDefined(USE_DEBUG)
295:   dci->timerInUse = PETSC_FALSE;
296: #endif
297:   cupmEventRecord(dci->end,dci->stream);
298:   cupmEventSynchronize(dci->end);
299:   cupmEventElapsedTime(&gtime,dci->begin,dci->end);
300:   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
301:   return 0;
302: }

304: // initialize the static member variables
305: template <DeviceType T> bool DeviceContext<T>::initialized_ = false;

307: template <DeviceType T>
308: std::array<typename DeviceContext<T>::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES>   DeviceContext<T>::blashandles_ = {};

310: template <DeviceType T>
311: std::array<typename DeviceContext<T>::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};

313: } // namespace Impl

315: // shorten this one up a bit (and instantiate the templates)
316: using CUPMContextCuda = Impl::DeviceContext<DeviceType::CUDA>;
317: using CUPMContextHip  = Impl::DeviceContext<DeviceType::HIP>;

319: // shorthand for what is an EXTREMELY long name
320: #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext<Petsc::Device::CUPM::DeviceType::IMPLS>::PetscDeviceContext_IMPLS

322: } // namespace CUPM

324: } // namespace Device

326: } // namespace Petsc

328: #endif // PETSCDEVICECONTEXTCUDA_HPP