Comments (21)
This line copies a permutation of B into A based on the indices in idx.
(A = matx::remap<0>(B, idx).run();
This line does nothing since you are not storing the output into an operator.
A_prime = matx::remap<0>(B, idx);
You need to construct the operator and pass it into the next operator to chain these together. That is your line should like something like this:
auto A_primeOp = (A_prime = matx::remap<0>(B, idx));
Now A_primeOp can be ran directly like this:
A_primeOp.run();
Or passed into another operation like this:
(n = norm(A_primeOp)).run();
the advantage of chaining together operators into other operators is that you get fewer kernel calls which generally leads to significantly less launch latency and memory traffic. Each call to .run() indicates a single kernel call.
from matx.
I may as well provide all the code.
Its a work in progress haha, so pls excuse it if is contains gross inefficiencies.
Assume k = 2, m = 2000, n = 70,000, d = 784
matx::tensor_t<float, 2> GsDBSCAN::findDistancesMatX(matx::tensor_t<float, 2> &X_t, matx::tensor_t<int, 2> &A_t, matx::tensor_t<int, 2> &B_t, float alpha) {
const int k = A_t.Shape()[1] / 2;
const int m = B_t.Shape()[1];
const int n = X_t.Shape()[0];
const int d = X_t.Shape()[1];
const int batchSize = GsDBSCAN::findDistanceBatchSize(alpha, n, d, k, m); // Around 300
auto AFlat_t = matx::reshape(A_t, {n * 2 * k});
auto ABatchFlat_t = matx::make_tensor<int>( {batchSize * 2 * k});
auto BBatch_t = matx::make_tensor<int>( {ABatchFlat_t.Size(0), m});
auto XBatch_t = matx::make_tensor<float>( {2*batchSize*k*m, d});
auto XSubset_t = matx::make_tensor<float>( {batchSize, d});
auto YBatch_t = matx::make_tensor<float>({batchSize, 2*k*m, d});
auto distancesBatch_t = matx::make_tensor<float>({batchSize, 2 * k * m});
auto distances_t = matx::make_tensor<float>({n, 2*k*m});
for (int i = 0; i < n; i += batchSize) {
int maxBatchIdx = i + batchSize - 1; // Index within X along the ROWS
(XSubset_t = matx::slice(X_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd})).run();
// XSubset_t = matx::slice(X_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd});
(ABatchFlat_t = matx::slice(AFlat_t, {i * 2 * k}, {(maxBatchIdx + 1) * 2 * k})).run();
// ABatchFlat_t = matx::slice(AFlat_t, {i * 2 * k}, {(maxBatchIdx + 1) * 2 * k});
(BBatch_t = matx::remap<0>(B_t, ABatchFlat_t)).run();
// BBatch_t = matx::remap<0>(B_t, ABatchFlat_t);
auto BBatch_t_flat = matx::flatten(BBatch_t);
(XBatch_t = matx::remap<0>(X_t, BBatch_t_flat)).run();
// XBatch_t = matx::remap<0>(X_t, BBatch_t_flat);
auto XBatchReshaped_t = matx::reshape(XBatch_t, {batchSize, 2*k*m, d});
auto XSubsetReshaped_t = matx::reshape(XSubset_t, {batchSize, 1, d});
(YBatch_t = XBatchReshaped_t - matx::repmat(XSubsetReshaped_t, {1, 2*k*m, 1})).run(); // Repmat is a workaround for minusing naively incompatibhle tensor shapes
(distancesBatch_t = matx::vector_norm(YBatch_t, {2}, matx::NormOrder::L2)).run();
(matx::slice(distances_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd}) = distancesBatch_t).run();
}
return distances_t;
}
Essentially what I find is that my distances_t
tensor varies greatly depending on the run()
lines I uncomment. E.g. If I do:
// (XSubset_t = matx::slice(X_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd})).run();
XSubset_t = matx::slice(X_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd});
Then the new distances_t tensor is different - failing the tests I've written to check it's contents.
FYI these results are different after calling cudaDeviceSynchronize()
from matx.
also I didn't inspect your code closely to see if you have this issue but another possible explanation would be if you were trying to write to indices that other threads read to in the same kernel. Something like this is not allowed:
(A = transpose(A)).run();
The reason is each element of the tensor get's processed by a different thread. This can lead to a race condition between the read of A(i,j) and the write of A(i,j).
from matx.
Ok, to make sure I understand, A_prime = matx::remap<0>(B, idx)
returns an operator? - that you then need to store? Hence why we need auto A_primeOp = (A_prime = matx::remap<0>(B, idx));
instead?
from matx.
from matx.
Ok thx, for the below:
auto A_primeOp = (A_prime = matx::remap<0>(B, idx));
is A_prime
more or less a dummy variable then - assuming that the scope of A_prime
ends after using norm
?
I.e. what I'm trying to say, is that if you wish to do deferred execution of the operators, do you more or less have to use a dummy variable? - assuming that's what A_prime
is?
from matx.
from matx.
Ok thx, for the below:
auto A_primeOp = (A_prime = matx::remap<0>(B, idx));
is
A_prime
more or less a dummy variable then?I.e. what I'm trying to say, is that if you wish to do deferred execution of the operators, do you more or less have to use a dummy variable? - assuming that's what
A_prime
is?
In that example A_prime is a tensor that you're storing to, and A_primeOp is a variable describing the operation of storing into that tensor. When you call run() that is triggering the statement to do what your operator is describing, which is storing the output into that tensor.
In reality you likely wouldn't store that statement into a variable and instead would just do (A_prime = matx::remap<0>(B, idx)).run()
from matx.
@cliffburdick ok thx. But if I wanted optimality interms of not having to launch the kernel multiple times etc. I would be best storing the statement right?
from matx.
@cliffburdick ok thx. But if I wanted optimality interms of not having to launch the kernel multiple times etc. I would be best storing the statement right?
Yes. See this example here where several lines are composed, and only the final line runs the kernel.
from matx.
Ok thx, I followed the auto op = ...
pattern for the entire loop body.
But now I've got a small error that I feel like shouldn't be a problem...
Would you know why the below doesn't work?
auto YBatch_t_op = (XBatchReshaped_t_op - matx::repmat(XSubsetReshaped_t_op, {1, 2*k*m, 1}));
auto YBatch_t_norm_op = matx::vector_norm(YBatch_t_op, {2}, matx::NormOrder::L2); // Gives 'no instance of constructor' error
(matx::slice(distances_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd}) = YBatch_t_norm_op).run();
from matx.
Can you paste the exact error?
from matx.
Yep its quite long though.
/usr/local/include/matx/operators/norm.h(163): error: no instance of constructor "matx::detail::NormOp<OpA, NormType>::NormOp [with OpA=matx::detail::matxBinaryOp<matx::detail::ReshapeOp<3, matx::detail::RemapOp<0, matx::tensor_t<float, 2, matx::basic_storage<matx::raw_pointer_buffer<float, matx::matx_allocator<float>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, matx::detail::FlattenOp<matx::detail::RemapOp<0, matx::tensor_t<int, 2, matx::basic_storage<matx::raw_pointer_buffer<int, matx::matx_allocator<int>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, matx::detail::SliceOp<1, matx::detail::ReshapeOp<1, matx::tensor_t<int, 2, matx::basic_storage<matx::raw_pointer_buffer<int, matx::matx_allocator<int>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, cuda::std::__4::array<long long, 1UL>>>>>>, cuda::std::__4::array<long long, 3UL>>, matx::detail::RepMatOp<matx::detail::ReshapeOp<3, matx::tensor_t<float, 2, matx::basic_storage<matx::raw_pointer_buffer<float, matx::matx_allocator<float>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, cuda::std::__4::array<long long, 3UL>>, 3>, matx::detail::BinOp<float, float, matx::detail::SubF<float, float>>>, NormType=matx::detail::NormTypeVector]" matches the argument list
argument types are: (matx::detail::PermuteOp<matx::detail::matxBinaryOp<matx::detail::ReshapeOp<3, matx::detail::RemapOp<0, matx::tensor_t<float, 2, matx::basic_storage<matx::raw_pointer_buffer<float, matx::matx_allocator<float>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, matx::detail::FlattenOp<matx::detail::RemapOp<0, matx::tensor_t<int, 2, matx::basic_storage<matx::raw_pointer_buffer<int, matx::matx_allocator<int>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, matx::detail::SliceOp<1, matx::detail::ReshapeOp<1, matx::tensor_t<int, 2, matx::basic_storage<matx::raw_pointer_buffer<int, matx::matx_allocator<int>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, cuda::std::__4::array<long long, 1UL>>>>>>, cuda::std::__4::array<long long, 3UL>>, matx::detail::RepMatOp<matx::detail::ReshapeOp<3, matx::tensor_t<float, 2, matx::basic_storage<matx::raw_pointer_buffer<float, matx::matx_allocator<float>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, cuda::std::__4::array<long long, 3UL>>, 3>, matx::detail::BinOp<float, float, matx::detail::SubF<float, float>>>>, matx::NormOrder)
detected during instantiation of "auto matx::vector_norm(const Op &, const int (&)[D], matx::NormOrder) [with Op=matx::detail::matxBinaryOp<matx::detail::ReshapeOp<3, matx::detail::RemapOp<0, matx::tensor_t<float, 2, matx::basic_storage<matx::raw_pointer_buffer<float, matx::matx_allocator<float>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, matx::detail::FlattenOp<matx::detail::RemapOp<0, matx::tensor_t<int, 2, matx::basic_storage<matx::raw_pointer_buffer<int, matx::matx_allocator<int>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, matx::detail::SliceOp<1, matx::detail::ReshapeOp<1, matx::tensor_t<int, 2, matx::basic_storage<matx::raw_pointer_buffer<int, matx::matx_allocator<int>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, cuda::std::__4::array<long long, 1UL>>>>>>, cuda::std::__4::array<long long, 3UL>>, matx::detail::RepMatOp<matx::detail::ReshapeOp<3, matx::tensor_t<float, 2, matx::basic_storage<matx::raw_pointer_buffer<float, matx::matx_allocator<float>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, cuda::std::__4::array<long long, 3UL>>, 3>, matx::detail::BinOp<float, float, matx::detail::SubF<float, float>>>, D=1]"
from matx.
Can you please paste all the code like you did above? I want to make sure this error makes sense for the type
from matx.
Yep sure thing.
I removed virtually all the initializations from before the loop - they're now not necessary..
matx::tensor_t<float, 2> GsDBSCAN::findDistancesMatX(matx::tensor_t<float, 2> &X_t, matx::tensor_t<int, 2> &A_t, matx::tensor_t<int, 2> &B_t, float alpha) {
const int k = A_t.Shape()[1] / 2;
const int m = B_t.Shape()[1];
const int n = X_t.Shape()[0];
const int d = X_t.Shape()[1];
const int batchSize = GsDBSCAN::findDistanceBatchSize(alpha, n, d, k, m);
auto AFlat_t = matx::flatten(A_t);
auto distances_t = matx::make_tensor<float>({n, 2*k*m});
for (int i = 0; i < n; i += batchSize) {
int maxBatchIdx = i + batchSize - 1; // Index within X along the ROWS
auto XSubset_t_op = matx::slice(X_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd});
auto ABatchFlat_t_op = matx::slice(AFlat_t, {i * 2 * k}, {(maxBatchIdx + 1) * 2 * k});
auto BBatch_t_op = matx::remap<0>(B_t, ABatchFlat_t_op);
auto XBatch_t_op = matx::remap<0>(X_t, matx::flatten(BBatch_t_op));
auto XBatchReshaped_t_op = matx::reshape(XBatch_t_op, {batchSize, 2*k*m, d});
auto XSubsetReshaped_t_op = matx::reshape(XSubset_t_op, {batchSize, 1, d});
auto YBatch_t_op = (XBatchReshaped_t_op - matx::repmat(XSubsetReshaped_t_op, {1, 2*k*m, 1})); // Repmat is a workaround for minusing naively incompatibhle tensor shapes
auto YBatch_t_norm_op = matx::vector_norm(YBatch_t_op, {2}, matx::NormOrder::L2);
(matx::slice(distances_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd}) = YBatch_t_norm_op).run();
}
return distances_t;
}
from matx.
Thanks. I'm not at a computer right now but I'll take a look when I have a chance unless someone else beats me to it.
from matx.
from matx.
It looks like we have a template mismatch here. I think I understand what is going on. First look here:
https://github.com/NVIDIA/MatX/blob/main/include/matx/operators/norm.h#L162-L163
template <typename Op, int D>
auto vector_norm(const Op &op, const int (&dims)[D],
NormOrder order = NormOrder::NONE) {
auto perm = detail::getPermuteDimsOp::Rank()(dims);
auto permop = permute(op, perm);
return detail::NormOp<Op, detail::NormTypeVector>(permop, order);
}
What I believe is happening is this: Op is one type but permop is another type. This leads to the type passed into detail::NormOp<Op... being different than permop causing this error.
One fix would be to replace this:
return detail::NormOp<Op, detail::NormTypeVector>(permop, order);
with this:
return detail::NormOp(permop, order);
If that doesn't work then we can probably replace <Op...> with <decltype(permop)...>
from matx.
@luitjens ok thx. First fix didn't seem to work, but replacing <Op...>
with <decltype(permop)...>
did
from matx.
from matx.
fixed in #696
from matx.
Related Issues (20)
- [FEA] Better error messages when allocating memory
- [QST] MatX is around x15 slower than CuPy for the same task HOT 55
- [DOC] Add detailed descriptions of memory space options
- [BUG] update CUB Cache
- [FEA] add argminmax function
- [FEA] Support UINT random generation with API
- [FEA] Add Normalize function
- [FEA] Add argsort function
- [BUG] ‘::rsqrt’ has not been declared when building MatX HOT 1
- [BUG] Less than operator doesn't seem to work HOT 2
- [BUG] Building docs too aggressive with recursive folder searching
- [QST] How to use the `sort` function? HOT 1
- [QST] How to multiply two random tensors together? HOT 3
- [FEA] add ability to print shape only HOT 1
- [BUG] MATX_EN_CUTENSOR / MATX_ENABLE_CUTENSOR Unified Language
- [BUG] Warning Message for Host call in device function for Einsum
- [BUG] sum function produces incorrect results HOT 2
- [BUG] Transform Ops have incorrect rank and size
- [QST] How to use norm ops with repmat? HOT 6
- [BUG] Random op can't be casted to tensor_t HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from matx.