Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions lib_nn/src/c/vpu_sim.c
Original file line number Diff line number Diff line change
Expand Up @@ -466,45 +466,54 @@ void VLSUB(xs3_vpu *vpu, const void *addr) {
}
}

static inline
unsigned _VLMUL_GET_SHIFT(const nn_target_arch_t arch, xs3_vpu *vpu) {
// VLMUL shift = bpe - 2 for XS3A, bpe - 1 for VX4A
assert(arch == TARGET_ARCH_XS3A || arch == TARGET_ARCH_VX4A);
unsigned shift = 0;
unsigned adj = (arch == TARGET_ARCH_XS3A) ? 0 : 1;
switch (vpu->mode) {
case MODE_S8:
shift = 8 - 2 + adj;
break;
case MODE_S16:
shift = 16 - 2 + adj;
break;
case MODE_S32:
shift = 32 - 2 + adj;
break;
default:
assert(0); // How'd this happen?
break;
}
return shift;
}

void VLMUL(xs3_vpu *vpu, const void *addr) {
#ifdef __XS3A__
assert_word_aligned(addr);
#endif

int VLMUL_SHR_S16;
if(NN_ARCH == TARGET_ARCH_XS3A){
VLMUL_SHR_S16 = VLMUL_SHR_XS3A;
} else if (NN_ARCH == TARGET_ARCH_VX4A){
VLMUL_SHR_S16 = VLMUL_SHR_VX4A;
} else {
assert(false);
}

const unsigned shift = _VLMUL_GET_SHIFT(NN_ARCH, vpu);
if (vpu->mode == MODE_S8) {
const int8_t *addr8 = (const int8_t *)addr;
for (int i = 0; i < VPU_INT8_EPV; i++) {
int32_t val = addr8[i];
int32_t res = ((int32_t)vpu->vR.s8[i] * val + (1<<5)) >> 6; // TODO use macros
if (NN_ARCH == TARGET_ARCH_VX4A){
res = res >> 1;
}
int32_t res = ((int32_t)vpu->vR.s8[i] * (int32_t)val + (1L<<(shift - 1))) >> shift;
vpu->vR.s8[i] = vpu_saturate(res, 8);
}
} else if (vpu->mode == MODE_S16) {
const int16_t *addr16 = (const int16_t *)addr;

for (int i = 0; i < VPU_INT16_EPV; i++) {
int64_t val = addr16[i];
int64_t res =
((int64_t)vpu->vR.s16[i] * (int64_t)val + (1LL<<(VLMUL_SHR_S16 - 1))) >> VLMUL_SHR_S16; // TODO use macros
int64_t res = ((int64_t)vpu->vR.s16[i] * (int64_t)val + (1LL<<(shift - 1))) >> shift;
vpu->vR.s16[i] = vpu_saturate(res, 16);
}
} else if (vpu->mode == MODE_S32) {
const int32_t *addr32 = (const int32_t *)addr;

for (int i = 0; i < VPU_INT32_EPV; i++) {
int64_t val = addr32[i];
int64_t res = (vpu->vR.s32[i] * val + (1<<29)) >> 30; // TODO use macros
int64_t res = ((int64_t)vpu->vR.s32[i] * (int64_t)val + (1LL<<(shift - 1))) >> shift;
vpu->vR.s32[i] = vpu_saturate(res, 32);
}
} else {
Expand Down