/github/workspace/src/MatrixFunctionsStride/mat_mult_trans_cmplx_stride/kernels/plp_mat_mult_trans_cmplx_stride_q8s_rv32im.c
Functions
Name | |
---|---|
void | plp_mat_mult_trans_cmplx_stride_q8s_rv32im(const int8_t restrict pSrcA, const int8_t restrict pSrcB, uint32_t M, uint32_t N, uint32_t O, uint32_t strideA, uint32_t strideB, uint32_t strideC, uint32_t shift, int8_t *restrict pDstC) Strided strided matrix transpose matrix multiplication for complex 8-bit fix-point on RV32IM. |
Defines
Name | |
---|---|
BASIC_VERSION |
Functions Documentation
function plp_mat_mult_trans_cmplx_stride_q8s_rv32im
void plp_mat_mult_trans_cmplx_stride_q8s_rv32im(
const int8_t *__restrict__ pSrcA,
const int8_t *__restrict__ pSrcB,
uint32_t M,
uint32_t N,
uint32_t O,
uint32_t strideA,
uint32_t strideB,
uint32_t strideC,
uint32_t shift,
int8_t *__restrict__ pDstC
)
Strided strided matrix transpose matrix multiplication for complex 8-bit fix-point on RV32IM.
Parameters:
- pSrcA Points to the first input matrix of shape MxN
- pSrcB Points to the second input matrix of shape OxN
- M Height of matrix SrcA and DstC
- N Width of matrix SrcA and SrcB
- O Height of matrix SrcB and width of matrix DstC
- strideA Stride of input matrix A (elements between each row)
- strideB Stride of input matrix B (elements between each row)
- strideC Stride of output matrix C (Elements between each row)
- shift Amount to shift the result of each multiplication ot the right
- pDstC Points to the output matrix of shape MxO
Return: none
Par: Fix-Point
Fix-Point and Shifting The result will be shifted by the parameter shift
to the right (which corresponds to a multiplication by 2^-shift
). Assume that matrix A is represente as pSrcA * 2^-x
and matrix B as pSrcB * 2^-y
(which means that A has x
, and B has y
bits after the binary point). Then, the output matrix C is represented as pDstC * 2^-(x + y - shift)
. The output matrix is also stored with the same number of bits as the inputs. Set the shift
parameter such that no overflow occurrs.
strided matrix transpose matrix multiplication for complex 8-bit fix-point on RV32IM
Macros Documentation
define BASIC_VERSION
#define BASIC_VERSION
Source code
/* =====================================================================
* Project: PULP DSP Library
* Title: plp_mat_mult_trans_cmplx_stride_q8s_rv32im.c
* Description: 8-bit complex strided matrix transpose matrix multiplication kernel for RV32IM
*
* $Date: 18. July 2020
* $Revision: V0
*
* Target Processor: PULP cores
* ===================================================================== */
/*
* Copyright (C) 2020 ETH Zurich and Ubiversity of Bologna. All rights reserved.
*
* Author: Tibor Schneider, ETH Zurich
*
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plp_math.h"
void plp_mat_mult_trans_cmplx_stride_q8s_rv32im(const int8_t *__restrict__ pSrcA,
const int8_t *__restrict__ pSrcB,
uint32_t M,
uint32_t N,
uint32_t O,
uint32_t strideA,
uint32_t strideB,
uint32_t strideC,
uint32_t shift,
int8_t *__restrict__ pDstC) {
#define BASIC_VERSION // if used don' forget to also use undefine at end of file
#ifdef BASIC_VERSION
int32_t bias = 1 << (shift - 1);
for (int m = 0; m < M; m++) {
for (int o = 0; o < O; o++) {
int32_t sum_re = 0;
int32_t sum_im = 0;
for (int n = 0; n < N; n++) {
int32_t a_re = (int32_t)pSrcA[(m * strideA + n) * 2 + 0];
int32_t a_im = (int32_t)pSrcA[(m * strideA + n) * 2 + 1];
int32_t b_re = (int32_t)pSrcB[(o * strideB + n) * 2 + 0];
int32_t b_im = (int32_t)pSrcB[(o * strideB + n) * 2 + 1];
sum_re += (a_re * b_re - a_im * b_im + bias) >> shift;
sum_im += (a_re * b_im + a_im * b_re + bias) >> shift;
}
pDstC[(m * strideC + o) * 2 + 0] = (int8_t)sum_re;
pDstC[(m * strideC + o) * 2 + 1] = (int8_t)sum_im;
}
}
#else
// TODO: Hackathon
#endif
#undef BASIC_VERSION
}
Updated on 2023-03-01 at 16:16:33 +0000