-
Notifications
You must be signed in to change notification settings - Fork 60
Expand file tree
/
Copy pathcorrelation_kernels.cu
More file actions
executable file
·196 lines (153 loc) · 5.4 KB
/
correlation_kernels.cu
File metadata and controls
executable file
·196 lines (153 loc) · 5.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
/*
* Copyright (c) 2021, Princeton Vision & Learning Lab (DROID-SLAM Authors)
* All rights reserved.
*
* This source code is licensed under the BSD 3-Clause License found in the
* LICENSE file in the root directory of this source tree.
*
* References:
* https://github.com/princeton-vl/DROID-SLAM
*/
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#define BLOCK 16
__forceinline__ __device__ bool within_bounds(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
template <typename scalar_t>
__global__ void corr_index_forward_kernel(
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> volume,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> coords,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr,
int r)
{
// batch index
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const int n = blockIdx.z;
const int h1 = volume.size(1);
const int w1 = volume.size(2);
const int h2 = volume.size(3);
const int w2 = volume.size(4);
if (!within_bounds(y, x, h1, w1)) {
return;
}
float x0 = coords[n][0][y][x];
float y0 = coords[n][1][y][x];
float dx = x0 - floor(x0);
float dy = y0 - floor(y0);
int rd = 2*r + 1;
for (int i=0; i<rd+1; i++) {
for (int j=0; j<rd+1; j++) {
int x1 = static_cast<int>(floor(x0)) - r + i;
int y1 = static_cast<int>(floor(y0)) - r + j;
if (within_bounds(y1, x1, h2, w2)) {
scalar_t s = volume[n][y][x][y1][x1];
if (i > 0 && j > 0)
corr[n][i-1][j-1][y][x] += s * scalar_t(dx * dy);
if (i > 0 && j < rd)
corr[n][i-1][j][y][x] += s * scalar_t(dx * (1.0f-dy));
if (i < rd && j > 0)
corr[n][i][j-1][y][x] += s * scalar_t((1.0f-dx) * dy);
if (i < rd && j < rd)
corr[n][i][j][y][x] += s * scalar_t((1.0f-dx) * (1.0f-dy));
}
}
}
}
template <typename scalar_t>
__global__ void corr_index_backward_kernel(
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> coords,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr_grad,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> volume_grad,
int r)
{
// batch index
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const int n = blockIdx.z;
const int h1 = volume_grad.size(1);
const int w1 = volume_grad.size(2);
const int h2 = volume_grad.size(3);
const int w2 = volume_grad.size(4);
if (!within_bounds(y, x, h1, w1)) {
return;
}
float x0 = coords[n][0][y][x];
float y0 = coords[n][1][y][x];
float dx = x0 - floor(x0);
float dy = y0 - floor(y0);
int rd = 2*r + 1;
for (int i=0; i<rd+1; i++) {
for (int j=0; j<rd+1; j++) {
int x1 = static_cast<int>(floor(x0)) - r + i;
int y1 = static_cast<int>(floor(y0)) - r + j;
if (within_bounds(y1, x1, h2, w2)) {
scalar_t g = 0.0;
if (i > 0 && j > 0)
g += corr_grad[n][i-1][j-1][y][x] * scalar_t(dx * dy);
if (i > 0 && j < rd)
g += corr_grad[n][i-1][j][y][x] * scalar_t(dx * (1.0f-dy));
if (i < rd && j > 0)
g += corr_grad[n][i][j-1][y][x] * scalar_t((1.0f-dx) * dy);
if (i < rd && j < rd)
g += corr_grad[n][i][j][y][x] * scalar_t((1.0f-dx) * (1.0f-dy));
volume_grad[n][y][x][y1][x1] += g;
}
}
}
}
std::vector<torch::Tensor> corr_index_cuda_forward(
torch::Tensor volume,
torch::Tensor coords,
int radius)
{
const auto batch_size = volume.size(0);
const auto ht = volume.size(1);
const auto wd = volume.size(2);
const dim3 blocks((wd + BLOCK - 1) / BLOCK,
(ht + BLOCK - 1) / BLOCK,
batch_size);
const dim3 threads(BLOCK, BLOCK);
auto opts = volume.options();
torch::Tensor corr = torch::zeros(
{batch_size, 2*radius+1, 2*radius+1, ht, wd}, opts);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_forward_kernel", ([&] {
corr_index_forward_kernel<scalar_t><<<blocks, threads>>>(
volume.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
coords.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
corr.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
radius);
}));
return {corr};
}
std::vector<torch::Tensor> corr_index_cuda_backward(
torch::Tensor volume,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius)
{
const auto batch_size = volume.size(0);
const auto ht = volume.size(1);
const auto wd = volume.size(2);
auto volume_grad = torch::zeros_like(volume);
const dim3 blocks((wd + BLOCK - 1) / BLOCK,
(ht + BLOCK - 1) / BLOCK,
batch_size);
const dim3 threads(BLOCK, BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_backward_kernel", ([&] {
corr_index_backward_kernel<scalar_t><<<blocks, threads>>>(
coords.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
corr_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
volume_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
radius);
}));
return {volume_grad};
}