Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QoL Improvements #383

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ fn read_file(file_name: &std::path::Path) -> String {
.read(true)
.write(false)
.create(false)
.open(&file_path);
.open(file_path);

let mut file = match options {
Ok(file) => file,
Expand Down Expand Up @@ -256,13 +256,13 @@ fn run_cmake_command(conf: &Config, build_dir: &std::path::Path) {

#[cfg(not(windows))]
fn run_cmake_command(conf: &Config, build_dir: &std::path::Path) {
let _ = fs::create_dir(&build_dir);
let _ = fs::create_dir(build_dir);

let options = prep_cmake_options(conf);
println!("options are {:?}", options);

let mut cmake_cmd = Command::new("cmake");
cmake_cmd.current_dir(&build_dir);
cmake_cmd.current_dir(build_dir);

run(
cmake_cmd
Expand All @@ -277,7 +277,7 @@ fn run_cmake_command(conf: &Config, build_dir: &std::path::Path) {
);

let mut make_cmd = Command::new("make");
make_cmd.current_dir(&build_dir);
make_cmd.current_dir(build_dir);
run(
make_cmd
.arg(format!("-j{}", conf.build_threads))
Expand Down
3 changes: 2 additions & 1 deletion examples/neural_network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ pub(crate) mod model {
pub trait Model {
fn predict(&self, feature: &Array<f32>) -> Array<f32>;

#[allow(clippy::too_many_arguments)]
fn train(
&mut self,
training_features: &Array<f32>,
Expand Down Expand Up @@ -285,7 +286,7 @@ mod ann {

fn back_propagate(
&mut self,
signals: &Vec<Array<f32>>,
signals: &[Array<f32>],
labels: &Array<u8>,
learning_rate_alpha: f64,
) {
Expand Down
2 changes: 1 addition & 1 deletion opencl-interop/examples/custom_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main() {

af::info();
let dims = af::dim4!(8);
let af_buffer = af::constant(0f32, dims.clone());
let af_buffer = af::constant(0f32, dims);
af::af_print!("af_buffer", af_buffer);

let src = r#"
Expand Down
4 changes: 2 additions & 2 deletions opencl-interop/examples/ocl_af_app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ fn main() {

// Choose platform & device(s) to use. Create a context, queue,
let platform_id = ocl_core::default_platform().unwrap();
let device_ids = ocl_core::get_device_ids(&platform_id, None, None).unwrap();
let device_ids = ocl_core::get_device_ids(platform_id, None, None).unwrap();
let device_id = device_ids[0];
let context_properties = ContextProperties::new().platform(platform_id);
let context =
ocl_core::create_context(Some(&context_properties), &[device_id], None, None).unwrap();
let queue = ocl_core::create_command_queue(&context, &device_id, None).unwrap();
let queue = ocl_core::create_command_queue(&context, device_id, None).unwrap();
let dims = [8, 1, 1];

// Create a `Buffer`:
Expand Down
4 changes: 2 additions & 2 deletions opencl-interop/examples/trivial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fn main() {
// (1) Define which platform and device(s) to use. Create a context,
// queue, and program then define some dims..
let platform_id = core::default_platform().unwrap();
let device_ids = core::get_device_ids(&platform_id, None, None).unwrap();
let device_ids = core::get_device_ids(platform_id, None, None).unwrap();
let device_id = device_ids[0];
let context_properties = ContextProperties::new().platform(platform_id);
let context =
Expand All @@ -33,7 +33,7 @@ fn main() {
None,
)
.unwrap();
let queue = core::create_command_queue(&context, &device_id, None).unwrap();
let queue = core::create_command_queue(&context, device_id, None).unwrap();
let dims = [1 << 20, 1, 1];

// (2) Create a `Buffer`:
Expand Down
5 changes: 3 additions & 2 deletions opencl-interop/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub enum DeviceType {
GPU = CL_DEVICE_TYPE_GPU,
ACCEL = CL_DEVICE_TYPE_ACCELERATOR,
ALL = CL_DEVICE_TYPE_ALL,
UNKNOWN,
}

extern "C" {
Expand Down Expand Up @@ -139,8 +140,8 @@ pub fn get_device_type() -> DeviceType {
let err_val = unsafe { afcl_get_device_type(&mut out as *mut c_int) };
handle_error_general(AfError::from(err_val));
match out {
-1 => unsafe { mem::transmute(out as u64) },
_ => DeviceType::ALL,
-1 => DeviceType::UNKNOWN,
_ => unsafe { mem::transmute::<u64, DeviceType>(out as u64) },
}
}

Expand Down
25 changes: 17 additions & 8 deletions src/algorithm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use super::core::{
HANDLE_ERROR,
};

use core::borrow::Borrow;

use libc::{c_double, c_int, c_uint};

extern "C" {
Expand Down Expand Up @@ -145,13 +147,20 @@ extern "C" {
macro_rules! dim_reduce_func_def {
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
#[doc=$doc_str]
pub fn $fn_name<T>(input: &Array<T>, dim: i32) -> Array<$out_type>
pub fn $fn_name<T, A>(input: A, dim: i32) -> Array<$out_type>
where
T: HasAfEnum,
$out_type: HasAfEnum,
A: Borrow<Array<T>>,
{
let mut temp: af_array = std::ptr::null_mut();
let err_val = unsafe { $ffi_name(&mut temp as *mut af_array, input.get(), dim) };
let err_val = unsafe {
$ffi_name(
std::ptr::from_mut::<af_array>(&mut temp),
input.borrow().get(),
dim,
)
};
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
Expand Down Expand Up @@ -783,7 +792,7 @@ all_reduce_func_def2!(
///
/// - `input` is the input Array
/// - `val` is the val that replaces all `NAN` values of the Array before reduction operation is
/// performed.
/// performed.
///
/// # Return Values
///
Expand Down Expand Up @@ -829,7 +838,7 @@ where
///
/// - `input` is the input Array
/// - `val` is the val that replaces all `NAN` values of the Array before reduction operation is
/// performed.
/// performed.
///
/// # Return Values
///
Expand Down Expand Up @@ -1238,7 +1247,7 @@ where
/// - `key` is the key Array
/// - `input` is the data on which scan is to be performed
/// - `dim` is the dimension along which scan operation is to be performed
/// - `op` takes value of [BinaryOp](./enum.BinaryOp.html) enum indicating
/// - `op` takes value of [`BinaryOp`](./enum.BinaryOp.html) enum indicating
/// the type of scan operation
/// - `inclusive` says if inclusive/exclusive scan is to be performed
///
Expand All @@ -1260,7 +1269,7 @@ where
let mut temp: af_array = std::ptr::null_mut();
let err_val = unsafe {
af_scan_by_key(
&mut temp as *mut af_array,
std::ptr::from_mut::<af_array>(&mut temp),
key.get(),
input.get(),
dim,
Expand Down Expand Up @@ -1300,8 +1309,8 @@ macro_rules! dim_reduce_by_key_func_def {
let mut out_vals: af_array = std::ptr::null_mut();
let err_val = unsafe {
$ffi_name(
&mut out_keys as *mut af_array,
&mut out_vals as *mut af_array,
std::ptr::from_mut::<af_array>(&mut out_keys),
std::ptr::from_mut::<af_array>(&mut out_vals),
keys.get(),
vals.get(),
dim,
Expand Down
4 changes: 2 additions & 2 deletions src/blas/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ where
///
/// - `arr` is the input Array
/// - `conjugate` is a boolean that indicates if the transpose operation needs to be a conjugate
/// transpose
/// transpose
///
/// # Return Values
///
Expand All @@ -232,7 +232,7 @@ pub fn transpose<T: HasAfEnum>(arr: &Array<T>, conjugate: bool) -> Array<T> {
///
/// - `arr` is the input Array that has to be transposed
/// - `conjugate` is a boolean that indicates if the transpose operation needs to be a conjugate
/// transpose
/// transpose
pub fn transpose_inplace<T: HasAfEnum>(arr: &mut Array<T>, conjugate: bool) {
let err_val = unsafe { af_transpose_inplace(arr.get(), conjugate) };
HANDLE_ERROR(AfError::from(err_val));
Expand Down
33 changes: 23 additions & 10 deletions src/core/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use num::Zero;

use libc::c_int;
use num::Complex;
use std::borrow::Borrow;
use std::mem;
use std::ops::Neg;
use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Not, Rem, Shl, Shr, Sub};
Expand Down Expand Up @@ -122,11 +123,11 @@ macro_rules! unary_func {
#[doc=$doc_str]
///
/// This is an element wise unary operation.
pub fn $fn_name<T: HasAfEnum>(input: &Array<T>) -> Array< T::$out_type >
pub fn $fn_name<T: HasAfEnum, A: Borrow<Array<T>>>(input: A) -> Array< T::$out_type >
where T::$out_type: HasAfEnum {

let mut temp: af_array = std::ptr::null_mut();
let err_val = unsafe { $ffi_fn(&mut temp as *mut af_array, input.get()) };
let err_val = unsafe { $ffi_fn(&mut temp as *mut af_array, input.borrow().get()) };
HANDLE_ERROR(AfError::from(err_val));
temp.into()

Expand Down Expand Up @@ -251,7 +252,7 @@ unary_func!(
);

macro_rules! unary_boolean_func {
[$doc_str: expr, $fn_name: ident, $ffi_fn: ident] => (
[$doc_str: expr, $fn_name: ident, $method_name: ident, $ffi_fn: ident] => (
#[doc=$doc_str]
///
/// This is an element wise unary operation.
Expand All @@ -263,12 +264,22 @@ macro_rules! unary_boolean_func {
temp.into()

}

impl<T: HasAfEnum> Array<T> {
#[doc=$doc_str]
///
/// Element-wise unary operation.
pub fn $method_name(&self) -> Array<bool> {
crate::core::arith::$fn_name(self)
}
}
)
}

unary_boolean_func!("Check if values are zero", iszero, af_iszero);
unary_boolean_func!("Check if values are infinity", isinf, af_isinf);
unary_boolean_func!("Check if values are NaN", isnan, af_isnan);
// TODO: Re-evaluate method names..
unary_boolean_func!("Check if values are zero", iszero, zeros, af_iszero);
unary_boolean_func!("Check if values are infinity", isinf, infs, af_isinf);
unary_boolean_func!("Check if values are NaN", isnan, nans, af_isnan);

macro_rules! binary_func {
($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => {
Expand Down Expand Up @@ -331,7 +342,7 @@ binary_func!(
af_hypot
);

/// Type Trait to convert to an [Array](./struct.Array.html)
/// Type Trait to convert to an [`Array`](./struct.Array.html)
///
/// Generic functions that overload the binary operations such as add, div, mul, rem, ge etc. are
/// bound by this trait to allow combinations of scalar values and Array objects as parameters
Expand Down Expand Up @@ -874,7 +885,9 @@ shift_spec!(Shr, shr);
#[cfg(op_assign)]
mod op_assign {

use super::*;
use super::{
add, bitand, bitor, bitxor, div, mem, mul, rem, shiftl, shiftr, sub, ImplicitPromote,
};
use crate::core::{assign_gen, Array, Indexer, Seq};
use std::ops::{AddAssign, DivAssign, MulAssign, RemAssign, SubAssign};
use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign, ShlAssign, ShrAssign};
Expand Down Expand Up @@ -973,12 +986,12 @@ where
}

/// Perform bitwise complement on all values of Array
pub fn bitnot<T: HasAfEnum>(input: &Array<T>) -> Array<T>
pub fn bitnot<T>(input: &Array<T>) -> Array<T>
where
T: HasAfEnum + IntegralType,
{
let mut temp: af_array = std::ptr::null_mut();
let err_val = unsafe { af_bitnot(&mut temp as *mut af_array, input.get()) };
let err_val = unsafe { af_bitnot(std::ptr::from_mut::<af_array>(&mut temp), input.get()) };
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
Loading