Skip to content

PTX Backend#18

Open
WillTrojak wants to merge 7 commits into
PyFR:masterfrom
WillTrojak:feature/ptx
Open

PTX Backend#18
WillTrojak wants to merge 7 commits into
PyFR:masterfrom
WillTrojak:feature/ptx

Conversation

@WillTrojak
Copy link
Copy Markdown
Member

This adds a PTX backend to GiMMiK. The key features are:

  • Mild optimisation of exist CUDA algorithms.
  • Optional async loads for some sparse kernels
  • Added dense generation for Hopper and above

Optimisations have focused on FP64, FP32 is future work.

Comment thread gimmik/kernels/ptx/bstream-msplit.mako Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
yield (tpl, args, meta)

# Warp-specialised dense DMMA
if cc >= (10, 0):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this gate consumer cards with less shared memory?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the best way to handle this is. I've added a DENSE_SMEM_MAX but we could set this via the ini or driver?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If consumer cards can pass the check they need to work. Not sure if there is a clear mapping from CC to max smem. Otherwise, have the caller pass in additional info about max shared memory.

Comment thread gimmik/ptx.py
@@ -0,0 +1,276 @@
# -*- coding: utf-8 -*-

import struct
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8

Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
i = m_tile * 8 + lane // 4
j = k_iter * 4 + lane % 4
v = float(a[i, j]) if (i < m and j < k) else 0.0
u = struct.unpack('<Q', struct.pack('<d', v))[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you unpick this for me?

Comment thread gimmik/ptx.py Outdated

# A in fragment layout: lane l -> A[m_tile*8 + l/4][k_iter*4 + l%4]
a_u64 = []
for m_tile in range(m_tiles):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can 3 arg range work here?

@FreddieWitherden
Copy link
Copy Markdown
Contributor

I know this is an utter pain but for FP32/FP64 can you confirm correctness for all relevant PyFR matrices at a suite of N values for all instances where a kernel is expected to work on A100/H100/B100)?

Comment thread gimmik/kernels/ptx/base.mako Outdated
.param .u64 _c)
{
% endif
.reg .u32 n, id, tid_x, tid_y;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure we throw higher up if n is too big.

Comment thread gimmik/kernels/ptx/bstream-msplit.mako Outdated
## Async fill of chunk 0
% for idx, kx in enumerate(bchunks[0]):
% if idx % msplit == cid:
% if n is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See if we can come up with some consistent indentation for Mako. Am open to ideas.

Comment thread gimmik/kernels/ptx/bstream-msplit.mako Outdated
<%
buf_cur = bb % 2
buf_next = (bb + 1) % 2
is_last = (bb == len(bchunks) - 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a Mako var for this.

Comment thread gimmik/kernels/ptx/bstream-msplit.mako Outdated
% if afix[row_j] == -1:
% if beta == 0:
{
.reg .${pftype} _tmp;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be factored up as appears in both branches?

Comment thread gimmik/kernels/ptx/cstream-ksplit.mako Outdated
fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp;
st.global.${pftype} [_cptr], _ctmp;
% else:
ld.global.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there scope to lifting these ld's up or does the assembler handle this?

Comment thread gimmik/kernels/ptx/bstream.mako
Comment thread gimmik/ptx.py
i = mt * 8 + lane // 4
j = kt * 4 + lane % 4
v = float(a[i, j]) if (i < m and j < k) else 0.0
u, = struct.unpack('<Q', struct.pack('<d', v))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought Python f-strings/format could do this for getting hex representation of floating point?

Comment thread gimmik/ptx.py
Comment thread gimmik/ptx.py
Comment thread gimmik/ptx.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants