blob: b5fe0a8d31207a2ed2a8e24382ca741aa35c2428 [file] [log] [blame]
/*
* Copyright 2016 Google Inc.
*
* Use of this source code is governed by a BSD-style license that can
* be found in the LICENSE file.
*
*/
#include <stdio.h>
#include <stdlib.h>
//
//
//
#include "gen.h"
#include "transpose.h"
#include "common/util.h"
#include "common/macros.h"
//
//
//
struct hsg_transpose_state
{
FILE * header;
struct hsg_config const * config;
};
static
char
hsg_transpose_reg_prefix(uint32_t const cols_log2)
{
return 'a' + (('r' + cols_log2 - 'a') % 26);
}
static
void
hsg_transpose_blend(uint32_t const cols_log2,
uint32_t const row_ll, // lower-left
uint32_t const row_ur, // upper-right
void * blend)
{
struct hsg_transpose_state * const state = blend;
// we're starting register names at '1' for now
fprintf(state->header,
" HS_TRANSPOSE_BLEND( %c, %c, %2u, %3u, %3u ) \\\n",
hsg_transpose_reg_prefix(cols_log2-1),
hsg_transpose_reg_prefix(cols_log2),
cols_log2,row_ll+1,row_ur+1);
}
static
void
hsg_transpose_remap(uint32_t const row_from,
uint32_t const row_to,
void * remap)
{
struct hsg_transpose_state * const state = remap;
// we're starting register names at '1' for now
fprintf(state->header,
" HS_TRANSPOSE_REMAP( %c, %3u, %3u ) \\\n",
hsg_transpose_reg_prefix(state->config->warp.lanes_log2),
row_from+1,row_to+1);
}
//
//
//
static
void
hsg_copyright(FILE * file)
{
fprintf(file,
"// \n"
"// Copyright 2016 Google Inc. \n"
"// \n"
"// Use of this source code is governed by a BSD-style \n"
"// license that can be found in the LICENSE file. \n"
"// \n"
"\n");
}
static
void
hsg_macros(FILE * file)
{
fprintf(file,
"// target-specific config \n"
"#include \"hs_config.h\" \n"
" \n"
"// GLSL preamble \n"
"#include \"hs_glsl_preamble.h\"\n"
" \n"
"// arch/target-specific macros \n"
"#include \"hs_glsl_macros.h\" \n"
" \n"
"// \n"
"// \n"
"// \n"
"\n");
}
//
//
//
struct hsg_target_state
{
FILE * header;
FILE * modules;
FILE * source;
};
//
//
//
void
hsg_target_glsl(struct hsg_target * const target,
struct hsg_config const * const config,
struct hsg_merge const * const merge,
struct hsg_op const * const ops,
uint32_t const depth)
{
switch (ops->type)
{
case HSG_OP_TYPE_END:
fprintf(target->state->source,
"}\n");
if (depth == 0) {
fclose(target->state->source);
target->state->source = NULL;
}
break;
case HSG_OP_TYPE_BEGIN:
fprintf(target->state->source,
"{\n");
break;
case HSG_OP_TYPE_ELSE:
fprintf(target->state->source,
"else\n");
break;
case HSG_OP_TYPE_TARGET_BEGIN:
{
// allocate state
target->state = malloc(sizeof(*target->state));
// allocate files
target->state->header = fopen("hs_config.h", "wb");
target->state->modules = fopen("hs_modules.h","wb");
hsg_copyright(target->state->header);
hsg_copyright(target->state->modules);
// initialize header
uint32_t const bc_max = msb_idx_u32(pow2_rd_u32(merge->warps));
fprintf(target->state->header,
"#ifndef HS_GLSL_ONCE \n"
"#define HS_GLSL_ONCE \n"
" \n"
"#define HS_SLAB_THREADS_LOG2 %u \n"
"#define HS_SLAB_THREADS (1 << HS_SLAB_THREADS_LOG2) \n"
"#define HS_SLAB_WIDTH_LOG2 %u \n"
"#define HS_SLAB_WIDTH (1 << HS_SLAB_WIDTH_LOG2) \n"
"#define HS_SLAB_HEIGHT %u \n"
"#define HS_SLAB_KEYS (HS_SLAB_WIDTH * HS_SLAB_HEIGHT)\n"
"#define HS_REG_LAST(c) c##%u \n"
"#define HS_KEY_WORDS %u \n"
"#define HS_VAL_WORDS 0 \n"
"#define HS_BS_SLABS %u \n"
"#define HS_BS_SLABS_LOG2_RU %u \n"
"#define HS_BC_SLABS_LOG2_MAX %u \n"
"#define HS_FM_BLOCK_HEIGHT %u \n"
"#define HS_FM_SCALE_MIN %u \n"
"#define HS_FM_SCALE_MAX %u \n"
"#define HS_HM_BLOCK_HEIGHT %u \n"
"#define HS_HM_SCALE_MIN %u \n"
"#define HS_HM_SCALE_MAX %u \n"
"#define HS_EMPTY \n"
" \n",
config->warp.lanes_log2, // FIXME -- this matters for SIMD
config->warp.lanes_log2,
config->thread.regs,
config->thread.regs,
config->type.words,
merge->warps,
msb_idx_u32(pow2_ru_u32(merge->warps)),
bc_max,
config->merge.flip.warps,
config->merge.flip.lo,
config->merge.flip.hi,
config->merge.half.warps,
config->merge.half.lo,
config->merge.half.hi);
if (target->define != NULL)
fprintf(target->state->header,"#define %s\n\n",target->define);
fprintf(target->state->header,
"#define HS_SLAB_ROWS() \\\n");
for (uint32_t ii=1; ii<=config->thread.regs; ii++)
fprintf(target->state->header,
" HS_SLAB_ROW( %3u, %3u ) \\\n",ii,ii-1);
fprintf(target->state->header,
" HS_EMPTY\n"
" \n");
fprintf(target->state->header,
"#define HS_TRANSPOSE_SLAB() \\\n");
for (uint32_t ii=1; ii<=config->warp.lanes_log2; ii++)
fprintf(target->state->header,
" HS_TRANSPOSE_STAGE( %u ) \\\n",ii);
struct hsg_transpose_state state[1] =
{
{ .header = target->state->header,
.config = config
}
};
hsg_transpose(config->warp.lanes_log2,
config->thread.regs,
hsg_transpose_blend,state,
hsg_transpose_remap,state);
fprintf(target->state->header,
" HS_EMPTY\n"
" \n");
}
break;
case HSG_OP_TYPE_TARGET_END:
// decorate the files
fprintf(target->state->header,
"#endif \n"
" \n"
"// \n"
"// \n"
"// \n"
" \n");
// close files
fclose(target->state->header);
fclose(target->state->modules);
// free state
free(target->state);
break;
case HSG_OP_TYPE_TRANSPOSE_KERNEL_PROTO:
{
fprintf(target->state->modules,
"#include \"hs_transpose.len.xxd\"\n,\n"
"#include \"hs_transpose.spv.xxd\"\n,\n");
target->state->source = fopen("hs_transpose.comp","w+");
hsg_copyright(target->state->source);
hsg_macros(target->state->source);
fprintf(target->state->source,
"HS_TRANSPOSE_KERNEL_PROTO()\n");
}
break;
case HSG_OP_TYPE_TRANSPOSE_KERNEL_PREAMBLE:
{
fprintf(target->state->source,
"HS_SUBGROUP_PREAMBLE();\n");
fprintf(target->state->source,
"HS_SLAB_GLOBAL_PREAMBLE();\n");
}
break;
case HSG_OP_TYPE_TRANSPOSE_KERNEL_BODY:
{
fprintf(target->state->source,
"HS_TRANSPOSE_SLAB()\n");
}
break;
case HSG_OP_TYPE_BS_KERNEL_PROTO:
{
struct hsg_merge const * const m = merge + ops->a;
uint32_t const bs = pow2_ru_u32(m->warps);
uint32_t const msb = msb_idx_u32(bs);
fprintf(target->state->modules,
"#include \"hs_bs_%u.len.xxd\"\n,\n"
"#include \"hs_bs_%u.spv.xxd\"\n,\n",
msb,
msb);
char filename[] = { "hs_bs_XX.comp" };
sprintf(filename,"hs_bs_%u.comp",msb);
target->state->source = fopen(filename,"w+");
hsg_copyright(target->state->source);
hsg_macros(target->state->source);
if (m->warps > 1)
{
fprintf(target->state->source,
"HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
m->warps * config->warp.lanes,
m->rows_bs);
}
fprintf(target->state->source,
"HS_BS_KERNEL_PROTO(%u,%u)\n",
m->warps,msb);
}
break;
case HSG_OP_TYPE_BS_KERNEL_PREAMBLE:
{
fprintf(target->state->source,
"HS_SUBGROUP_PREAMBLE();\n");
fprintf(target->state->source,
"HS_SLAB_GLOBAL_PREAMBLE();\n");
}
break;
case HSG_OP_TYPE_BC_KERNEL_PROTO:
{
struct hsg_merge const * const m = merge + ops->a;
uint32_t const msb = msb_idx_u32(m->warps);
fprintf(target->state->modules,
"#include \"hs_bc_%u.len.xxd\"\n,\n"
"#include \"hs_bc_%u.spv.xxd\"\n,\n",
msb,
msb);
char filename[] = { "hs_bc_XX.comp" };
sprintf(filename,"hs_bc_%u.comp",msb);
target->state->source = fopen(filename,"w+");
hsg_copyright(target->state->source);
hsg_macros(target->state->source);
if (m->warps > 1)
{
fprintf(target->state->source,
"HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
m->warps * config->warp.lanes,
m->rows_bc);
}
fprintf(target->state->source,
"HS_BC_KERNEL_PROTO(%u,%u)\n",
m->warps,msb);
}
break;
case HSG_OP_TYPE_BC_KERNEL_PREAMBLE:
{
fprintf(target->state->source,
"HS_SUBGROUP_PREAMBLE()\n");
fprintf(target->state->source,
"HS_SLAB_GLOBAL_PREAMBLE();\n");
}
break;
case HSG_OP_TYPE_FM_KERNEL_PROTO:
{
fprintf(target->state->modules,
"#include \"hs_fm_%u_%u.len.xxd\"\n,\n"
"#include \"hs_fm_%u_%u.spv.xxd\"\n,\n",
ops->a,ops->b,
ops->a,ops->b);
char filename[] = { "hs_fm_X_XX.comp" };
sprintf(filename,"hs_fm_%u_%u.comp",ops->a,ops->b);
target->state->source = fopen(filename,"w+");
hsg_copyright(target->state->source);
hsg_macros(target->state->source);
fprintf(target->state->source,
"HS_FM_KERNEL_PROTO(%u,%u)\n",
ops->a,ops->b);
}
break;
case HSG_OP_TYPE_FM_KERNEL_PREAMBLE:
{
fprintf(target->state->source,
"HS_SUBGROUP_PREAMBLE()\n");
fprintf(target->state->source,
"HS_FM_PREAMBLE(%u);\n",
ops->a);
}
break;
case HSG_OP_TYPE_HM_KERNEL_PROTO:
{
fprintf(target->state->modules,
"#include \"hs_hm_%u.len.xxd\"\n,\n"
"#include \"hs_hm_%u.spv.xxd\"\n,\n",
ops->a,
ops->a);
char filename[] = { "hs_hm_X.comp" };
sprintf(filename,"hs_hm_%u.comp",ops->a);
target->state->source = fopen(filename,"w+");
hsg_copyright(target->state->source);
hsg_macros(target->state->source);
fprintf(target->state->source,
"HS_HM_KERNEL_PROTO(%u)\n",
ops->a);
}
break;
case HSG_OP_TYPE_HM_KERNEL_PREAMBLE:
{
fprintf(target->state->source,
"HS_SUBGROUP_PREAMBLE()\n");
fprintf(target->state->source,
"HS_HM_PREAMBLE(%u);\n",
ops->a);
}
break;
case HSG_OP_TYPE_BX_REG_GLOBAL_LOAD:
{
static char const * const vstr[] = { "vin", "vout" };
fprintf(target->state->source,
"HS_KEY_TYPE r%-3u = HS_SLAB_GLOBAL_LOAD(%s,%u);\n",
ops->n,vstr[ops->v],ops->n-1);
}
break;
case HSG_OP_TYPE_BX_REG_GLOBAL_STORE:
fprintf(target->state->source,
"HS_SLAB_GLOBAL_STORE(%u,r%u);\n",
ops->n-1,ops->n);
break;
case HSG_OP_TYPE_HM_REG_GLOBAL_LOAD:
fprintf(target->state->source,
"HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
ops->a,ops->b);
break;
case HSG_OP_TYPE_HM_REG_GLOBAL_STORE:
fprintf(target->state->source,
"HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
ops->b,ops->a);
break;
case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_LEFT:
fprintf(target->state->source,
"HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
ops->a,ops->b);
break;
case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_LEFT:
fprintf(target->state->source,
"HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
ops->b,ops->a);
break;
case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_RIGHT:
fprintf(target->state->source,
"HS_KEY_TYPE r%-3u = HS_FM_GLOBAL_LOAD_R(%u);\n",
ops->b,ops->a);
break;
case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_RIGHT:
fprintf(target->state->source,
"HS_FM_GLOBAL_STORE_R(%-3u,r%u);\n",
ops->a,ops->b);
break;
case HSG_OP_TYPE_FM_MERGE_RIGHT_PRED:
{
if (ops->a <= ops->b)
{
fprintf(target->state->source,
"if (HS_FM_IS_NOT_LAST_SPAN() || (fm_frac == 0))\n");
}
else if (ops->b > 1)
{
fprintf(target->state->source,
"else if (fm_frac == %u)\n",
ops->b);
}
else
{
fprintf(target->state->source,
"else\n");
}
}
break;
case HSG_OP_TYPE_SLAB_FLIP:
fprintf(target->state->source,
"HS_SLAB_FLIP_PREAMBLE(%u);\n",
ops->n-1);
break;
case HSG_OP_TYPE_SLAB_HALF:
fprintf(target->state->source,
"HS_SLAB_HALF_PREAMBLE(%u);\n",
ops->n / 2);
break;
case HSG_OP_TYPE_CMP_FLIP:
fprintf(target->state->source,
"HS_CMP_FLIP(%-3u,r%-3u,r%-3u);\n",ops->a,ops->b,ops->c);
break;
case HSG_OP_TYPE_CMP_HALF:
fprintf(target->state->source,
"HS_CMP_HALF(%-3u,r%-3u);\n",ops->a,ops->b);
break;
case HSG_OP_TYPE_CMP_XCHG:
if (ops->c == UINT32_MAX)
{
fprintf(target->state->source,
"HS_CMP_XCHG(r%-3u,r%-3u);\n",
ops->a,ops->b);
}
else
{
fprintf(target->state->source,
"HS_CMP_XCHG(r%u_%u,r%u_%u);\n",
ops->c,ops->a,ops->c,ops->b);
}
break;
case HSG_OP_TYPE_BS_REG_SHARED_STORE_V:
fprintf(target->state->source,
"HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u) = r%u;\n",
merge[ops->a].warps,ops->c,ops->b);
break;
case HSG_OP_TYPE_BS_REG_SHARED_LOAD_V:
fprintf(target->state->source,
"r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
ops->b,merge[ops->a].warps,ops->c);
break;
case HSG_OP_TYPE_BC_REG_SHARED_LOAD_V:
fprintf(target->state->source,
"HS_KEY_TYPE r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
ops->b,ops->a,ops->c);
break;
case HSG_OP_TYPE_BX_REG_SHARED_STORE_LEFT:
fprintf(target->state->source,
"HS_SLAB_LOCAL_L(%5u) = r%u_%u;\n",
ops->b * config->warp.lanes,
ops->c,
ops->a);
break;
case HSG_OP_TYPE_BS_REG_SHARED_STORE_RIGHT:
fprintf(target->state->source,
"HS_SLAB_LOCAL_R(%5u) = r%u_%u;\n",
ops->b * config->warp.lanes,
ops->c,
ops->a);
break;
case HSG_OP_TYPE_BS_REG_SHARED_LOAD_LEFT:
fprintf(target->state->source,
"HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_L(%u);\n",
ops->c,
ops->a,
ops->b * config->warp.lanes);
break;
case HSG_OP_TYPE_BS_REG_SHARED_LOAD_RIGHT:
fprintf(target->state->source,
"HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_R(%u);\n",
ops->c,
ops->a,
ops->b * config->warp.lanes);
break;
case HSG_OP_TYPE_BC_REG_GLOBAL_LOAD_LEFT:
fprintf(target->state->source,
"HS_KEY_TYPE r%u_%-3u = HS_BC_GLOBAL_LOAD_L(%u);\n",
ops->c,
ops->a,
ops->b);
break;
case HSG_OP_TYPE_BLOCK_SYNC:
fprintf(target->state->source,
"HS_BLOCK_BARRIER();\n");
//
// FIXME - Named barriers to allow coordinating warps to proceed?
//
break;
case HSG_OP_TYPE_BS_FRAC_PRED:
{
if (ops->m == 0)
{
fprintf(target->state->source,
"if (warp_idx < bs_full)\n");
}
else
{
fprintf(target->state->source,
"else if (bs_frac == %u)\n",
ops->w);
}
}
break;
case HSG_OP_TYPE_BS_MERGE_H_PREAMBLE:
{
struct hsg_merge const * const m = merge + ops->a;
fprintf(target->state->source,
"HS_BS_MERGE_H_PREAMBLE(%u);\n",
m->warps);
}
break;
case HSG_OP_TYPE_BC_MERGE_H_PREAMBLE:
{
struct hsg_merge const * const m = merge + ops->a;
fprintf(target->state->source,
"HS_BC_MERGE_H_PREAMBLE(%u);\n",
m->warps);
}
break;
case HSG_OP_TYPE_BX_MERGE_H_PRED:
fprintf(target->state->source,
"if (HS_SUBGROUP_ID() < %u)\n",
ops->a);
break;
case HSG_OP_TYPE_BS_ACTIVE_PRED:
{
struct hsg_merge const * const m = merge + ops->a;
if (m->warps <= 32)
{
fprintf(target->state->source,
"if (((1u << HS_SUBGROUP_ID()) & 0x%08X) != 0)\n",
m->levels[ops->b].active.b32a2[0]);
}
else
{
fprintf(target->state->source,
"if (((1UL << HS_SUBGROUP_ID()) & 0x%08X%08XL) != 0L)\n",
m->levels[ops->b].active.b32a2[1],
m->levels[ops->b].active.b32a2[0]);
}
}
break;
default:
fprintf(stderr,"type not found: %s\n",hsg_op_type_string[ops->type]);
exit(EXIT_FAILURE);
break;
}
}
//
//
//