blob: b7b6556dad61060ba648cc8629083d9153deda89 [file] [log] [blame]
#pragma once
//==============================================================================================
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//==============================================================================================
#include "hipUtils.h"
#include <vector>
template <typename T> class DeviceArray {
public:
// Default constructor
DeviceArray() : DeviceData(nullptr) {}
// Constructor to initialize both host and device data with optional data copy
DeviceArray(size_t N, const T *hostData = nullptr)
: HostData(N), DeviceData(nullptr) {
checkHIP(hipMalloc((void **)&DeviceData, N * sizeof(T)),
"Unable to allocate device memory");
if (hostData != nullptr) {
std::copy(hostData, hostData + N, HostData.begin());
toDevice(); // Copy data from host to device
}
}
// Destructor to free device memory
~DeviceArray() { (void)hipFree(DeviceData); }
// Copy data from host to device
void toDevice() {
checkHIP(hipMemcpy(DeviceData, HostData.data(), HostData.size() * sizeof(T),
hipMemcpyHostToDevice),
"Unable to copy data from host to device");
}
// Copy data from device to host
void toHost() {
checkHIP(hipMemcpy(HostData.data(), DeviceData, HostData.size() * sizeof(T),
hipMemcpyDeviceToHost),
"Unable to copy data from device to host");
}
// Resize the array
void resize(size_t N) {
size_t oldSize = HostData.size();
HostData.resize(N);
T *newDeviceData;
checkHIP(hipMalloc((void **)&newDeviceData, N * sizeof(T)),
"Unable to allocate new device memory during resize");
if (DeviceData && oldSize > 0) {
size_t copySize = std::min(oldSize, N) * sizeof(T);
checkHIP(hipMemcpy(newDeviceData, DeviceData, copySize,
hipMemcpyDeviceToDevice),
"Unable to copy data within device during resize");
hipFree(DeviceData);
}
DeviceData = newDeviceData;
}
// Get pointer to host data
T *getHostPtr() { return HostData.data(); }
// Get pointer to device data
T *getDevicePtr() { return DeviceData; }
private:
std::vector<T> HostData; // Host data
T *DeviceData; // Pointer to device data
};