userdataset.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import torch.nn
  2. from torch.utils import data
  3. import numpy as np
  4. import pandas as pd
  5. from torch.utils.data.dataset import T_co
  6. import pickle
  7. from tqdm import tqdm
  8. import sys
  9. sys.path.append('../')
  10. from common.log_utils import logFactory
  11. logger = logFactory("UserDataset").log
  12. class UserDataset(data.Dataset):
  13. def __init__(self, dict_file_path, file_type):
  14. self.data_path = dict_file_path
  15. self.label_dict = self.load_obj(file_type + "_label_dict")
  16. self.path_df = self.load_obj(file_type + "_path_df")
  17. # f = open(self.data_path, "r")
  18. # lines = f.read().split("\n")
  19. # self.label_dict = {}
  20. # self.path_df = {}
  21. # logger.info(f"开始加载所有{dict_file_path}文件")
  22. # for i, line in enumerate(tqdm(lines)):
  23. # if line == "":
  24. # continue
  25. # file_path = line.split(" ")[0]
  26. # key = file_path.split("/")[-1].split(".")[0]
  27. # label_value = line.split(" ")[1]
  28. # self.label_dict[key] = label_value
  29. # df = pd.read_pickle(file_path)
  30. # self.path_df[key] = df.drop(columns=['uuid'])
  31. # self.save_obj(self.path_df, file_type + "_path_df")
  32. # self.save_obj(self.label_dict, file_type + "_label_dict")
  33. logger.info(f"所有{dict_file_path}中的dataframe加载完成")
  34. # f.close()
  35. def save_obj(self, obj, name):
  36. with open('./obj/' + name + '.pkl', 'wb') as f:
  37. pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
  38. def load_obj(self, name):
  39. with open("C:\\project\\code\\unsubscribe_predict\\lstm\\obj\\" + name + ".pkl", "rb") as f:
  40. return pickle.load(f)
  41. def __getitem__(self, index) -> T_co:
  42. # logger.info(f"current get item index:{index}")
  43. label = int(list(self.label_dict.values())[index - 1])
  44. file_path = list(self.label_dict.keys())[index - 1]
  45. df = self.path_df[file_path]
  46. # df = pd.read_pickle(file_path)
  47. # df.drop(columns=['uuid'], inplace=True)
  48. content = np.array(df)
  49. # logger.info(f"file_name:{file_path}")
  50. return torch.from_numpy(content).float(), torch.LongTensor([label])
  51. def get_labels(self):
  52. return list(self.label_dict.values())
  53. def __len__(self) -> int:
  54. return len(self.label_dict)
  55. if __name__ == "__main__":
  56. custom_dataset = UserDataset("a")
  57. # hidden_dim =10
  58. # lstm = torch.nn.LSTM(2,3)
  59. # inputs = [torch.randn(1, 3) for _ in range(5)]
  60. # x = torch.randn(4,3,2)
  61. pass