diff --git a/src/linalg/mat.cu b/src/linalg/mat.cu index f75b0dc..745ddfb 100644 --- a/src/linalg/mat.cu +++ b/src/linalg/mat.cu @@ -21,18 +21,55 @@ __device__ float sampleVolumeNearest(float* volumeData, const int volW, const in return volumeData[idx]; } -__device__ Vec3 computeGradient(float* volumeData, const int volW, const int volH, const int volD, int vx, int vy, int vz) { - // Finite difference for partial derivatives. +// tri-linear interpolation - ready if necessary (but no visible improvement for full volume) +__device__ float sampleVolumeTrilinear(float* volumeData, const int volW, const int volH, const int volD, float fx, float fy, float fz) { + int ix = (int)floorf(fx); + int iy = (int)floorf(fy); + int iz = (int)floorf(fz); + + // Clamp indices to valid range + int ix1 = min(ix + 1, volH - 1); + int iy1 = min(iy + 1, volW - 1); + int iz1 = min(iz + 1, volD - 1); + ix = max(ix, 0); + iy = max(iy, 0); + iz = max(iz, 0); + + // Compute weights + float dx = fx - ix; + float dy = fy - iy; + float dz = fz - iz; + + // Sample values + float c00 = sampleVolumeNearest(volumeData, volW, volH, volD, ix, iy, iz) * (1.0f - dx) + + sampleVolumeNearest(volumeData, volW, volH, volD, ix1, iy, iz) * dx; + float c10 = sampleVolumeNearest(volumeData, volW, volH, volD, ix, iy1, iz) * (1.0f - dx) + + sampleVolumeNearest(volumeData, volW, volH, volD, ix1, iy1, iz) * dx; + float c01 = sampleVolumeNearest(volumeData, volW, volH, volD, ix, iy, iz1) * (1.0f - dx) + + sampleVolumeNearest(volumeData, volW, volH, volD, ix1, iy, iz1) * dx; + float c11 = sampleVolumeNearest(volumeData, volW, volH, volD, ix, iy1, iz1) * (1.0f - dx) + + sampleVolumeNearest(volumeData, volW, volH, volD, ix1, iy1, iz1) * dx; + + float c0 = c00 * (1.0f - dy) + c10 * dy; + float c1 = c01 * (1.0f - dy) + c11 * dy; + + return c0 * (1.0f - dz) + c1 * dz; +} + +__device__ Vec3 computeGradient(float* volumeData, const int volW, const int volH, const int volD, float fx, float fy, float fz) { + // Compute gradient using central differencing with trilinear interpolation + float hx = DLAT; // x => height => lat + float hy = DLON; // y => width => lon + float hz = DLEV; // z => depth => alt + float dfdx = (sampleVolumeTrilinear(volumeData, volW, volH, volD, fx + hx, fy, fz) - + sampleVolumeTrilinear(volumeData, volW, volH, volD, fx - hx, fy, fz)) / (2.0f * hx); - float dfdx = (sampleVolumeNearest(volumeData, VOLUME_WIDTH, VOLUME_HEIGHT, VOLUME_DEPTH, vx + 1, vy, vz) - - sampleVolumeNearest(volumeData, VOLUME_WIDTH, VOLUME_HEIGHT, VOLUME_DEPTH, vx - 1, vy, vz)) / (2.0f * DLAT); // x => height => lat + float dfdy = (sampleVolumeTrilinear(volumeData, volW, volH, volD, fx, fy + hy, fz) - + sampleVolumeTrilinear(volumeData, volW, volH, volD, fx, fy - hy, fz)) / (2.0f * hy); - float dfdy = (sampleVolumeNearest(volumeData, VOLUME_WIDTH, VOLUME_HEIGHT, VOLUME_DEPTH, vx, vy + 1, vz) - - sampleVolumeNearest(volumeData, VOLUME_WIDTH, VOLUME_HEIGHT, VOLUME_DEPTH, vx, vy - 1, vz)) / (2.0f * DLON); // y => width => lon - - float dfdz = (sampleVolumeNearest(volumeData, VOLUME_WIDTH, VOLUME_HEIGHT, VOLUME_DEPTH, vx, vy, vz + 1) - - sampleVolumeNearest(volumeData, VOLUME_WIDTH, VOLUME_HEIGHT, VOLUME_DEPTH, vx, vy, vz - 1)) / (2.0f * DLEV); + float dfdz = (sampleVolumeTrilinear(volumeData, volW, volH, volD, fx, fy, fz + hz) - + sampleVolumeTrilinear(volumeData, volW, volH, volD, fx, fy, fz - hz)) / (2.0f * hz); return Vec3::init(dfdx, dfdy, dfdz); }; diff --git a/src/linalg/mat.h b/src/linalg/mat.h index 02234c0..7591761 100644 --- a/src/linalg/mat.h +++ b/src/linalg/mat.h @@ -4,7 +4,10 @@ #include "vec.h" #include "consts.h" -__device__ Vec3 computeGradient(float* volumeData, const int volW, const int volH, const int volD, int x, int y, int z); +__device__ float sampleVolumeNearest(float* volumeData, const int volW, const int volH, const int volD, int vx, int vy, int vz); +__device__ float sampleVolumeTrilinear(float* volumeData, const int volW, const int volH, const int volD, float fx, float fy, float fz); + +__device__ Vec3 computeGradient(float* volumeData, const int volW, const int volH, const int volD, float fx, float fy, float fz); __device__ unsigned int packUnorm4x8(float r, float g, float b, float a);