按顺序给出作业文档中的代码

```python
import torch
a = torch.randn((16, 16))
a = a.cuda()
b = a * a # pointwise kernel
```

```python
import torch
print(torch.cuda.is_available())
# True
print(torch._C._GLIBCXX_USE_CXX11_ABI)
# True 是否使用 CXX11 ABI
print(torch.__config__.show())
# 编译信息
```

```python
import torch
from flash_attn import flash_attn_func
B = 1
S = 4096
H = 32
D = 128
q, k, v = [
    torch.randn(
        (B, S, H, D), 
        device='cuda', 
        dtype=torch.bfloat16
    ) for _ in range(3)
]

o = flash_attn_func(q, k, v)

# 注意 pytorch 的 API 不同
qt, kt, vt = map(lambda x: x.transpose(1, 2), [q, k, v])
torch_o = torch.nn.functional.scaled_dot_product_attention(
    qt, kt, vt
).transpose(1, 2)

print(torch.allclose(o, torch_o, atol=1e-4, rtol=0.01))
```


```python
# train.py
def get_batch(_):
    x = torch.randint(
        0, 50000, 
        (batch_size, block_size), 
        device='cuda', 
        dtype=torch.int64
    )
    y = torch.empty_like(x)
    y[:, :-1].copy_(x[:, 1:])
    y[:, -1] = -1 # ignore index -1
    return x, y
```

```python
import time
import torch

st = time.time()
... # torch cuda ops
torch.cuda.synchronize()
ed = time.time()
print(f"Time: {ed - st}s")
```

```python
import torch

st = torch.Event("cuda", enable_timing=True) 
st.record()  # or st.record(torch.cuda.current_stream())
... # code 1
ed1 = torch.Event("cuda", enable_timing=True)
ed1.record()
... # code 2
ed2 = torch.Event("cuda", enable_timing=True)
ed2.record()

ed2.synchronize()
print(f"Time: {st.elapsed_time(ed1) / 1e3}s")
print(f"Time: {ed1.elapsed_time(ed2) / 1e3}s")
```


```python
# vec_add.py
import torch
a = torch.randn(4096, device='cuda', dtype=torch.float32)
b = torch.randn(4096, device='cuda', dtype=torch.float32)
c = a + b
torch.cuda.synchronize()
```


```python
# sgemm.py
import torch
import nvtx
a = torch.randn(
    (4096, 4096), device='cuda', dtype=torch.float32
)
b = torch.randn(
    (4096, 4096), device='cuda', dtype=torch.float32
)

for _ in range(3):
    a @ b

with nvtx.annotate("SGEMM"):
    for _ in range(10):
        a @ b

for _ in range(3):
    a @ b

torch.cuda.synchronize()
```

```python
import torch

def int8_quantization(x: torch.Tensor):
    scale = x.abs().max() / 127
    qx = (x / scale).round().to(torch.int8)
    return qx, scale

class INT8LinearFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, w, bias):
        x_shape = list(x.shape)
        x = x.reshape(-1, x.size(-1))
        qx, sx = int8_quantization(x)
        qw, sw = int8_quantization(w)
        ctx.save_for_backward(qx, sx, w, bias)
        ctx.x_shape = x_shape
        y = torch._int_mm(qx, qw.T).to(x.dtype) * (sx * sw)
        if bias is not None:
            y += bias[..., :]
        return y.reshape(x_shape[:-1] + [-1])

    @staticmethod
    def backward(ctx, dy: torch.Tensor):
        dy = dy.reshape(-1, dy.size(-1))
        x_shape = ctx.x_shape
        qx, sx, w, bias = ctx.saved_tensors
        if bias is not None:
            db = dy.sum(dim=0)
        else:
            db = None
        qwt, swt = int8_quantization(w.T)
        qdy, sdy = int8_quantization(dy)
        dx = torch._int_mm(
            qdy, qwt.T
        ).to(dy.dtype).reshape(x_shape) * (swt * sdy)
        
        del qwt, swt
        qxt = qx.transpose(0, 1).contiguous()
        
        dw = torch._int_mm(
            qdy.T.contiguous(), qxt.T
        ).to(w.dtype) * (sx * sdy)
        
        return dx, dw, db



class INT8Linear(torch.nn.Linear):
    def forward(self, input):
        return INT8LinearFunc.apply(
            input, self.weight, self.bias
        )
```