#include "immintrin.h" #include #include "../utils.h" #define A(i,j) A[(i)+(j)*LDA] #define B(i,j) B[(i)+(j)*LDB] #define C(i,j) C[(i)+(j)*LDC] #define M_BLOCKING 192 #define N_BLOCKING 8640 #define K_BLOCKING 384 void scale_c_k18(double *C,int M, int N, int LDC, double scalar){ int m_count,n_count; int M8=M&-8,N4=N&-4,LDC2=LDC<<1,LDC3=LDC2+LDC,LDC4=LDC<<2; __m512d vscalar = _mm512_set1_pd(scalar); double *c_ptr_base1 = C,*c_ptr_base2 = C+LDC,*c_ptr_base3 = C+LDC2,*c_ptr_base4 = C+LDC3; double *c_ptr_dyn1,*c_ptr_dyn2,*c_ptr_dyn3,*c_ptr_dyn4; for (n_count=0;n_count1;count_second+=2,count_sub-=2){ tosrc1=src+count_second*leading_dim;tosrc2=tosrc1+leading_dim; for (count_first=0;count_first0;count_second++,count_sub-=1){ tosrc1=src+count_second*leading_dim; for (count_first=0;count_first23;count_first+=24,count_sub-=24){ tosrc=src+count_first; for(count_second=0;count_second7;count_first+=8,count_sub-=8){ tosrc=src+count_first; for(count_second=0;count_second1;count_first+=2,count_sub-=2){ tosrc=src+count_first; for(count_second=0;count_second0;count_first+=1,count_sub-=1){ tosrc=src+count_first; for(count_second=0;count_second7;n_count_sub-=8,n_count+=8){ //call the m layer with n=8; macro_n8 //TODO: case when m is divisible by 1 } for (;n_count_sub>3;n_count_sub-=4,n_count+=4){ //call the m layer with n=4 macro_n4 //TODO: case when m is divisible by 1 } for (;n_count_sub>1;n_count_sub-=2,n_count+=2){ //call the m layer with n=2 macro_n2 } for (;n_count_sub>0;n_count_sub-=1,n_count+=1){ //TODO:call the m layer with n=1 } } void mydgemm_cpu_v18(\ int M, \ int N, \ int K, \ double alpha, \ double *A, \ int LDA, \ double *B, \ int LDB, \ double beta, \ double *C, \ int LDC)\ { int i,j,k; if (beta != 1.0) scale_c_k18(C,M,N,LDC,beta); if (alpha == 0.||K==0) return; int M4,N8=N&-8,K4; double *a_buffer = (double *)aligned_alloc(4096,K_BLOCKING*M_BLOCKING*sizeof(double)); double *b_buffer = (double *)aligned_alloc(4096,K_BLOCKING*N_BLOCKING*sizeof(double)); int second_m_count,second_n_count,second_m_inc,second_n_inc; int m_count,n_count,k_count; int m_inc,n_inc,k_inc; for (n_count=0;n_countN_BLOCKING)?N_BLOCKING:N-n_count; for (k_count=0;k_countK_BLOCKING)?K_BLOCKING:K-k_count; m_inc=M>M_BLOCKING?M_BLOCKING:M; packing_a_k18(alpha,A+k_count*LDA,a_buffer,LDA,m_inc,k_inc); for (second_n_count=n_count;second_n_count16?16:n_count+n_inc-second_n_count; packing_b_k18(B+k_count+second_n_count*LDB,b_buffer+(second_n_count-n_count)*k_inc,LDB,k_inc,second_n_inc); macro_kernel_k18(a_buffer,b_buffer+(second_n_count-n_count)*k_inc,m_inc,second_n_inc,k_inc,&C(0,second_n_count),LDC); } for (m_count=m_inc;m_countM_BLOCKING)?M_BLOCKING:M-m_count; packing_a_k18(alpha,A+m_count+k_count*LDA,a_buffer,LDA,m_inc,k_inc); macro_kernel_k18(a_buffer,b_buffer,m_inc,n_inc,k_inc,&C(m_count,n_count),LDC); } } } free(a_buffer);free(b_buffer); }