I am coding PyTorch. Between the torch inference code, I add some peripheral code for my own interest. This code works fine, but it is too slow. The reason might be for iteration. So, i need parallel and fast way of doing this.
It is okay to do this in tensor, Numpy, or just python array.
I made a function named selective_max
to find maximum value in arrays. But the problem is that I don't want a maximum among the whole arrays, but among specific candidates which is designated by mask
array. Let me show the gist of this function (below shows the code itself)
Input
x [batch_size , dim, num_points, k]
: x is a original input, but this becomes [batch_size, num_points, dim, k] by x.permute(0,2,1,3)
.
batch_size
is a well-known definition in the deep learning society. In every mini batch, there are many points. And a single point is represented by dim
length feature. For each feature element, there are k
potential candidates which is target of max
function later.
mask [batch_size, num_points, k]
: This array is similar to x
without dim
. Its element is either 0
or 1
. So, I use this as a mask signal, like do max operation only on 1
masked value.
Kindly see the code below with this explanation. I use 3
for iteration. Let's say we target a specific batch and a specific point. For a specific batch and a specific point, x
has [dim, k] array. And mask has [k] array which consists of either 0
or 1
. So, I extract the non-zero index from [k] array and use this for extracting specific elements in x
dim by dim('for k in range(dim)').
Toy example
Let's say we are in the second for iteration. So, we now have [dim, k]
for x
and [k]
for mask
. For this toy example, i
presume k=3
and dim=4
. x = [[3,2,1],[5,6,4],[9,8,7],[12,11,10]]
, k=[0,1,1]
. So, output would be [2,6,8,11]
, not [3, 6, 9, 12]
.
Previous attempt
I try { mask.repeat(0,0,1,0) *(element-wise mul) x }
and do the max
operation. But, '0' might the max value, because the x might have minus values in all array. So, this would result in wrong operation.
def selective_max2(x, mask): # x : [batch_size , dim, num_points, k] , mask : [batch_size, num_points, k]
batch_size = x.size(0)
dim = x.size(1)
num_points = x.size(2)
k = x.size(3)
device = torch.device('cuda')
x = x.permute(0,2,1,3) # : [batch, num_points, dim, k]
#print('permuted x dimension : ',x.size())
x = x.detach().cpu().numpy()
mask = mask.cpu().numpy()
output = np.zeros((batch_size,num_points,dim))
for i in range(batch_size):
for j in range(num_points):
query=np.nonzero(mask[i][j]) # among mask entries, we get the index of nonzero values.
for k in range(dim): # for different k values, we get the max value.
# query is index of nonzero values. so, using query, we can get the values that we want.
output[i][j][k] = np.max(x[i][j][k][query])
output = torch.from_numpy(output).float().to(device=device)
output = output.permute(0,2,1).contiguous()
return output