简单行,列索引操作
import torch
import numpy as np
def test01():
data = torch.randint(0, 10, [4, 5])
print(data)
print(data[0])
print(data[:, 0])
print(data[1, 2])
print(data[:3, 2])
print(data[:3, :2])
def test02():
data = torch.randint(0, 10, [4, 5])
print(data)
print(data[[0, 2, 3], [0, 1, 2]])
print(data[[[0], [2], [3]], [0, 1, 2]])
if __name__ == "__main__":
test02()
布尔索引
import torch
import numpy as np
def test01():
torch.manual_seed(0)
data = torch.randint(0, 10, [4, 5])
print(data)
print(data > 3)
print(data[data > 3])
print(data[data[:, 1] > 6])
print(data[:, data[1] > 3])
def test02():
torch.manual_seed(0)
data = torch.randint(0, 10, [3, 4, 5])
print(data)
print(data[0, :, :])
print(data[:, 0, :])
print(data[:, :, 0])
if __name__ == "__main__":
test02()