12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- import torch.nn
- from torch.utils import data
- import numpy as np
- import pandas as pd
- from torch.utils.data.dataset import T_co
- import pickle
- from tqdm import tqdm
- import sys
- sys.path.append('../')
- from common.log_utils import logFactory
- logger = logFactory("UserDataset").log
- class UserDataset(data.Dataset):
- def __init__(self, dict_file_path, file_type):
- self.data_path = dict_file_path
- self.label_dict = self.load_obj(file_type + "_label_dict")
- self.path_df = self.load_obj(file_type + "_path_df")
- # f = open(self.data_path, "r")
- # lines = f.read().split("\n")
- # self.label_dict = {}
- # self.path_df = {}
- # logger.info(f"开始加载所有{dict_file_path}文件")
- # for i, line in enumerate(tqdm(lines)):
- # if line == "":
- # continue
- # file_path = line.split(" ")[0]
- # key = file_path.split("/")[-1].split(".")[0]
- # label_value = line.split(" ")[1]
- # self.label_dict[key] = label_value
- # df = pd.read_pickle(file_path)
- # self.path_df[key] = df.drop(columns=['uuid'])
- # self.save_obj(self.path_df, file_type + "_path_df")
- # self.save_obj(self.label_dict, file_type + "_label_dict")
- logger.info(f"所有{dict_file_path}中的dataframe加载完成")
- # f.close()
- def save_obj(self, obj, name):
- with open('./obj/' + name + '.pkl', 'wb') as f:
- pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
- def load_obj(self, name):
- with open("C:\\project\\code\\unsubscribe_predict\\lstm\\obj\\" + name + ".pkl", "rb") as f:
- return pickle.load(f)
- def __getitem__(self, index) -> T_co:
- # logger.info(f"current get item index:{index}")
- label = int(list(self.label_dict.values())[index - 1])
- file_path = list(self.label_dict.keys())[index - 1]
- df = self.path_df[file_path]
- # df = pd.read_pickle(file_path)
- # df.drop(columns=['uuid'], inplace=True)
- content = np.array(df)
- # logger.info(f"file_name:{file_path}")
- return torch.from_numpy(content).float(), torch.LongTensor([label])
- def get_labels(self):
- return list(self.label_dict.values())
- def __len__(self) -> int:
- return len(self.label_dict)
- if __name__ == "__main__":
- custom_dataset = UserDataset("a")
- # hidden_dim =10
- # lstm = torch.nn.LSTM(2,3)
- # inputs = [torch.randn(1, 3) for _ in range(5)]
- # x = torch.randn(4,3,2)
- pass
|