pytorch snippet

Dongsung Kim·2022년 10월 25일
0

torch.max

2022-11-23

values, indices = torch.max(output, 1)

torch.perm

import torch
import random

# use python list random shuffle
l_ = list(range(20))
random.shuffle(l_)

# use pytorch perm
rp = torch.randperm(torch.tensor(l_).tolist()

dataloader

# in case there is 2 separate files for train and test
import torch
from torch.utils.data import DataLoader, TensorDataset

x_train_tensor = torch.Tensor(x_train)
y_train_tensor = torch.Tensor(y_train)

x_test_tensor = torch.Tensor(x_test)
y_test_tensor = torch.Tensor(y_test)

train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
test_dataset = TensorDataset(x_test_tensor, y_test_tensor)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

Custom dataset

  • time series with different code

  • 각 id 별로 input_width에 맞는 조각 가능

    idtimetemp
    1025
    1123
    1217
class CusDataset(Dataset):
    
    def __init__(self, df, input_width):
        self.data = df
        self.idx_all = df.index.copy()
        self.l_all = list(self.idx_all)
        self.size_grp = df.groupby('id').size()
        self.size_grp_cum = self.size_grp.values.cumsum()
        self.input_width = input_width
        self.idx_possible = self._create_poss_idx()
    
    def __len__(self):
        return sum([s - self.input_width +1
                  for s in self.s_grp])
    
    def __getitem__(self, idx):
        idx = self.idx_possible[idx]
        return self.data.iloc[idx: idx+self.input_width]
        
    def _create_poss_idx(self):
        l_excl = []
        input_width = self.input_width
        for idx in self.size_grp_cum:
            l_excl.append(slice(idx-input_width+1, idx))
        l_excl = [self.l_all[sl] for sl in l_excl]
        arr_excl = np.array(l_excl).flatten('C')
        idx_possible = np.delete(self.idx_all, arr_excl)
        return idx_possible

variable batch size

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(250, 1)
        self.batch_indices = [0, 100, 129, 150, 200, 250]

    def __getitem__(self, index):
        start_idx = self.batch_indices[index]
        end_idx = self.batch_indices[index+1]
        data = self.data[start_idx:end_idx]
        return data
        
    def __len__(self):
        return len(self.batch_indices) - 1


dataset = MyDataset()
loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2
)

for data in loader:
    data = data.view(-1, 1)
    print(data.shape)
profile
Pick one

0개의 댓글