-
Notifications
You must be signed in to change notification settings - Fork 19
Cooperative vector API #141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
7912570
82cf902
a6fd226
b0a6e21
38c4a3c
122b673
9532b48
8c88e17
89c0db2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
+4 −2 | include/nanothread/nanothread.h | |
+8 −8 | src/queue.cpp | |
+2 −0 | src/queue.h |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -179,6 +179,9 @@ extern JIT_EXPORT void* jit_cuda_pop_context(); | |||||
/// Query the compute capability of the current device (e.g. '52') | ||||||
extern JIT_EXPORT int jit_cuda_compute_capability(); | ||||||
|
||||||
/// Get the major and minor CUDA version | ||||||
extern JIT_EXPORT void jit_cuda_version(int *major, int *minor); | ||||||
|
||||||
/** | ||||||
* \brief Override generated PTX version and compute capability | ||||||
* | ||||||
|
@@ -637,7 +640,10 @@ enum class JitOp : uint32_t { | |||||
Rcp, Rsqrt, | ||||||
|
||||||
// Multi-function generator (CUDA) | ||||||
Sin, Cos, Exp2, Log2, | ||||||
Sin, Cos, Exp2, Log2, Tanh, | ||||||
|
||||||
// Step function | ||||||
Step, | ||||||
|
||||||
// Total number of operations | ||||||
Count | ||||||
|
@@ -799,6 +805,9 @@ extern JIT_EXPORT uint32_t jit_var_exp2_intrinsic(uint32_t a0); | |||||
/// Approximate `log2(a0)` and return a variable representing the result | ||||||
extern JIT_EXPORT uint32_t jit_var_log2_intrinsic(uint32_t a0); | ||||||
|
||||||
/// Approximate `tanh(a0)` and return a variable representing the result | ||||||
extern JIT_EXPORT uint32_t jit_var_tanh_intrinsic(uint32_t a0); | ||||||
|
||||||
/// Return a variable indicating valid lanes within a function call | ||||||
extern JIT_EXPORT uint32_t jit_var_call_mask(JitBackend backend); | ||||||
|
||||||
|
@@ -2403,6 +2412,7 @@ struct VarInfo { | |||||
void *data; | ||||||
}; | ||||||
bool is_array; | ||||||
bool is_coop_vec; | ||||||
bool unaligned; | ||||||
}; | ||||||
|
||||||
|
@@ -2677,6 +2687,97 @@ extern JIT_EXPORT void jit_freeze_abort(JitBackend backend); | |||||
*/ | ||||||
extern JIT_EXPORT void jit_freeze_destroy(Recording *recording); | ||||||
|
||||||
// ==================================================================== | ||||||
// Cooperative vector API | ||||||
// ==================================================================== | ||||||
|
||||||
/// Pack a set of regular Dr.Jit variables to form a cooperative vector | ||||||
extern JIT_EXPORT uint32_t jit_coop_vec_pack(uint32_t n, const uint32_t *in); | ||||||
|
||||||
/// Unpack a cooperative vector into its components | ||||||
extern JIT_EXPORT void jit_coop_vec_unpack(uint32_t index, uint32_t n, uint32_t *out); | ||||||
|
||||||
/// Create a cooperative vectors, whose components are a uniform literal constant | ||||||
extern JIT_EXPORT uint32_t jit_coop_vec_literal(JIT_ENUM JitBackend backend, | ||||||
JIT_ENUM VarType type, | ||||||
const void *value, | ||||||
size_t size JIT_DEF(1), | ||||||
uint32_t length JIT_DEF(1)); | ||||||
|
||||||
/// Load a cooperative vector from memory | ||||||
extern JIT_EXPORT uint32_t jit_coop_vec_load(uint32_t buffer, uint32_t offset, uint32_t length); | ||||||
|
||||||
/// Determine the length of a cooperative vector | ||||||
extern JIT_EXPORT uint32_t jit_coop_vec_length(uint32_t index); | ||||||
|
||||||
/// Perform a unary operation on a cooperative vector | ||||||
extern JIT_EXPORT uint32_t jit_coop_vec_unary_op(JitOp op, uint32_t a0); | ||||||
|
||||||
/// Perform a binary operation on a pair of cooperative vectors | ||||||
extern JIT_EXPORT uint32_t jit_coop_vec_binary_op(JitOp op, uint32_t a0, uint32_t a1); | ||||||
|
||||||
/// Perform a ternary operation on a triplet of cooperative vectors | ||||||
extern JIT_EXPORT uint32_t jit_coop_vec_ternary_op(JitOp op, uint32_t a0, uint32_t a1, uint32_t a2); | ||||||
|
||||||
/// Encodes a type of request for jit_coop_vec_pack_matrices() | ||||||
enum class MatrixLayout : uint32_t { | ||||||
RowMajor, | ||||||
InferencingOptimal, | ||||||
TrainingOptimal | ||||||
}; | ||||||
|
||||||
/// Summary of a matrix/vector that has been packed into a buffer | ||||||
struct MatrixDescr { | ||||||
VarType dtype; //< Variable type | ||||||
MatrixLayout layout; //< Layout type | ||||||
uint32_t rows, cols; //< Shape of the matrix | ||||||
uint32_t offset; //< Offset from the beginning of the buffer (in elements) | ||||||
uint32_t stride; //< Row stride (in elements) | ||||||
uint32_t size; //< Total size (in elements) | ||||||
}; | ||||||
|
||||||
/// Pack a sequence of matrices from row-major into a representation that is | ||||||
/// optimal for inference/training, or do the reverse. | ||||||
extern JIT_EXPORT void jit_coop_vec_pack_matrices(uint32_t count, | ||||||
uint32_t in, | ||||||
const MatrixDescr *in_descr, | ||||||
uint32_t out, | ||||||
const MatrixDescr *out_descr); | ||||||
|
||||||
/// Query the backend to compute the size of an array/vector in a given layout | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specify which fields will be ignored in |
||||||
extern JIT_EXPORT MatrixDescr jit_coop_vec_compute_layout(uint32_t index, | ||||||
const MatrixDescr *in, | ||||||
MatrixLayout layout, | ||||||
uint32_t offset); | ||||||
|
||||||
/// Perform a matrix-vector multiplication + bias addition | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specify which arguments are optional (e.g. |
||||||
extern JIT_EXPORT uint32_t jit_coop_vec_matvec(uint32_t A_index, | ||||||
const MatrixDescr *A_descr, | ||||||
uint32_t x_index, | ||||||
uint32_t b_index, | ||||||
const MatrixDescr *b_descr, | ||||||
int transpose); | ||||||
|
||||||
/// Accumulate the coop. vector 'index' into the buffer 'target' with offset. | ||||||
/// Potentially create a new buffer of size 'size' if target == 0. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
extern JIT_EXPORT uint32_t jit_coop_vec_accum(uint32_t target, | ||||||
uint32_t target_size, | ||||||
uint32_t offset, | ||||||
uint32_t index); | ||||||
|
||||||
/// Cast a cooperative vector to a different precision | ||||||
extern JIT_EXPORT uint32_t jit_coop_vec_cast(uint32_t index, VarType vt); | ||||||
|
||||||
/// Accumulate the outer product of cooperative vectors 'a' and 'b' into buffer | ||||||
/// 'target', at a location described by 'descr'. Potentially create a new | ||||||
/// buffer of size 'size' if target == 0. | ||||||
extern JIT_EXPORT uint32_t jit_coop_vec_outer_product_accum( | ||||||
uint32_t target, | ||||||
uint32_t target_size, | ||||||
const MatrixDescr *descr, | ||||||
uint32_t a, | ||||||
uint32_t b); | ||||||
|
||||||
#if defined(__cplusplus) | ||||||
} | ||||||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could make a note that in some layouts, the size in number of elements may be larger than the actual number of elements of the matrix? (e.g. when there's mandatory padding, etc).