Welcome to WuJiGu Developer Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
266 views
in Technique[技术] by (71.8m points)

pytorch - max operation for selective elements, not for all

I am coding PyTorch. Between the torch inference code, I add some peripheral code for my own interest. This code works fine, but it is too slow. The reason might be for iteration. So, i need parallel and fast way of doing this.

It is okay to do this in tensor, Numpy, or just python array.

I made a function named selective_max to find maximum value in arrays. But the problem is that I don't want a maximum among the whole arrays, but among specific candidates which is designated by mask array. Let me show the gist of this function (below shows the code itself)

Input

x [batch_size , dim, num_points, k] : x is a original input, but this becomes [batch_size, num_points, dim, k] by x.permute(0,2,1,3).

batch_size is a well-known definition in the deep learning society. In every mini batch, there are many points. And a single point is represented by dim length feature. For each feature element, there are k potential candidates which is target of max function later.

mask [batch_size, num_points, k] : This array is similar to x without dim. Its element is either 0 or 1. So, I use this as a mask signal, like do max operation only on 1 masked value.

Kindly see the code below with this explanation. I use 3 for iteration. Let's say we target a specific batch and a specific point. For a specific batch and a specific point, x has [dim, k] array. And mask has [k] array which consists of either 0 or 1. So, I extract the non-zero index from [k] array and use this for extracting specific elements in x dim by dim('for k in range(dim)').

Toy example

Let's say we are in the second for iteration. So, we now have [dim, k] for x and [k] for mask. For this toy example, i presume k=3 and dim=4. x = [[3,2,1],[5,6,4],[9,8,7],[12,11,10]], k=[0,1,1]. So, output would be [2,6,8,11], not [3, 6, 9, 12].

Previous attempt

I try { mask.repeat(0,0,1,0) *(element-wise mul) x } and do the max operation. But, '0' might the max value, because the x might have minus values in all array. So, this would result in wrong operation.

def selective_max2(x, mask): # x : [batch_size , dim, num_points, k] , mask : [batch_size, num_points, k]
batch_size = x.size(0)
dim = x.size(1)
num_points = x.size(2)
k = x.size(3)
device = torch.device('cuda')

x = x.permute(0,2,1,3) # : [batch, num_points, dim, k]
#print('permuted x dimension : ',x.size())

x = x.detach().cpu().numpy()
mask = mask.cpu().numpy()
output = np.zeros((batch_size,num_points,dim))

for i in range(batch_size):
 for j in range(num_points):
  query=np.nonzero(mask[i][j]) # among mask entries, we get the index of nonzero values.
  for k in range(dim): # for different k values, we get the max value.
   # query is index of nonzero values. so, using query, we can get the values that we want.
   output[i][j][k] = np.max(x[i][j][k][query])

output = torch.from_numpy(output).float().to(device=device)
output = output.permute(0,2,1).contiguous()
return output

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

Disclaimer: I've followed your toy example (however while retaining generality) to write the following solution.

The first thing is to expand your k as x (treating them both as PyTorch tensors):

k_expanded = k.expand_as(x)

Then you select the elements where your 1's exist in the k_expanded, and view the resulting tensor as x number of rows (written as x.shape[0]), and number of 1's in k (or the mask) as the number of columns. Up to this point, we have selected the range we want to query the maximum element for. Then, you find the maximum along the rows dimension (showed in .sum(0)) using max(1)

values, indices = x[k_expanded == 1].view(x.shape[0], (k == 1).sum(0)).max(1)
values
Out[29]: tensor([ 2,  6,  8, 11])

Benchmarks

def find_max_elements_inside_tensor_range(arr, mask, return_indices=False):
    mask_expanded = mask.expand_as(arr)
    values, indices = x[k_expanded==1].view(x.shape[0], (k == 1).sum(0)).max(1)
    return (values, indices) if return_indices else values

Just added a third parameter in case you want to get the numbers indices

%timeit find_max_elements_inside_tensor_range(x, k)
38.4 μs ± 534 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Note: the above solution also works for tensors and masks of various shapes.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to WuJiGu Developer Q&A Community for programmer and developer-Open, Learning and Share
...