diff --git a/src/index.js b/src/index.js index 8838c73..0746826 100644 --- a/src/index.js +++ b/src/index.js @@ -68,6 +68,8 @@ module.exports = function (req, res, logFacilities, config, next) { } // Capture the response + const originalSetHeader = res.setHeader.bind(res); + const originalRemoveHeader = res.removeHeader.bind(res); const originalWriteHead = res.writeHead.bind(res); const originalWrite = res.write.bind(res); const originalEnd = res.end.bind(res); @@ -77,6 +79,16 @@ module.exports = function (req, res, logFacilities, config, next) { let maximumCachedResponseSizeExceeded = false; let piping = false; + res.setHeader = function (name, value) { + writtenHeaders[name.toLowerCase()] = value; + originalSetHeader(name, value); + }; + + res.removeHeader = function (name) { + delete writtenHeaders[name.toLowerCase()]; + originalRemoveHeader(name); + }; + res.writeHead = function (statusCode, statusCodeDescription, headers) { const properHeaders = headers ? headers : statusCodeDescription; if (typeof properHeaders === "object" && properHeaders !== null) { @@ -85,7 +97,7 @@ module.exports = function (req, res, logFacilities, config, next) { }); } writtenStatusCode = statusCode; - res.setHeader("X-SVRJS-Cache", "MISS"); + originalSetHeader("X-SVRJS-Cache", "MISS"); if (headers || typeof statusCodeDescription !== "object") { originalWriteHead( writtenStatusCode, diff --git a/tests/index.test.js b/tests/index.test.js index 30561aa..ea14a95 100644 --- a/tests/index.test.js +++ b/tests/index.test.js @@ -14,12 +14,9 @@ jest.mock("../src/utils/cacheControlUtils.js", () => ({ })); describe("SVR.JS Cache mod", () => { - let req, res, logFacilities, config, next, resWriteHead, resEnd; + let req, res, logFacilities, config, next; beforeEach(() => { - resWriteHead = jest.fn(); - resEnd = jest.fn(); - req = { method: "GET", headers: {}, @@ -29,9 +26,9 @@ describe("SVR.JS Cache mod", () => { res = { headers: {}, - writeHead: resWriteHead, + writeHead: jest.fn(), write: jest.fn(), - end: resEnd, + end: jest.fn(), setHeader: jest.fn(), getHeaderNames: jest.fn(() => []), getHeaders: jest.fn(() => ({})), @@ -96,6 +93,11 @@ describe("SVR.JS Cache mod", () => { // Reset mocks for the second invocation jest.clearAllMocks(); next.mockReset(); + res.setHeader = jest.fn(); + res.removeHeader = jest.fn(); + res.writeHead = jest.fn(); + res.write = jest.fn(); + res.end = jest.fn(); // Second request: retrieve from cache parseCacheControl.mockReturnValue({}); @@ -108,14 +110,12 @@ describe("SVR.JS Cache mod", () => { "The response is cached." ); expect(res.setHeader).toHaveBeenCalledWith("X-SVRJS-Cache", "HIT"); - expect(resWriteHead).toHaveBeenCalledWith(200, { + expect(res.writeHead).toHaveBeenCalledWith(200, { "cache-control": "max-age=300", "content-type": "application/json" }); - expect(resEnd).toHaveBeenCalledWith( - Buffer.from("cached response body", "latin1"), - undefined, - undefined + expect(res.end).toHaveBeenCalledWith( + Buffer.from("cached response body", "latin1") ); expect(next).not.toHaveBeenCalled(); // No middleware should be called });