Seregon/ShadPKG

A tool for deriving PKG packet encryption keys for ps4 written in c++

C++/47.3 KB/No license
common/io_file.cpp
ShadPKG / common / io_file.cpp
1// SPDX-FileCopyrightText: Copyright 2021 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include <vector>
5#include <bit>
6#include <cstdio>
7
8#include "common/alignment.h"
9#include "common/assert.h"
10#include "common/error.h"
11#include "common/io_file.h"
12#include "common/logging/log.h"
13#include "common/path_util.h"
14
15#ifdef _WIN32
16#include "common/ntapi.h"
17
18#include <io.h>
19#include <share.h>
20#include <windows.h>
21#else
22#include <unistd.h>
23#include <stdio.h>
24#endif
25
26#ifdef _MSC_VER
27#define fileno _fileno
28#define fseeko _fseeki64
29#define ftello _ftelli64
30#endif
31
32namespace Common::FS {
33
34namespace fs = std::filesystem;
35
36namespace {
37
38#ifdef _WIN32
39
40[[nodiscard]] constexpr const wchar_t* AccessModeToWStr(FileAccessMode mode, FileType type) {
41 switch (type) {
42 case FileType::BinaryFile:
43 switch (mode) {
44 case FileAccessMode::Read:
45 return L"rb";
46 case FileAccessMode::Write:
47 return L"wb";
48 case FileAccessMode::Append:
49 return L"ab";
50 case FileAccessMode::ReadWrite:
51 return L"r+b";
52 case FileAccessMode::ReadAppend:
53 return L"a+b";
54 }
55 break;
56 case FileType::TextFile:
57 switch (mode) {
58 case FileAccessMode::Read:
59 return L"r";
60 case FileAccessMode::Write:
61 return L"w";
62 case FileAccessMode::Append:
63 return L"a";
64 case FileAccessMode::ReadWrite:
65 return L"r+";
66 case FileAccessMode::ReadAppend:
67 return L"a+";
68 }
69 break;
70 }
71
72 return L"";
73}
74
75[[nodiscard]] constexpr int ToWindowsFileShareFlag(FileShareFlag flag) {
76 switch (flag) {
77 case FileShareFlag::ShareNone:
78 default:
79 return _SH_DENYRW;
80 case FileShareFlag::ShareReadOnly:
81 return _SH_DENYWR;
82 case FileShareFlag::ShareWriteOnly:
83 return _SH_DENYRD;
84 case FileShareFlag::ShareReadWrite:
85 return _SH_DENYNO;
86 }
87}
88
89#else
90
91[[nodiscard]] constexpr const char* AccessModeToStr(FileAccessMode mode, FileType type) {
92 switch (type) {
93 case FileType::BinaryFile:
94 switch (mode) {
95 case FileAccessMode::Read:
96 return "rb";
97 case FileAccessMode::Write:
98 return "wb";
99 case FileAccessMode::Append:
100 return "ab";
101 case FileAccessMode::ReadWrite:
102 return "r+b";
103 case FileAccessMode::ReadAppend:
104 return "a+b";
105 }
106 break;
107 case FileType::TextFile:
108 switch (mode) {
109 case FileAccessMode::Read:
110 return "r";
111 case FileAccessMode::Write:
112 return "w";
113 case FileAccessMode::Append:
114 return "a";
115 case FileAccessMode::ReadWrite:
116 return "r+";
117 case FileAccessMode::ReadAppend:
118 return "a+";
119 }
120 break;
121 }
122
123 return "";
124}
125
126#endif
127
128[[nodiscard]] constexpr int ToSeekOrigin(SeekOrigin origin) {
129 switch (origin) {
130 case SeekOrigin::SetOrigin:
131 default:
132 return SEEK_SET;
133 case SeekOrigin::CurrentPosition:
134 return SEEK_CUR;
135 case SeekOrigin::End:
136 return SEEK_END;
137 }
138}
139
140} // Anonymous namespace
141
142IOFile::IOFile() = default;
143
144IOFile::IOFile(const std::string& path, FileAccessMode mode, FileType type, FileShareFlag flag) {
145 Open(path, mode, type, flag);
146}
147
148IOFile::IOFile(std::string_view path, FileAccessMode mode, FileType type, FileShareFlag flag) {
149 Open(path, mode, type, flag);
150}
151
152IOFile::IOFile(const fs::path& path, FileAccessMode mode, FileType type, FileShareFlag flag) {
153 Open(path, mode, type, flag);
154}
155
156IOFile::~IOFile() {
157 Close();
158}
159
160IOFile::IOFile(IOFile&& other) noexcept {
161 std::swap(file_path, other.file_path);
162 std::swap(file_access_mode, other.file_access_mode);
163 std::swap(file_type, other.file_type);
164 std::swap(file, other.file);
165}
166
167IOFile& IOFile::operator=(IOFile&& other) noexcept {
168 std::swap(file_path, other.file_path);
169 std::swap(file_access_mode, other.file_access_mode);
170 std::swap(file_type, other.file_type);
171 std::swap(file, other.file);
172 return *this;
173}
174
175int IOFile::Open(const fs::path& path, FileAccessMode mode, FileType type, FileShareFlag flag) {
176 Close();
177
178 file_path = path;
179 file_access_mode = mode;
180 file_type = type;
181
182 errno = 0;
183 int result = 0;
184
185#ifdef _WIN32
186 if (flag != FileShareFlag::ShareNone) {
187 file = _wfsopen(path.c_str(), AccessModeToWStr(mode, type), ToWindowsFileShareFlag(flag));
188 result = errno;
189 } else {
190 result = _wfopen_s(&file, path.c_str(), AccessModeToWStr(mode, type));
191 }
192#else
193 file = std::fopen(path.c_str(), AccessModeToStr(mode, type));
194 result = errno;
195#endif
196
197 if (!IsOpen()) {
198 const auto ec = std::error_code{result, std::generic_category()};
199 LOG_ERROR(Common_Filesystem, "Failed to open the file at path={}, error_message={}",
200 PathToUTF8String(file_path), ec.message());
201 }
202
203 return result;
204}
205
206void IOFile::Close() {
207 if (!IsOpen()) {
208 return;
209 }
210
211 errno = 0;
212
213 const auto close_result = std::fclose(file) == 0;
214
215 if (!close_result) {
216 const auto ec = std::error_code{errno, std::generic_category()};
217 LOG_ERROR(Common_Filesystem, "Failed to close the file at path={}, ec_message={}",
218 PathToUTF8String(file_path), ec.message());
219 }
220
221 file = nullptr;
222
223#ifdef _WIN64
224 if (file_mapping && file_access_mode == FileAccessMode::ReadWrite) {
225 CloseHandle(std::bit_cast<HANDLE>(file_mapping));
226 }
227#endif
228}
229
230void IOFile::Unlink() {
231 if (!IsOpen()) {
232 return;
233 }
234
235 // Mark the file for deletion
236 // TODO: Also remove the file path?
237#ifdef _WIN64
238 FILE_DISPOSITION_INFORMATION disposition;
239 IO_STATUS_BLOCK iosb;
240
241 const int fd = fileno(file);
242 HANDLE hfile = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
243
244 disposition.DeleteFile = TRUE;
245 NtSetInformationFile(hfile, &iosb, &disposition, sizeof(disposition),
246 FileDispositionInformation);
247#else
248 if (unlink(file_path.c_str()) != 0) {
249 const auto ec = std::error_code{errno, std::generic_category()};
250 LOG_ERROR(Common_Filesystem, "Failed to unlink the file at path={}, ec_message={}",
251 PathToUTF8String(file_path), ec.message());
252 }
253#endif
254}
255
256uintptr_t IOFile::GetFileMapping() {
257 if (file_mapping) {
258 return file_mapping;
259 }
260#ifdef _WIN64
261 const int fd = fileno(file);
262
263 HANDLE hfile = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
264 HANDLE mapping = nullptr;
265
266 if (file_access_mode == FileAccessMode::ReadWrite) {
267 mapping = CreateFileMapping2(hfile, NULL, FILE_MAP_WRITE, PAGE_READWRITE, SEC_COMMIT, 0,
268 NULL, NULL, 0);
269 } else {
270 mapping = hfile;
271 }
272
273 file_mapping = std::bit_cast<uintptr_t>(mapping);
274 ASSERT_MSG(file_mapping, "{}", Common::GetLastErrorMsg());
275 return file_mapping;
276#else
277 file_mapping = fileno(file);
278 return file_mapping;
279#endif
280}
281
282std::string IOFile::ReadString(size_t length) const {
283 std::vector<char> string_buffer(length);
284
285 const auto chars_read = ReadSpan<char>(string_buffer);
286 const auto string_size = chars_read != length ? chars_read : length;
287
288 return std::string{string_buffer.data(), string_size};
289}
290
291bool IOFile::Flush() const {
292 if (!IsOpen()) {
293 return false;
294 }
295
296 errno = 0;
297
298#ifdef _WIN32
299 const auto flush_result = std::fflush(file) == 0;
300#else
301 const auto flush_result = std::fflush(file) == 0;
302#endif
303
304 if (!flush_result) {
305 const auto ec = std::error_code{errno, std::generic_category()};
306 LOG_ERROR(Common_Filesystem, "Failed to flush the file at path={}, ec_message={}",
307 PathToUTF8String(file_path), ec.message());
308 }
309
310 return flush_result;
311}
312
313bool IOFile::Commit() const {
314 if (!IsOpen()) {
315 return false;
316 }
317
318 errno = 0;
319
320#ifdef _WIN32
321 const auto commit_result = std::fflush(file) == 0 && _commit(fileno(file)) == 0;
322#else
323 const auto commit_result = std::fflush(file) == 0 && fsync(fileno(file)) == 0;
324#endif
325
326 if (!commit_result) {
327 const auto ec = std::error_code{errno, std::generic_category()};
328 LOG_ERROR(Common_Filesystem, "Failed to commit the file at path={}, ec_message={}",
329 PathToUTF8String(file_path), ec.message());
330 }
331
332 return commit_result;
333}
334
335bool IOFile::SetSize(u64 size) const {
336 if (!IsOpen()) {
337 return false;
338 }
339
340 errno = 0;
341
342#ifdef _WIN32
343 const auto set_size_result = _chsize_s(fileno(file), static_cast<s64>(size)) == 0;
344#else
345 const auto set_size_result = ftruncate(fileno(file), static_cast<s64>(size)) == 0;
346#endif
347
348 if (!set_size_result) {
349 const auto ec = std::error_code{errno, std::generic_category()};
350 LOG_ERROR(Common_Filesystem, "Failed to resize the file at path={}, size={}, ec_message={}",
351 PathToUTF8String(file_path), size, ec.message());
352 }
353
354 return set_size_result;
355}
356
357u64 IOFile::GetSize() const {
358 if (!IsOpen()) {
359 return 0;
360 }
361
362 // Flush any unwritten buffered data into the file prior to retrieving the file size.
363 std::fflush(file);
364
365 std::error_code ec;
366
367 const auto file_size = fs::file_size(file_path, ec);
368
369 if (ec) {
370 LOG_ERROR(Common_Filesystem, "Failed to retrieve the file size of path={}, ec_message={}",
371 PathToUTF8String(file_path), ec.message());
372 return 0;
373 }
374
375 return file_size;
376}
377
378bool IOFile::Seek(s64 offset, SeekOrigin origin) const {
379 if (!IsOpen()) {
380 return false;
381 }
382
383 if (False(file_access_mode & (FileAccessMode::Write | FileAccessMode::Append))) {
384 u64 size = GetSize();
385 if (origin == SeekOrigin::CurrentPosition && Tell() + offset > size) {
386 LOG_ERROR(Common_Filesystem, "Seeking past the end of the file");
387 return false;
388 } else if (origin == SeekOrigin::SetOrigin && (u64)offset > size) {
389 LOG_ERROR(Common_Filesystem, "Seeking past the end of the file");
390 return false;
391 } else if (origin == SeekOrigin::End && offset > 0) {
392 LOG_ERROR(Common_Filesystem, "Seeking past the end of the file");
393 return false;
394 }
395 }
396
397 errno = 0;
398
399 const auto seek_result = fseeko(file, offset, ToSeekOrigin(origin)) == 0;
400
401 if (!seek_result) {
402 const auto ec = std::error_code{errno, std::generic_category()};
403 LOG_ERROR(Common_Filesystem,
404 "Failed to seek the file at path={}, offset={}, origin={}, ec_message={}",
405 PathToUTF8String(file_path), offset, static_cast<u32>(origin), ec.message());
406 }
407
408 return seek_result;
409}
410
411s64 IOFile::Tell() const {
412 if (!IsOpen()) {
413 return 0;
414 }
415
416 errno = 0;
417
418 return ftello(file);
419}
420
421u64 GetDirectorySize(const std::filesystem::path& path) {
422 if (!fs::exists(path)) {
423 return 0;
424 }
425
426 u64 total = 0;
427 for (const auto& entry : fs::recursive_directory_iterator(path)) {
428 if (fs::is_regular_file(entry.path())) {
429 total += fs::file_size(entry.path());
430 }
431 }
432 return total;
433}
434
435} // namespace Common::FS
436